Repository: InternLM/lmdeploy Branch: main Commit: 764f35a85b0e Files: 1274 Total size: 7.7 MB Directory structure: gitextract_4p86pot8/ ├── .clang-format ├── .claude/ │ └── skills/ │ ├── check-env/ │ │ └── SKILL.md │ ├── code-navigation/ │ │ └── SKILL.md │ ├── resolve-review/ │ │ └── SKILL.md │ ├── submit-pr/ │ │ └── SKILL.md │ └── support-new-model/ │ └── SKILL.md ├── .github/ │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE/ │ │ ├── 1-bug-report.yml │ │ ├── 2-feature-request.yml │ │ └── 3-documentation.yml │ ├── pull_request_template.md │ ├── release.yml │ ├── scripts/ │ │ ├── action_tools.py │ │ ├── check_lmdeploy.py │ │ ├── doc_link_checker.py │ │ ├── eval_base_config.py │ │ ├── eval_chat_config.py │ │ ├── eval_regression_base_models.py │ │ ├── eval_regression_chat_models.py │ │ ├── eval_stable_object_config.py │ │ └── eval_stable_subject_config.py │ └── workflows/ │ ├── api_eval.yml │ ├── benchmark.yml │ ├── cuda12.8_whl_release.yml │ ├── daily_ete_test.yml │ ├── daily_ete_test_3090.yml │ ├── daily_ete_test_5080.yml │ ├── docker.yml │ ├── docker_dev.yml │ ├── evaluate.yml │ ├── lint.yml │ ├── linux_x64_gpu.yml │ ├── mllm_api_eval.yml │ ├── pr_ete_test.yml │ ├── pypi.yml │ ├── stable.yml │ ├── stale.yml │ ├── test_docker.yml │ ├── unit_test.yml │ └── windows_x64_gpu.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── CLAUDE.md ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── README_ja.md ├── README_zh-CN.md ├── autotest/ │ ├── benchmark/ │ │ ├── test_apiserver_performance.py │ │ ├── test_longtext_performance.py │ │ ├── test_mllm_apiserver_performance.py │ │ ├── test_prefixcache_performance.py │ │ └── test_throughput_performance.py │ ├── chat_prompt_case.yml │ ├── config.yml │ ├── config_3090.yml │ ├── config_3090_legacy.yml │ ├── config_5080.yml │ ├── config_5080_legacy.yml │ ├── config_ascend.yml │ ├── config_h.yml │ ├── config_h800.yml │ ├── config_h_legacy.yml │ ├── config_legacy.yml │ ├── config_test.yml │ ├── config_testascend.yml │ ├── conftest.py │ ├── evaluate/ │ │ ├── eval_config_chat.py │ │ ├── test_api_evaluate.py │ │ └── test_mllm_api_evaluate.py │ ├── interface/ │ │ ├── pipeline/ │ │ │ ├── test_pipeline_func.py │ │ │ └── test_pipeline_longtext_func.py │ │ └── restful/ │ │ ├── test_restful_chat_completions_v1.py │ │ ├── test_restful_completions_v1.py │ │ └── test_restful_generate.py │ ├── prompt_case.yml │ ├── pytest.ini │ ├── template.json │ ├── toolchain/ │ │ └── test_lagent.py │ ├── tools/ │ │ ├── chat/ │ │ │ ├── test_command_chat_hf_pytorch.py │ │ │ └── test_command_chat_hf_turbomind.py │ │ ├── common_case_config.py │ │ ├── pipeline/ │ │ │ ├── llm_case.py │ │ │ ├── mllm_case.py │ │ │ ├── test_pipeline_chat_pytorch_llm.py │ │ │ ├── test_pipeline_chat_pytorch_mllm.py │ │ │ ├── test_pipeline_chat_turbomind_llm.py │ │ │ └── test_pipeline_chat_turbomind_mllm.py │ │ ├── quantization/ │ │ │ ├── test_quantization_awq.py │ │ │ └── test_quantization_w8a8.py │ │ └── restful/ │ │ ├── test_restful_chat_hf_pytorch_llm.py │ │ ├── test_restful_chat_hf_pytorch_mllm.py │ │ ├── test_restful_chat_hf_turbomind_llm.py │ │ └── test_restful_chat_hf_turbomind_mllm.py │ └── utils/ │ ├── benchmark_utils.py │ ├── common_utils.py │ ├── config_utils.py │ ├── constant.py │ ├── evaluate_utils.py │ ├── get_run_config.py │ ├── mp_log_utils.py │ ├── pipeline_chat.py │ ├── proxy_distributed_utils.py │ ├── quantization_utils.py │ ├── ray_distributed_utils.py │ ├── restful_return_check.py │ ├── rule_condition_assert.py │ ├── run_client_chat.py │ ├── run_restful_chat.py │ └── toolkit.py ├── benchmark/ │ ├── README.md │ ├── benchmark_decode.py │ ├── benchmark_pipeline.py │ ├── benchmark_serving.py │ ├── benchmark_throughput.py │ ├── lmdeploy.yml │ ├── profile_pipeline_api.py │ ├── profile_restful_api.py │ └── profile_throughput.py ├── builder/ │ ├── manywheel/ │ │ ├── Dockerfile_2014 │ │ ├── README.md │ │ ├── build_all_lmdeploy_builders.sh │ │ ├── build_all_wheel.sh │ │ ├── build_lmdeploy_builder.sh │ │ ├── build_wheel.sh │ │ ├── entrypoint_build.sh │ │ └── scripts/ │ │ ├── install_conda.sh │ │ ├── install_cuda.sh │ │ └── install_openmpi.sh │ └── windows/ │ ├── README.md │ ├── generate.ps1 │ └── setup_cuda.ps1 ├── cmake/ │ ├── Modules/ │ │ └── FindNCCL.cmake │ ├── TritonTurboMindBackendConfig.cmake.in │ ├── TurboMindConfig.cmake.in │ └── yaml-cpp_cmake_policy.patch ├── debug.sh ├── docker/ │ ├── Dockerfile │ ├── Dockerfile.jetson │ ├── Dockerfile_ascend_a2_300i │ ├── Dockerfile_ascend_a3 │ ├── Dockerfile_dev │ ├── InternVL_Dockerfile │ ├── Qwen2VL_Dockerfile │ ├── build.sh │ ├── install.sh │ └── prepare_wheel.sh ├── docs/ │ ├── en/ │ │ ├── .readthedocs.yaml │ │ ├── Makefile │ │ ├── _static/ │ │ │ └── css/ │ │ │ └── readthedocs.css │ │ ├── advance/ │ │ │ ├── chat_template.md │ │ │ ├── context_parallel.md │ │ │ ├── debug_turbomind.md │ │ │ ├── long_context.md │ │ │ ├── metrics.md │ │ │ ├── pytorch_multinodes.md │ │ │ ├── pytorch_multithread.md │ │ │ ├── pytorch_new_model.md │ │ │ ├── pytorch_profiling.md │ │ │ ├── spec_decoding.md │ │ │ ├── structed_output.md │ │ │ └── update_weights.md │ │ ├── api/ │ │ │ ├── cli.rst │ │ │ ├── openapi.rst │ │ │ └── pipeline.rst │ │ ├── benchmark/ │ │ │ ├── a100_fp16.md │ │ │ ├── benchmark.md │ │ │ ├── evaluate_with_opencompass.md │ │ │ └── evaluate_with_vlmevalkit.md │ │ ├── conf.py │ │ ├── faq.md │ │ ├── get_started/ │ │ │ ├── ascend/ │ │ │ │ └── get_started.md │ │ │ ├── camb/ │ │ │ │ └── get_started.md │ │ │ ├── get_started.md │ │ │ ├── index.rst │ │ │ ├── installation.md │ │ │ └── maca/ │ │ │ └── get_started.md │ │ ├── index.rst │ │ ├── inference/ │ │ │ ├── load_hf.md │ │ │ ├── pytorch.md │ │ │ ├── turbomind.md │ │ │ └── turbomind_config.md │ │ ├── llm/ │ │ │ ├── api_server.md │ │ │ ├── api_server_lora.md │ │ │ ├── api_server_reasoning.md │ │ │ ├── api_server_tools.md │ │ │ ├── codellama.md │ │ │ ├── pipeline.md │ │ │ └── proxy_server.md │ │ ├── make.bat │ │ ├── multi_modal/ │ │ │ ├── api_server_vl.md │ │ │ ├── cogvlm.md │ │ │ ├── deepseek_vl2.md │ │ │ ├── gemma3.md │ │ │ ├── index.rst │ │ │ ├── internvl.md │ │ │ ├── llava.md │ │ │ ├── minicpmv.md │ │ │ ├── molmo.md │ │ │ ├── phi3.md │ │ │ ├── qwen2_5_vl.md │ │ │ ├── qwen2_vl.md │ │ │ ├── vl_pipeline.md │ │ │ └── xcomposer2d5.md │ │ ├── quantization/ │ │ │ ├── kv_quant.md │ │ │ ├── llm_compressor.md │ │ │ ├── w4a16.md │ │ │ └── w8a8.md │ │ └── supported_models/ │ │ ├── reward_models.md │ │ └── supported_models.md │ └── zh_cn/ │ ├── .readthedocs.yaml │ ├── Makefile │ ├── _static/ │ │ └── css/ │ │ └── readthedocs.css │ ├── advance/ │ │ ├── chat_template.md │ │ ├── context_parallel.md │ │ ├── debug_turbomind.md │ │ ├── long_context.md │ │ ├── metrics.md │ │ ├── pytorch_multinodes.md │ │ ├── pytorch_multithread.md │ │ ├── pytorch_new_model.md │ │ ├── pytorch_profiling.md │ │ ├── spec_decoding.md │ │ ├── structed_output.md │ │ └── update_weights.md │ ├── api/ │ │ ├── cli.rst │ │ ├── openapi.rst │ │ └── pipeline.rst │ ├── benchmark/ │ │ ├── benchmark.md │ │ ├── evaluate_with_opencompass.md │ │ └── evaluate_with_vlmevalkit.md │ ├── conf.py │ ├── faq.md │ ├── get_started/ │ │ ├── ascend/ │ │ │ └── get_started.md │ │ ├── camb/ │ │ │ └── get_started.md │ │ ├── get_started.md │ │ ├── index.rst │ │ ├── installation.md │ │ └── maca/ │ │ └── get_started.md │ ├── index.rst │ ├── inference/ │ │ ├── load_hf.md │ │ ├── pytorch.md │ │ ├── turbomind.md │ │ └── turbomind_config.md │ ├── llm/ │ │ ├── api_server.md │ │ ├── api_server_lora.md │ │ ├── api_server_reasoning.md │ │ ├── api_server_tools.md │ │ ├── codellama.md │ │ ├── pipeline.md │ │ └── proxy_server.md │ ├── make.bat │ ├── multi_modal/ │ │ ├── api_server_vl.md │ │ ├── cogvlm.md │ │ ├── deepseek_vl2.md │ │ ├── gemma3.md │ │ ├── index.rst │ │ ├── internvl.md │ │ ├── llava.md │ │ ├── minicpmv.md │ │ ├── molmo.md │ │ ├── phi3.md │ │ ├── qwen2_5_vl.md │ │ ├── qwen2_vl.md │ │ ├── vl_pipeline.md │ │ └── xcomposer2d5.md │ ├── quantization/ │ │ ├── kv_quant.md │ │ ├── llm_compressor.md │ │ ├── w4a16.md │ │ └── w8a8.md │ └── supported_models/ │ ├── reward_models.md │ └── supported_models.md ├── eval/ │ ├── config.py │ └── eval.py ├── examples/ │ └── lite/ │ ├── qwen3_30b_a3b_awq.py │ └── qwen3_30b_a3b_gptq.py ├── generate.sh ├── k8s/ │ ├── deployment.yaml │ └── service.yaml ├── lmdeploy/ │ ├── __init__.py │ ├── __main__.py │ ├── api.py │ ├── archs.py │ ├── cli/ │ │ ├── __init__.py │ │ ├── chat.py │ │ ├── cli.py │ │ ├── entrypoint.py │ │ ├── lite.py │ │ ├── serve.py │ │ └── utils.py │ ├── lite/ │ │ ├── __init__.py │ │ ├── apis/ │ │ │ ├── __init__.py │ │ │ ├── auto_awq.py │ │ │ ├── calibrate.py │ │ │ ├── get_small_sharded_hf.py │ │ │ ├── gptq.py │ │ │ └── smooth_quant.py │ │ ├── defaults.py │ │ ├── modeling/ │ │ │ ├── __init__.py │ │ │ ├── internlm2_gptq.py │ │ │ └── internlm3_gptq.py │ │ ├── quantization/ │ │ │ ├── __init__.py │ │ │ ├── activation/ │ │ │ │ ├── __init__.py │ │ │ │ └── observer.py │ │ │ ├── awq.py │ │ │ ├── calibration.py │ │ │ ├── modules/ │ │ │ │ ├── __init__.py │ │ │ │ └── linear.py │ │ │ └── weight/ │ │ │ ├── __init__.py │ │ │ ├── quant_utils.py │ │ │ └── quantizer.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── batch_split.py │ │ ├── cal_qparams.py │ │ ├── calib_dataloader.py │ │ ├── collect.py │ │ ├── global_avail.py │ │ ├── load.py │ │ └── memory_efficient.py │ ├── logger.py │ ├── messages.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── loggers.py │ │ ├── metrics_processor.py │ │ └── stats.py │ ├── model.py │ ├── monitoring/ │ │ ├── docker-compose.yaml │ │ ├── grafana/ │ │ │ ├── dashboards/ │ │ │ │ ├── config/ │ │ │ │ │ └── dashboard.yaml │ │ │ │ └── json/ │ │ │ │ └── lmdeploy-dashboard.json │ │ │ └── datasources/ │ │ │ └── datasource.yaml │ │ └── prometheus.yaml │ ├── pipeline.py │ ├── profiler.py │ ├── pytorch/ │ │ ├── __init__.py │ │ ├── adapter/ │ │ │ ├── __init__.py │ │ │ └── adapter.py │ │ ├── backends/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── apply_rotary_emb.py │ │ │ ├── attention.py │ │ │ ├── awq_modules.py │ │ │ ├── base.py │ │ │ ├── blockedf8_modules.py │ │ │ ├── causal_conv1d.py │ │ │ ├── cuda/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── apply_rotary_emb.py │ │ │ │ ├── attention/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── default.py │ │ │ │ │ ├── fa3.py │ │ │ │ │ └── mla.py │ │ │ │ ├── awq_modules.py │ │ │ │ ├── blockedf8_modules.py │ │ │ │ ├── causal_conv1d.py │ │ │ │ ├── flash_attention.py │ │ │ │ ├── gated_delta_rule.py │ │ │ │ ├── graph_runner.py │ │ │ │ ├── lora.py │ │ │ │ ├── moe/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── blocked_fp8.py │ │ │ │ │ ├── default.py │ │ │ │ │ ├── ep_utils.py │ │ │ │ │ └── w8a8.py │ │ │ │ ├── moe_router.py │ │ │ │ ├── multinomial_sampling.py │ │ │ │ ├── norm.py │ │ │ │ ├── nsa.py │ │ │ │ ├── op_backend.py │ │ │ │ ├── qmodules.py │ │ │ │ ├── token_dispatcher.py │ │ │ │ ├── utils.py │ │ │ │ └── warmup_manager.py │ │ │ ├── deepep_moe_checker.py │ │ │ ├── default/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── apply_rotary_emb.py │ │ │ │ ├── awq_modules.py │ │ │ │ ├── embedding.py │ │ │ │ ├── linear.py │ │ │ │ ├── moe.py │ │ │ │ ├── moe_router.py │ │ │ │ ├── multinomial_sampling.py │ │ │ │ ├── norm.py │ │ │ │ ├── op_backend.py │ │ │ │ ├── rotary_embedding.py │ │ │ │ └── token_dispatcher.py │ │ │ ├── dlinfer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── apply_rotary_emb.py │ │ │ │ ├── ascend/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── op_backend.py │ │ │ │ │ └── utils.py │ │ │ │ ├── attention.py │ │ │ │ ├── awq_modules.py │ │ │ │ ├── camb/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── op_backend.py │ │ │ │ ├── flash_attention.py │ │ │ │ ├── linear.py │ │ │ │ ├── maca/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── op_backend.py │ │ │ │ ├── moe.py │ │ │ │ ├── norm.py │ │ │ │ ├── op_backend.py │ │ │ │ ├── qmodules.py │ │ │ │ └── rotary_embedding.py │ │ │ ├── embedding.py │ │ │ ├── flash_attention.py │ │ │ ├── gated_delta_rule.py │ │ │ ├── graph_runner.py │ │ │ ├── linear.py │ │ │ ├── lora.py │ │ │ ├── moe.py │ │ │ ├── moe_router.py │ │ │ ├── multinomial_sampling.py │ │ │ ├── norm.py │ │ │ ├── nsa.py │ │ │ ├── qmodules.py │ │ │ ├── rotary_embedding.py │ │ │ ├── selector.py │ │ │ └── token_dispatcher.py │ │ ├── block.py │ │ ├── check_env/ │ │ │ ├── __init__.py │ │ │ ├── adapter.py │ │ │ ├── base.py │ │ │ ├── cuda.py │ │ │ ├── deeplink.py │ │ │ ├── dist.py │ │ │ ├── model.py │ │ │ ├── torch.py │ │ │ ├── transformers.py │ │ │ ├── triton.py │ │ │ └── triton_custom_add.py │ │ ├── config.py │ │ ├── configurations/ │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── chatglm.py │ │ │ ├── cogvlm.py │ │ │ ├── deepseek_v2.py │ │ │ ├── deepseek_v32.py │ │ │ ├── deepseek_vl2.py │ │ │ ├── default.py │ │ │ ├── gemma.py │ │ │ ├── glm4.py │ │ │ ├── gpt_oss.py │ │ │ ├── interns1_pro.py │ │ │ ├── internvl.py │ │ │ ├── internvl3_hf.py │ │ │ ├── llama.py │ │ │ ├── llama4.py │ │ │ ├── llava_hf.py │ │ │ ├── minicpm3.py │ │ │ ├── qwen.py │ │ │ ├── qwen3_5.py │ │ │ ├── qwen3_next.py │ │ │ ├── qwen3_vl.py │ │ │ ├── sdar.py │ │ │ └── utils.py │ │ ├── consts.py │ │ ├── devices/ │ │ │ ├── __init__.py │ │ │ └── device_manager.py │ │ ├── disagg/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── backend/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backend.py │ │ │ │ ├── base.py │ │ │ │ ├── dlslime.py │ │ │ │ └── mooncake.py │ │ │ ├── config.py │ │ │ ├── conn/ │ │ │ │ ├── __init__.py │ │ │ │ ├── engine_conn.py │ │ │ │ ├── protocol.py │ │ │ │ └── proxy_conn.py │ │ │ └── messages.py │ │ ├── distributed.py │ │ ├── engine/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cache_engine.py │ │ │ ├── config_builder.py │ │ │ ├── engine.py │ │ │ ├── engine_checker.py │ │ │ ├── engine_instance.py │ │ │ ├── engine_loop.py │ │ │ ├── executor/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── base_worker.py │ │ │ │ ├── dist_utils.py │ │ │ │ ├── mp_executor.py │ │ │ │ ├── ray_executor.py │ │ │ │ └── uni_executor.py │ │ │ ├── guided_process.py │ │ │ ├── input_process.py │ │ │ ├── inputs_maker.py │ │ │ ├── logits_process.py │ │ │ ├── model_agent/ │ │ │ │ ├── __init__.py │ │ │ │ ├── agent.py │ │ │ │ ├── inputs_maker.py │ │ │ │ └── profiler.py │ │ │ ├── mp_engine/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── base_worker.py │ │ │ │ ├── ray_engine.py │ │ │ │ ├── zmq_engine.py │ │ │ │ └── zmq_rpc.py │ │ │ └── request.py │ │ ├── envs.py │ │ ├── kernels/ │ │ │ ├── __init__.py │ │ │ ├── cuda/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── apply_rotary_pos_emb.py │ │ │ │ ├── awq_kernels.py │ │ │ │ ├── bitonic_topk.py │ │ │ │ ├── blocked_fp8_fused_moe.py │ │ │ │ ├── blocked_gemm_fp8.py │ │ │ │ ├── causal_conv1d.py │ │ │ │ ├── ds_index.py │ │ │ │ ├── fill_kv_cache.py │ │ │ │ ├── flashattention.py │ │ │ │ ├── flatten_kv_cache.py │ │ │ │ ├── fused_lora.py │ │ │ │ ├── fused_moe.py │ │ │ │ ├── fused_moe_ep.py │ │ │ │ ├── fused_noaux_tc.py │ │ │ │ ├── gated_delta_rule.py │ │ │ │ ├── multinomial_sampling.py │ │ │ │ ├── pagedattention.py │ │ │ │ ├── rms_norm.py │ │ │ │ ├── utils.py │ │ │ │ ├── w8a8_fused_moe.py │ │ │ │ └── w8a8_triton_kernels.py │ │ │ ├── default/ │ │ │ │ ├── __init__.py │ │ │ │ ├── multinomial_sampling.py │ │ │ │ └── w8a8_kernels.py │ │ │ ├── dispatcher.py │ │ │ ├── dlinfer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── apply_rotary_pos_emb.py │ │ │ │ ├── awq_kernels.py │ │ │ │ ├── fill_kv_cache.py │ │ │ │ ├── flash_attention.py │ │ │ │ ├── fused_moe.py │ │ │ │ ├── fused_rotary_emb.py │ │ │ │ ├── linear.py │ │ │ │ ├── moe_gating_topk_softmax.py │ │ │ │ ├── pagedattention.py │ │ │ │ ├── rms_norm.py │ │ │ │ └── w8a8_kernels.py │ │ │ └── w8a8_triton_kernels.py │ │ ├── messages.py │ │ ├── model_inputs.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── baichuan.py │ │ │ ├── chatglm2.py │ │ │ ├── cogvlm.py │ │ │ ├── deepseek.py │ │ │ ├── deepseek_mtp.py │ │ │ ├── deepseek_v2.py │ │ │ ├── deepseek_v32.py │ │ │ ├── deepseek_vl2.py │ │ │ ├── gemma.py │ │ │ ├── gemma3_vl.py │ │ │ ├── glm4.py │ │ │ ├── glm4_1v.py │ │ │ ├── glm4_moe.py │ │ │ ├── glm4moe_mtp.py │ │ │ ├── gpt_oss.py │ │ │ ├── internlm.py │ │ │ ├── internlm2.py │ │ │ ├── internlm2_reward.py │ │ │ ├── internlm2_ve.py │ │ │ ├── internlm3.py │ │ │ ├── interns1_pro.py │ │ │ ├── interns1_pro_ts.py │ │ │ ├── internvl.py │ │ │ ├── internvl3_hf.py │ │ │ ├── internvl_patch.py │ │ │ ├── llama.py │ │ │ ├── llama4.py │ │ │ ├── llama_eagle.py │ │ │ ├── llama_eagle3.py │ │ │ ├── llava.py │ │ │ ├── minicpm3.py │ │ │ ├── minicpmv26.py │ │ │ ├── mistral.py │ │ │ ├── mixtral.py │ │ │ ├── module_map.py │ │ │ ├── patch.py │ │ │ ├── phi3.py │ │ │ ├── phi3_moe.py │ │ │ ├── phi3_v.py │ │ │ ├── q_modules.py │ │ │ ├── qwen.py │ │ │ ├── qwen2.py │ │ │ ├── qwen2_5_vl.py │ │ │ ├── qwen2_moe.py │ │ │ ├── qwen2_reward.py │ │ │ ├── qwen2_vl.py │ │ │ ├── qwen3.py │ │ │ ├── qwen3_5.py │ │ │ ├── qwen3_5_moe.py │ │ │ ├── qwen3_moe.py │ │ │ ├── qwen3_next.py │ │ │ ├── qwen3_vl.py │ │ │ ├── qwen3_vl_moe.py │ │ │ ├── sdar.py │ │ │ ├── sdar_moe.py │ │ │ ├── siglip.py │ │ │ ├── starcoder2.py │ │ │ ├── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cudagraph.py │ │ │ │ ├── micro_batch.py │ │ │ │ └── model.py │ │ │ └── whisper.py │ │ ├── multimodal/ │ │ │ ├── __init__.py │ │ │ └── data_type.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── attention.py │ │ │ ├── embedding.py │ │ │ ├── eplb.py │ │ │ ├── gated_delta.py │ │ │ ├── linear/ │ │ │ │ ├── __init__.py │ │ │ │ ├── awq.py │ │ │ │ ├── base.py │ │ │ │ ├── blocked_fp8.py │ │ │ │ ├── default.py │ │ │ │ ├── lora.py │ │ │ │ ├── utils.py │ │ │ │ └── w8a8.py │ │ │ ├── moe/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── blocked_fp8.py │ │ │ │ ├── default.py │ │ │ │ ├── route.py │ │ │ │ └── w8a8.py │ │ │ ├── multinomial_sampling.py │ │ │ ├── norm.py │ │ │ ├── nsa.py │ │ │ ├── quant_utils.py │ │ │ ├── rotary_embedding.py │ │ │ └── utils.py │ │ ├── paging/ │ │ │ ├── __init__.py │ │ │ ├── block_manager/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_block_manager.py │ │ │ │ ├── default_block_manager.py │ │ │ │ └── window_block_manager.py │ │ │ ├── block_trie.py │ │ │ ├── eviction_helper/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_eviction_helper.py │ │ │ │ └── recompute_eviction_helper.py │ │ │ ├── scheduler.py │ │ │ ├── seq_states/ │ │ │ │ ├── __init__.py │ │ │ │ └── states.py │ │ │ └── state_manager.py │ │ ├── ray.py │ │ ├── spec_decode/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── proposers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── deepseek_mtp.py │ │ │ │ ├── eagle.py │ │ │ │ └── eagle3.py │ │ │ ├── reject_sampler.py │ │ │ └── spec_agent.py │ │ ├── strategies/ │ │ │ ├── __init__.py │ │ │ ├── ar/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cudagraph.py │ │ │ │ ├── engine.py │ │ │ │ ├── model_agent.py │ │ │ │ ├── model_inputs.py │ │ │ │ ├── sampling.py │ │ │ │ └── sequence.py │ │ │ ├── ar_spec/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cudagraph.py │ │ │ │ ├── engine.py │ │ │ │ ├── model_agent.py │ │ │ │ ├── model_inputs.py │ │ │ │ ├── sampling.py │ │ │ │ └── sequence.py │ │ │ ├── base/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cudagraph.py │ │ │ │ ├── engine.py │ │ │ │ ├── model_agent.py │ │ │ │ ├── model_inputs.py │ │ │ │ ├── sampling.py │ │ │ │ └── sequence.py │ │ │ └── dllm/ │ │ │ ├── __init__.py │ │ │ ├── cudagraph.py │ │ │ ├── engine.py │ │ │ ├── model_agent.py │ │ │ ├── model_inputs.py │ │ │ ├── sampling.py │ │ │ ├── sequence.py │ │ │ └── unmasking.py │ │ ├── third_party/ │ │ │ ├── __init__.py │ │ │ ├── deep_gemm/ │ │ │ │ └── __init__.py │ │ │ └── flash_attn_interface.py │ │ ├── tools/ │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── transformers/ │ │ │ ├── __init__.py │ │ │ └── configuration_deepseek_v32.py │ │ ├── utils.py │ │ └── weight_loader/ │ │ ├── __init__.py │ │ └── model_weight_loader.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── async_engine.py │ │ │ ├── exceptions.py │ │ │ └── vl_async_engine.py │ │ ├── managers/ │ │ │ ├── __init__.py │ │ │ └── session_manager.py │ │ ├── openai/ │ │ │ ├── __init__.py │ │ │ ├── api_client.py │ │ │ ├── api_server.py │ │ │ ├── harmony_utils.py │ │ │ ├── launch_server.py │ │ │ ├── protocol.py │ │ │ ├── reasoning_parser/ │ │ │ │ ├── __init__.py │ │ │ │ ├── deepseek_r1_reasoning_parser.py │ │ │ │ ├── qwen_qwq_reasoning_parser.py │ │ │ │ └── reasoning_parser.py │ │ │ ├── serving_chat_completion.py │ │ │ ├── serving_completion.py │ │ │ ├── serving_generate.py │ │ │ └── tool_parser/ │ │ │ ├── __init__.py │ │ │ ├── internlm2_parser.py │ │ │ ├── llama3_parser.py │ │ │ ├── qwen2d5_parser.py │ │ │ ├── qwen3_parser.py │ │ │ ├── qwen3coder_parser.py │ │ │ ├── tool_parser.py │ │ │ └── utils.py │ │ ├── processors/ │ │ │ ├── __init__.py │ │ │ └── multimodal.py │ │ ├── proxy/ │ │ │ ├── __init__.py │ │ │ ├── proxy.py │ │ │ ├── streaming_response.py │ │ │ └── utils.py │ │ └── utils/ │ │ ├── __init__.py │ │ └── server_utils.py │ ├── tokenizer.py │ ├── turbomind/ │ │ ├── __init__.py │ │ ├── deploy/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── converter.py │ │ │ ├── loader.py │ │ │ ├── module.py │ │ │ ├── parameter.py │ │ │ ├── policy.py │ │ │ ├── source_model/ │ │ │ │ ├── __init__.py │ │ │ │ ├── baichuan.py │ │ │ │ ├── base.py │ │ │ │ ├── deepseek2.py │ │ │ │ ├── deepseek_vl.py │ │ │ │ ├── glm4.py │ │ │ │ ├── glm4_moe_lite.py │ │ │ │ ├── gpt_oss.py │ │ │ │ ├── internlm2.py │ │ │ │ ├── internvl.py │ │ │ │ ├── llama.py │ │ │ │ ├── llava.py │ │ │ │ ├── minicpmv.py │ │ │ │ ├── mixtral.py │ │ │ │ ├── molmo.py │ │ │ │ ├── qwen.py │ │ │ │ └── xcomposer2.py │ │ │ └── target_model/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── fp.py │ │ ├── supported_models.py │ │ ├── tokenizer_info.py │ │ └── turbomind.py │ ├── utils.py │ ├── version.py │ └── vl/ │ ├── __init__.py │ ├── constants.py │ ├── engine.py │ ├── media/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── connection.py │ │ ├── image.py │ │ ├── time_series.py │ │ ├── video.py │ │ └── video_loader.py │ ├── model/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── builder.py │ │ ├── cogvlm.py │ │ ├── deepseek.py │ │ ├── deepseek_vl2.py │ │ ├── gemma3_vl.py │ │ ├── glm4_1v.py │ │ ├── glm4_v.py │ │ ├── interns1_pro.py │ │ ├── internvl.py │ │ ├── internvl3_hf.py │ │ ├── internvl_llava.py │ │ ├── llama4.py │ │ ├── llava.py │ │ ├── llava_hf.py │ │ ├── llava_next.py │ │ ├── minicpmv.py │ │ ├── mllama.py │ │ ├── molmo.py │ │ ├── phi3_vision.py │ │ ├── qwen.py │ │ ├── qwen2.py │ │ ├── qwen3.py │ │ ├── qwen3_5.py │ │ ├── utils.py │ │ ├── xcomposer2.py │ │ └── yi.py │ ├── tools/ │ │ ├── __init__.py │ │ └── merge_xcomposer2d5_task.py │ └── utils.py ├── pyproject.toml ├── setup.py ├── src/ │ ├── CMakeLists.txt │ └── turbomind/ │ ├── CMakeLists.txt │ ├── comm/ │ │ ├── CMakeLists.txt │ │ ├── barrier.h │ │ ├── cuda_ipc/ │ │ │ ├── CMakeLists.txt │ │ │ ├── allgather.cu │ │ │ ├── allreduce.cu │ │ │ ├── bootstrap.h │ │ │ ├── broadcast.cu │ │ │ ├── common.h │ │ │ ├── cuda_ipc_comm.cu │ │ │ ├── cuda_ipc_comm.h │ │ │ ├── fused_allreduce.cu │ │ │ ├── fused_allreduce_ex.cu │ │ │ ├── group_sum.h │ │ │ ├── mscclpp.h │ │ │ ├── multimem.cuh │ │ │ ├── semaphore.cuh │ │ │ └── semaphore.h │ │ ├── device_comm.cc │ │ ├── device_comm.h │ │ ├── env.h │ │ ├── gloo/ │ │ │ ├── CMakeLists.txt │ │ │ ├── gloo_comm.cc │ │ │ ├── hybrid_comm.cc │ │ │ ├── tcp_store.cc │ │ │ ├── tcp_store.h │ │ │ └── test_ipc_comm.cc │ │ ├── host_comm.cc │ │ ├── host_comm.h │ │ ├── nccl/ │ │ │ ├── CMakeLists.txt │ │ │ └── nccl.cu │ │ ├── test_comm.cu │ │ ├── test_host_comm.cc │ │ └── thread_comm.cc │ ├── core/ │ │ ├── CMakeLists.txt │ │ ├── allocator.cc │ │ ├── allocator.h │ │ ├── buffer.cc │ │ ├── buffer.h │ │ ├── check.cc │ │ ├── check.h │ │ ├── common.h │ │ ├── context.cc │ │ ├── context.h │ │ ├── copy.cc │ │ ├── copy.h │ │ ├── core.h │ │ ├── cuda_data_type.h │ │ ├── data_type.h │ │ ├── interval.h │ │ ├── layout.cc │ │ ├── layout.h │ │ ├── module.cc │ │ ├── module.h │ │ ├── ranges.h │ │ ├── serdes.h │ │ ├── state.h │ │ ├── stream.cc │ │ ├── stream.h │ │ ├── tensor.cc │ │ ├── tensor.cu │ │ ├── tensor.h │ │ └── test_core.cc │ ├── engine/ │ │ ├── CMakeLists.txt │ │ ├── batch.h │ │ ├── engine.cc │ │ ├── engine.h │ │ ├── gateway.cc │ │ ├── gateway.h │ │ ├── model_executor.cc │ │ ├── model_executor.h │ │ ├── model_request.cc │ │ ├── model_request.h │ │ ├── queue.h │ │ ├── request.cc │ │ ├── request.h │ │ ├── request_queue.cc │ │ ├── request_queue.h │ │ └── signal_buffer.h │ ├── generation/ │ │ ├── CMakeLists.txt │ │ ├── base_param.h │ │ ├── generation.cc │ │ ├── generation.h │ │ ├── guided_decoding.cc │ │ ├── guided_decoding.h │ │ ├── logits_processor.cc │ │ ├── logits_processor.h │ │ ├── sampling.cc │ │ ├── sampling.h │ │ ├── stop_criteria.cc │ │ ├── stop_criteria.h │ │ └── utils.h │ ├── kernels/ │ │ ├── CMakeLists.txt │ │ ├── activation.cu │ │ ├── activation.h │ │ ├── activation_kernels.cu │ │ ├── activation_kernels.h │ │ ├── apply_token_bitmask_inplace_cuda.cu │ │ ├── apply_token_bitmask_inplace_cuda.h │ │ ├── attention/ │ │ │ ├── CMakeLists.txt │ │ │ ├── arch.h │ │ │ ├── attention.cu │ │ │ ├── attention.h │ │ │ ├── attention_params.h │ │ │ ├── attention_template.h │ │ │ ├── attention_universal.h │ │ │ ├── block.h │ │ │ ├── block_iterator.h │ │ │ ├── cp_utils.cu │ │ │ ├── cp_utils.h │ │ │ ├── cta_map.h │ │ │ ├── decoding.cu │ │ │ ├── decoding.h │ │ │ ├── decoding_template.h │ │ │ ├── desc.h │ │ │ ├── impl.h │ │ │ ├── impl_16816.h │ │ │ ├── impl_1688.h │ │ │ ├── impl_81616.h │ │ │ ├── impl_884.h │ │ │ ├── impl_m16n8.h │ │ │ ├── impl_simt.h │ │ │ ├── iterator.h │ │ │ ├── iterator_sm70.h │ │ │ ├── iterator_sm80.h │ │ │ ├── kernel/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── attention_sm70_128.cu │ │ │ │ ├── attention_sm70_256.cu │ │ │ │ ├── attention_sm70_576.cu │ │ │ │ ├── attention_sm70_64.cu │ │ │ │ ├── attention_sm75_128.cu │ │ │ │ ├── attention_sm75_256.cu │ │ │ │ ├── attention_sm75_576.cu │ │ │ │ ├── attention_sm75_64.cu │ │ │ │ ├── attention_sm80_128.cu │ │ │ │ ├── attention_sm80_192.cu │ │ │ │ ├── attention_sm80_256.cu │ │ │ │ ├── attention_sm80_576.cu │ │ │ │ ├── attention_sm80_64.cu │ │ │ │ ├── decoding_sm70_128.cu │ │ │ │ ├── decoding_sm70_256.cu │ │ │ │ ├── decoding_sm70_576.cu │ │ │ │ ├── decoding_sm70_64.cu │ │ │ │ ├── decoding_sm75_128.cu │ │ │ │ ├── decoding_sm75_256.cu │ │ │ │ ├── decoding_sm75_576.cu │ │ │ │ ├── decoding_sm75_64.cu │ │ │ │ ├── decoding_sm80_128.cu │ │ │ │ ├── decoding_sm80_192.cu │ │ │ │ ├── decoding_sm80_256.cu │ │ │ │ ├── decoding_sm80_576.cu │ │ │ │ └── decoding_sm80_64.cu │ │ │ ├── kernel.h │ │ │ ├── kernel_impl.h │ │ │ ├── kv_cache_utils_v2.cu │ │ │ ├── kv_cache_utils_v2.h │ │ │ ├── linear_iterator.h │ │ │ ├── mainloop.h │ │ │ ├── mainloop_sm70.h │ │ │ ├── mainloop_sm80.h │ │ │ ├── quantization.h │ │ │ ├── reduce.cu │ │ │ ├── reduce.h │ │ │ ├── reference.cu │ │ │ ├── reference.h │ │ │ ├── registrar.h │ │ │ ├── registry.cu │ │ │ ├── registry.h │ │ │ ├── rotary_embedding.h │ │ │ ├── test_attention.cu │ │ │ ├── test_quant.cu │ │ │ ├── test_utils.cu │ │ │ ├── test_utils.h │ │ │ ├── utils.cc │ │ │ └── utils.h │ │ ├── ban_bad_words.cu │ │ ├── ban_bad_words.h │ │ ├── core/ │ │ │ ├── array.h │ │ │ ├── array_ops.h │ │ │ ├── common.h │ │ │ ├── data_type.h │ │ │ ├── floating_point.h │ │ │ ├── layout.h │ │ │ ├── math.h │ │ │ ├── meta.h │ │ │ ├── mma.h │ │ │ ├── pipe_iter.h │ │ │ ├── smem.h │ │ │ ├── sub_byte_ptr.h │ │ │ ├── sync.h │ │ │ └── thread_map.h │ │ ├── decoding_kernels.cu │ │ ├── decoding_kernels.h │ │ ├── gemm/ │ │ │ ├── CMakeLists.txt │ │ │ ├── arch/ │ │ │ │ ├── config_simt.h │ │ │ │ ├── config_sm70_s884.h │ │ │ │ ├── config_sm75_s16816.h │ │ │ │ ├── config_sm80_s16816.h │ │ │ │ ├── mma_simt.h │ │ │ │ ├── mma_sm70.h │ │ │ │ ├── mma_sm80.h │ │ │ │ ├── operand_simt.h │ │ │ │ ├── operand_sm70_s884.h │ │ │ │ ├── operand_sm80_s16816.h │ │ │ │ ├── smem_copy_simt.h │ │ │ │ ├── smem_copy_sm70.h │ │ │ │ └── smem_copy_sm80.h │ │ │ ├── arch.h │ │ │ ├── cast.cu │ │ │ ├── cast.h │ │ │ ├── context.cu │ │ │ ├── context.h │ │ │ ├── convert.cuh │ │ │ ├── convert.h │ │ │ ├── convert_v3.cu │ │ │ ├── cp_async.h │ │ │ ├── cta_map.h │ │ │ ├── cublas.cu │ │ │ ├── desc.h │ │ │ ├── dispatch_cache.cu │ │ │ ├── dispatch_cache.h │ │ │ ├── epilogue.h │ │ │ ├── format.h │ │ │ ├── gemm.cu │ │ │ ├── gemm.h │ │ │ ├── gemm_universal.h │ │ │ ├── gemm_universal_sm90.h │ │ │ ├── gemm_universal_sm90_v2.h │ │ │ ├── gemm_universal_sm90_v3.h │ │ │ ├── gemm_universal_sm90_v4.h │ │ │ ├── gemm_universal_sm90_v5.h │ │ │ ├── gpu_metric.cu │ │ │ ├── gpu_metric.h │ │ │ ├── iterator.h │ │ │ ├── iterator_sm70.h │ │ │ ├── iterator_sm80.h │ │ │ ├── iterator_sm90.h │ │ │ ├── kernel/ │ │ │ │ ├── sm70_884_16.cu │ │ │ │ ├── sm70_884_4.cu │ │ │ │ ├── sm70_884_8.cu │ │ │ │ ├── sm75_16816_16.cu │ │ │ │ ├── sm75_16816_4.cu │ │ │ │ ├── sm75_16816_8.cu │ │ │ │ ├── sm80_16816_16.cu │ │ │ │ ├── sm80_16816_4.cu │ │ │ │ ├── sm80_16816_8.cu │ │ │ │ ├── sm90_16816_16.cu │ │ │ │ ├── sm90_16816_4.cu │ │ │ │ ├── sm90_16816_8.cu │ │ │ │ └── sm90_64n32_8.cu │ │ │ ├── kernel.cu │ │ │ ├── kernel.h │ │ │ ├── kernel_impl.h │ │ │ ├── kernel_impl_sm90.h │ │ │ ├── mainloop_sm70.h │ │ │ ├── mainloop_sm80_v2.h │ │ │ ├── matrix_ptr.h │ │ │ ├── moe_utils_v2.cu │ │ │ ├── moe_utils_v2.h │ │ │ ├── operand.h │ │ │ ├── predicate.h │ │ │ ├── registry.cu │ │ │ ├── registry.h │ │ │ ├── scaled_gmma_fp8_sm90.h │ │ │ ├── scheduler.cuh │ │ │ ├── scheduler_sm70.cuh │ │ │ ├── simt.h │ │ │ ├── sm90_utils.h │ │ │ ├── smem_copy.h │ │ │ ├── test/ │ │ │ │ ├── gemm_bench.cu │ │ │ │ ├── models.h │ │ │ │ ├── quantization.cu │ │ │ │ ├── quantization.h │ │ │ │ ├── quantization_impl.h │ │ │ │ ├── reference.cu │ │ │ │ ├── reference.h │ │ │ │ ├── test_gemm_v2.cc │ │ │ │ ├── test_moe_utils.cu │ │ │ │ ├── test_utils.cu │ │ │ │ ├── test_utils.h │ │ │ │ └── testbed_v3.h │ │ │ ├── thread_group_map.h │ │ │ ├── thread_map.h │ │ │ ├── tiled_mma.h │ │ │ ├── tma.cu │ │ │ ├── tma.h │ │ │ ├── transform.h │ │ │ ├── tuner/ │ │ │ │ ├── cache_utils.cu │ │ │ │ ├── cache_utils.h │ │ │ │ ├── measurer.cu │ │ │ │ ├── measurer.h │ │ │ │ ├── params.cc │ │ │ │ ├── params.h │ │ │ │ ├── sampler.cu │ │ │ │ ├── sampler.h │ │ │ │ ├── stats.h │ │ │ │ ├── stopping_criterion.cc │ │ │ │ └── stopping_criterion.h │ │ │ ├── types.h │ │ │ ├── unpack.cu │ │ │ └── utils.h │ │ ├── gpt_kernels.cu │ │ ├── gpt_kernels.h │ │ ├── logprob_kernels.cu │ │ ├── logprob_kernels.h │ │ ├── norm/ │ │ │ ├── CMakeLists.txt │ │ │ ├── rms_norm.cu │ │ │ └── rms_norm.h │ │ ├── penalty_types.h │ │ ├── quantization.cu │ │ ├── quantization.cuh │ │ ├── quantization.h │ │ ├── reduce_kernel_utils.cuh │ │ ├── sampling_kernels.cu │ │ ├── sampling_kernels.h │ │ ├── sampling_penalty_kernels.cu │ │ ├── sampling_penalty_kernels.h │ │ ├── sampling_topk_kernels.cu │ │ ├── sampling_topk_kernels.h │ │ ├── sampling_topp_kernels.cu │ │ ├── sampling_topp_kernels.h │ │ ├── stop_criteria_kernels.cu │ │ ├── stop_criteria_kernels.h │ │ ├── test_quantization.cc │ │ ├── unfused_attention_kernels.cu │ │ └── unfused_attention_kernels.h │ ├── macro.h │ ├── models/ │ │ ├── CMakeLists.txt │ │ ├── input_processor.cc │ │ ├── input_processor.h │ │ ├── language_model.cc │ │ ├── language_model.h │ │ ├── llama/ │ │ │ ├── Barrier.h │ │ │ ├── BlockManager.cc │ │ │ ├── BlockManager.h │ │ │ ├── BlockTrie.cc │ │ │ ├── BlockTrie.h │ │ │ ├── CMakeLists.txt │ │ │ ├── GatedDeltaNetLayer.cc │ │ │ ├── GatedDeltaNetLayer.h │ │ │ ├── GatedDeltaNetWeight.cc │ │ │ ├── GatedDeltaNetWeight.h │ │ │ ├── LlamaDecoderLayerWeight.cc │ │ │ ├── LlamaDecoderLayerWeight.h │ │ │ ├── LlamaDenseWeight.cc │ │ │ ├── LlamaDenseWeight.h │ │ │ ├── LlamaFfnLayer.cc │ │ │ ├── LlamaFfnLayer.h │ │ │ ├── LlamaLinear.cu │ │ │ ├── LlamaLinear.h │ │ │ ├── LlamaWeight.cc │ │ │ ├── LlamaWeight.h │ │ │ ├── SequenceManager.cc │ │ │ ├── SequenceManager.h │ │ │ ├── bench_conv1d_silu.cc │ │ │ ├── bench_gated_delta_net.cc │ │ │ ├── context.h │ │ │ ├── gated_delta_net_kernels.cu │ │ │ ├── gated_delta_net_kernels.h │ │ │ ├── llama_kernels.cu │ │ │ ├── llama_kernels.h │ │ │ ├── llama_params.h │ │ │ ├── llama_rope.h │ │ │ ├── llama_utils.cu │ │ │ ├── llama_utils.h │ │ │ ├── mla_utils.cu │ │ │ ├── mla_utils.h │ │ │ ├── moe_ffn_layer.cc │ │ │ ├── moe_ffn_layer.h │ │ │ ├── test_cache_manager.cc │ │ │ ├── unified_attention_layer.cc │ │ │ ├── unified_attention_layer.h │ │ │ ├── unified_decoder.cc │ │ │ └── unified_decoder.h │ │ ├── output_processor.cc │ │ └── output_processor.h │ ├── python/ │ │ ├── CMakeLists.txt │ │ ├── bind.cpp │ │ ├── dlpack.h │ │ └── xgrammar_bind.cpp │ ├── turbomind.cc │ ├── turbomind.h │ └── utils/ │ ├── CMakeLists.txt │ ├── anomaly_handler.cu │ ├── anomaly_handler.h │ ├── constant.h │ ├── cuda_bf16_fallbacks.cuh │ ├── cuda_bf16_wrapper.h │ ├── cuda_type_utils.cuh │ ├── cuda_utils.cc │ ├── cuda_utils.h │ ├── debug_utils.h │ ├── dispatch.h │ ├── logger.cc │ ├── logger.h │ ├── memory_utils.cu │ ├── memory_utils.h │ ├── metrics.h │ ├── monotonic.h │ ├── nvtx_utils.cc │ ├── nvtx_utils.h │ ├── parser.cc │ ├── parser.h │ ├── string_utils.h │ └── test_utils.h └── tests/ ├── csrc/ │ ├── CMakeLists.txt │ └── unittests/ │ ├── CMakeLists.txt │ ├── gtest_utils.h │ ├── test_logprob_kernels.cu │ ├── test_penalty_kernels.cu │ ├── test_sampling_kernels.cu │ ├── test_sampling_layer.cu │ └── unittest_utils.h ├── pytorch/ │ ├── config/ │ │ └── test_hf_overrides.py │ ├── engine/ │ │ ├── test_logits_process.py │ │ ├── test_request.py │ │ └── test_zmq_rpc.py │ ├── kernel/ │ │ ├── test_activation.py │ │ ├── test_apply_rotary.py │ │ ├── test_bitonic_topk.py │ │ ├── test_causal_conv1d.py │ │ ├── test_ds_index.py │ │ ├── test_fill_kv_cache.py │ │ ├── test_flash_attention.py │ │ ├── test_flatten_kv_cache.py │ │ ├── test_fuse_moe_blocked_fp8.py │ │ ├── test_fused_lora.py │ │ ├── test_fused_moe.py │ │ ├── test_gated_delta_rule.py │ │ ├── test_gemm_fp8.py │ │ ├── test_moe_route.py │ │ ├── test_multinomial_sampling.py │ │ ├── test_paged_attention.py │ │ └── test_rms_norm.py │ ├── nn/ │ │ └── test_embedding.py │ └── paging/ │ ├── test_block_manager.py │ ├── test_block_trie.py │ └── test_scheduler.py └── test_lmdeploy/ ├── test_auto_backend.py ├── test_content_merge.py ├── test_grammar.py ├── test_harmony_gpt_oss_parser.py ├── test_lite/ │ └── test_quantization/ │ └── test_utils/ │ └── test_cal_qparams.py ├── test_messages.py ├── test_model.py ├── test_pipeline.py ├── test_qwen3_parser.py ├── test_qwen3coder_parser.py ├── test_tokenizer.py ├── test_turbomind/ │ └── test_converter.py ├── test_utils.py └── test_vl/ ├── test_hf_chat_template.py ├── test_nonhf_chat_template.py ├── test_qwen3vl_processor.py └── test_vl_encode.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ Language: Cpp AccessModifierOffset: -4 AlignAfterOpenBracket: Align AllowShortEnumsOnASingleLine: false AlignConsecutiveAssignments: true AlignConsecutiveDeclarations: true AlignEscapedNewlines: Right AlignOperands: true AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowAllArgumentsOnNextLine: true AllowShortBlocksOnASingleLine: Empty AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: Empty AllowShortIfStatementsOnASingleLine: Never AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false AlwaysBreakTemplateDeclarations: true BinPackArguments: false BinPackParameters: false BreakBeforeBinaryOperators: NonAssignment BreakBeforeBraces: Stroustrup BreakBeforeTernaryOperators: false BreakConstructorInitializers: AfterColon BreakInheritanceList: AfterColon BreakStringLiterals: false ColumnLimit: 120 CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DerivePointerAlignment: false FixNamespaceComments: true IndentCaseLabels: true IndentPPDirectives: None IndentWidth: 4 IndentWrappedFunctionNames: false KeepEmptyLinesAtTheStartOfBlocks: true MaxEmptyLinesToKeep: 1 NamespaceIndentation: None PointerAlignment: Left ReflowComments: true SortIncludes: true SortUsingDeclarations: false SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: false SpaceBeforeAssignmentOperators: true SpaceBeforeCtorInitializerColon: false SpaceBeforeInheritanceColon: false SpaceBeforeParens: ControlStatements SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 2 SpacesInAngles: false SpacesInCStyleCastParentheses: false SpacesInContainerLiterals: false SpacesInParentheses: false SpacesInSquareBrackets: false Standard: c++17 TabWidth: 4 UseTab: Never ================================================ FILE: .claude/skills/check-env/SKILL.md ================================================ --- name: check-env description: Check if the LMDeploy dev environment is properly set up. --- # Check LMDeploy Dev Environment ## 1. Find and activate the conda env ```bash conda env list # starred = currently active conda activate # pick the right env for this project ``` ## 2. Verify editable install ```bash python -c "import lmdeploy; print(lmdeploy.__file__)" # Must point into the repo dir, e.g. /path/to/lmdeploy_vl/lmdeploy/__init__.py ``` If it doesn't: ```bash pip install -e . # run from repo root ``` ## 3. Confirm python and CUDA ```bash which python # must show conda env path, not /usr/bin/python python -c "import torch; print(torch.__version__, torch.version.cuda, torch.cuda.device_count())" ``` ## Troubleshooting | Problem | Fix | | -------------------- | ----------------------------------------------- | | `conda: not found` | `source ~/miniconda3/etc/profile.d/conda.sh` | | Wrong Python | `conda deactivate && conda activate ` | | `lmdeploy` not found | `pip install -e .` from repo root | ================================================ FILE: .claude/skills/code-navigation/SKILL.md ================================================ --- name: code-navigation description: LMDeploy codebase directory map for fast orientation. --- # LMDeploy Project Structure ```text lmdeploy/ ├── cli/ # Command line interface implementations ├── lib/ # Shared libraries/binary assets ├── lite/ # Quantization Toolkit │ ├── apis/ # Calibration, AWQ, and SmoothQuant entry points │ ├── modeling/ # GPTQ/quantized model specific logic │ ├── quantization/ # Scaling calculation (activations/weights) │ └── utils/ # Quantization helper functions (cal_qparams.py) ├── metrics/ # Statistics and performance monitoring ├── monitoring/ # Monitoring configs (Docker/Grafana) ├── pytorch/ # PyTorch inference backend │ ├── adapter/ # LoRA and adapter logic │ ├── backends/ # Kernel/Operator Dispatchers (FP8, AWQ, CUDA) │ ├── check_env/ # Environment/GPU capability sanity checks │ ├── configurations/ # Per-model engine configurations (Llama, etc.) │ ├── devices/ # Device management (CUDA) │ ├── disagg/ # Disaggregated prefill/decode logic │ ├── engine/ # Main Scheduler and Execution Loop │ ├── kernels/ # Triton/CUDA Kernels (w8a8_triton_kernels.py) │ ├── models/ # Model Patches: Replacing HF layers with kernels │ ├── multimodal/ # Multi-modal input types for Pytorch engine │ ├── nn/ # Reusable PyTorch modules │ ├── paging/ # PagedAttention: KV cache block management │ ├── spec_decode/ # Speculative decoding logic │ ├── strategies/ # Execution and dispatch strategies │ ├── third_party/ # External dependencies/repos │ ├── tools/ # Internal engine debugging tools │ ├── transformers/ # HF Transformers integration depth │ └── weight_loader/ # Sharded/quantized weight loading engine ├── serve/ # Serving: OpenAI-compatible API and gRPC ├── turbomind/ # C++ TurboMind inference backend ├── vl/ # Vision-Language (VL) Support and Image Processing │ ├── media/ # Image/Video/... loaders and base classes │ └── model/ # VL Archs (InternVL, Qwen-VL, LLaVA, etc.) and preprocess ├── api.py # High-level entry for model interaction ├── archs.py # Registry: Maps architectures to runtime patches ├── messages.py # Core Types: GenerationConfig, EngineConfig ├── model.py # Chat Templates: CRITICAL for conversation logic ├── pipeline.py # Main Orchestrator: Engine + Tokenizer └── tokenizer.py # Wrapper for HF/SentencePiece tokenizers ``` ================================================ FILE: .claude/skills/resolve-review/SKILL.md ================================================ --- name: resolve-review description: Fetch and resolve PR review comments, then push fixes. --- # Resolve PR Review Comments ## 1. Fetch comments ```bash gh api repos/InternLM/lmdeploy/pulls//comments \ | python3 -c " import json, sys for c in json.load(sys.stdin): print(f'[{c[\"path\"]}:{c.get(\"line\",\"?\")}]') print(c['body']) print() " ``` ## 2. Fix each issue Read the flagged file, understand the comment, edit the file. ## 3. Lint ```bash pre-commit run --all-files ``` ## 4. Stage & commit ```bash git add git commit -m "fix: address PR review comments" ``` ## 5. Push ```bash git push ``` ================================================ FILE: .claude/skills/submit-pr/SKILL.md ================================================ --- name: submit-pr description: Submit a GitHub pull request for LMDeploy. --- # Submit a PR for LMDeploy ## 1. Create branch (off main) Skip this step if already on a feature branch. ```bash git checkout main && git pull git checkout -b / # e.g. feat/qwen3-omni ``` ## 2. Lint ```bash pre-commit run --all-files ``` ## 3. Stage ```bash git add lmdeploy/path/to/changed_file.py # specific files only, never git add . git status # verify staged set ``` ## 4. Commit ```bash git commit -m "feat: add Qwen3-Omni support" # Conventional prefixes: feat | fix | refactor | docs | test | chore ``` ## 5. Push ```bash git push -u origin ``` ## 6. Create PR ```bash gh pr create --title ": " --body "$(cat <<'EOF' ## Summary - - ## Test plan - [ ] `pre-commit run --all-files` passes - [ ] unit tests pass: `pytest tests/test_lmdeploy/` - [ ] manual smoke test with pipeline 🤖 Generated with [Claude Code](https://claude.com/claude-code) EOF )" ``` ================================================ FILE: .claude/skills/support-new-model/SKILL.md ================================================ --- name: support-new-model description: Add a new LLM or VLM to LMDeploy's PyTorch backend. --- # Tutorial: Adding a New Model to LMDeploy (PyTorch Backend) This guide walks through adding a new LLM or VLM to LMDeploy's PyTorch backend. ______________________________________________________________________ ## Before Writing Any Code **Study the reference implementations before touching any files.** 1. Read the HF model's `config.json` to understand: `model_type`, `architectures`, layer counts, hidden dims, number of attention heads, MoE parameters (if applicable). 2. Identify which category the model falls into: - **LLM only** — pure text model - **VLM** — text + vision (needs an additional preprocessor in `vl/model/`) 3. Find the closest existing model in LMDeploy and read it thoroughly: | Reference model | File(s) | | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | | LLM (dense) | `lmdeploy/pytorch/models/qwen3.py` | | LLM (MoE) | `lmdeploy/pytorch/models/qwen3_moe.py` | | VLM preprocessor | `lmdeploy/vl/model/qwen3.py` | | VLM (composite config) | `lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py` + `lmdeploy/pytorch/configurations/qwen3_omni.py` + `lmdeploy/vl/model/qwen3_omni.py` | ______________________________________________________________________ ## Key Files Quick Reference | File | Purpose | | -------------------------------------------- | --------------------------------------------------------------- | | `lmdeploy/pytorch/models/.py` | Attention, MLP, DecoderLayer, Model, ForCausalLM | | `lmdeploy/pytorch/models/module_map.py` | HF class name → LMDeploy class path mapping | | `lmdeploy/pytorch/configurations/.py` | Config builder — only needed for non-standard/nested HF configs | | `lmdeploy/vl/model/.py` | VLM: image/video preprocessing *(VLM only)* | | `lmdeploy/vl/model/base.py` | `VisionModel` base class + `VISION_MODELS` registry | | `lmdeploy/archs.py` | VLM: arch name → task mapping *(VLM only)* | | `lmdeploy/lite/apis/calibrate.py` | Quantization: layer/norm/head mappings *(optional)* | | `lmdeploy/lite/quantization/awq.py` | Quantization: AWQ scale mappings *(optional)* | ______________________________________________________________________ ## Step-by-Step: LLM (PyTorch Backend) ### Step 1 — Create the PyTorch model file **File:** `lmdeploy/pytorch/models/.py` Implement the following class hierarchy (innermost → outermost): 1. **`Attention`** — QKV projection, rotary embedding, attention forward 2. **`MLP`** — gate-up linear, activation, down projection 3. **`DecoderLayer`** — wraps Attention + MLP with layer norms and residual connections 4. **`Model`** — embedding table, all decoder layers, final norm, rotary embedding 5. **`ForCausalLM`** — top-level class; inherits `nn.Module`, `DeployModelMixinV1`, `CudaGraphMixin` **Required imports:** ```python import torch import torch.nn as nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import add_prefix from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, build_embedding ``` **Attention skeleton:** ```python class MyModelAttention(nn.Module): def __init__(self, config, dtype=None, device=None, prefix=''): super().__init__() self.qkv_proj = build_qkv_proj( config.hidden_size, num_q_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, head_size=config.hidden_size // config.num_attention_heads, bias=False, dtype=dtype, device=device, prefix=add_prefix('qkv_proj', prefix)) self.apply_rotary_pos_emb = ApplyRotaryEmb() self.attn_fwd = Attention( config.num_attention_heads, config.hidden_size // config.num_attention_heads, num_kv_heads=config.num_key_value_heads) self.o_proj = build_o_proj( config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size, bias=False, dtype=dtype, device=device, prefix=add_prefix('o_proj', prefix)) def forward(self, hidden_states, rotary_pos_emb, past_key_value, attn_metadata): qkv_states = self.qkv_proj(hidden_states) # split q, k, v; apply rotary; call attn_fwd; project output ... ``` **MLP skeleton:** ```python class MyModelMLP(nn.Module): def __init__(self, config, dtype=None, device=None, prefix=''): super().__init__() self.gate_up_proj = build_gateup_linear( config.hidden_size, config.intermediate_size, bias=False, dtype=dtype, device=device, prefix=add_prefix('gate_up_proj', prefix)) self.down_proj = build_down_linear( config.intermediate_size, config.hidden_size, bias=False, dtype=dtype, device=device, prefix=add_prefix('down_proj', prefix)) self.act_fn = SiluAndMul() def forward(self, x): return self.down_proj(self.act_fn(self.gate_up_proj(x))) ``` **ForCausalLM skeleton (critical fields):** ```python class MyModelForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): # Maps packed param name → list of original HF param suffixes packed_modules_mapping = { 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], 'gate_up_proj': ['gate_proj', 'up_proj'], } def __init__(self, config, ctx_mgr=None, prefix='', **kwargs): super().__init__() self.model = MyModelModel(config, ...) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.ctx_mgr = ctx_mgr def get_input_embeddings(self): return self.model.embed_tokens def forward(self, input_ids, inputs_embeds, past_key_values, attn_metadata, **kwargs): hidden_states = self.model(input_ids, inputs_embeds, past_key_values, attn_metadata) return hidden_states def get_logits(self, hidden_states): return self.lm_head(hidden_states) # prepare_inputs_for_generation and load_weights: copy from qwen3.py, # update stacked_params_mapping to match this model's HF weight names. ``` ______________________________________________________________________ ### Step 2 — Register in `module_map.py` **File:** `lmdeploy/pytorch/models/module_map.py` Add an entry to `MODULE_MAP`. The key is the exact HF architecture class name from `config.json`'s `architectures` field: ```python MODULE_MAP.update({ 'MyModelForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.my_model.MyModelForCausalLM', }) ``` ______________________________________________________________________ ### Step 3 — Add config builder (if needed) **File:** `lmdeploy/pytorch/configurations/.py` **Skip this step** for models with a standard flat HF config — `DefaultModelConfigBuilder` handles them automatically. Only create this file when the HF config is non-standard, e.g.: - Nested config (e.g., Qwen3-Omni has `hf_config.thinker_config.text_config`) - Unusual `model_type` that needs special field remapping ```python from .builder import AutoModelConfigBuilder, DefaultModelConfigBuilder class MyModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): # Must match model_type from config.json exactly return hf_config.model_type == 'my_model' @classmethod def build(cls, hf_config, model_path=None, **kwargs): # Extract the text config if nested; patch fields if needed cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) cfg.hf_config = hf_config # keep full config for VLM layers return cfg ``` Auto-discovery: subclasses of `AutoModelConfigBuilder` register themselves automatically via `__init_subclass__()` — no import needed elsewhere. ______________________________________________________________________ ### Step 4 — Add quantization mappings (optional) Only needed to support AWQ/SmoothQuant calibration for this model family. **`lmdeploy/lite/apis/calibrate.py`** — add layer name, norm name, and head name mappings for the new model type. **`lmdeploy/lite/quantization/awq.py`** — add entries to `NORM_FCS_MAP` (norm → downstream FC layers) and `FC_FCS_MAP` (FC → downstream FC layers) following the existing patterns. ______________________________________________________________________ ## Step-by-Step: VLM (additional steps) ### Step 5 — Create the VL preprocessor **File:** `lmdeploy/vl/model/.py` The preprocessor handles image/video decoding and feature extraction before the LLM backbone sees the input. ```python from lmdeploy.vl.model.base import VISION_MODELS, VisionModel @VISION_MODELS.register_module() class MyModelVLModel(VisionModel): # Must match hf_config.architectures exactly (can be a list for variants) _arch = ['MyModelForConditionalGeneration'] def build_preprocessor(self): """Load the vision processor from the model checkpoint.""" from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(self.model_path) # Set image_token_id to the token ID of the image placeholder # (used by the engine to know where to inject image features) tokenizer = self.processor.tokenizer self.image_token = '' # model-specific placeholder token self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) # preprocess and to_pytorch: copy from vl/model/qwen3.py and adapt # image token handling (image_token, image_token_id, image_tokens count). ``` Key points: - `collect_images()`, `proc_messages()`, `to_pytorch_aux()` are all provided by `VisionModel` — do not re-implement them. - `image_tokens` tells the engine how many token slots to reserve for each image. - Auto-registered via `@VISION_MODELS.register_module()` when the module is imported. **Add an explicit import** in `lmdeploy/vl/model/builder.py` alongside the existing imports so the decorator runs at startup: ```python from .my_model import MyModelVLModel # noqa F401 ``` ______________________________________________________________________ ### Step 6 — Register VLM arch in `archs.py` **File:** `lmdeploy/archs.py` Add the architecture name to the `supported_archs` set inside `check_vl_llm()` so the engine routes the model through the VLM code path: ```python # lmdeploy/archs.py — inside check_vl_llm() supported_archs = set([ ... 'MyModelForConditionalGeneration', # add this line ]) ``` ______________________________________________________________________ ## Checklist **LLM (PyTorch backend):** - [ ] `pytorch/models/.py` — all 5 classes implemented (`Attention`, `MLP`, `DecoderLayer`, `Model`, `ForCausalLM`) - [ ] `module_map.py` — HF architecture class name registered - [ ] `packed_modules_mapping` matches HF parameter naming scheme - [ ] `stacked_params_mapping` in `load_weights()` has correct shard indices - [ ] `pytorch/configurations/.py` — added only if HF config is non-standard - [ ] Weights load cleanly from HF checkpoint (no missing/unexpected key errors) **VLM (additional):** - [ ] `vl/model/.py` — `build_preprocessor`, `preprocess`, `to_pytorch` implemented - [ ] `_arch` matches `config.json` `architectures[0]` exactly - [ ] `image_token_id` correctly resolved from the tokenizer - [ ] `image_tokens` count is correct for the image resolution/encoding scheme - [ ] `vl/model/builder.py` — explicit import added for new model - [ ] `archs.py` entry added **Quantization (optional):** - [ ] `calibrate.py` — layer/norm/head name mappings added - [ ] `awq.py` — `NORM_FCS_MAP` / `FC_FCS_MAP` entries added ______________________________________________________________________ ## Common Pitfalls 1. **Weight name mismatches** — `packed_modules_mapping` keys must match HF param name suffixes exactly. Check actual HF weight names with `list(model.state_dict().keys())[:20]` before coding. 2. **Wrong shard index order** — `stacked_params_mapping` for QKV must follow Q→0, K→1, V→2. Wrong order silently produces bad outputs. 3. **Wrong `_arch`** — must match `hf_config.architectures[0]` literally (e.g., `'Qwen3VLForConditionalGeneration'`, not `'Qwen3VL'`). 4. **`image_token_id` is None** — causes the engine to silently skip image feature injection. Always verify with `tokenizer.convert_tokens_to_ids(image_token)` returning a real token ID. 5. **Missing `role='preprocess'` append** — `to_pytorch_aux()` searches messages for exactly `role='preprocess'`; if `preprocess()` does not append it, inference will fail with a confusing error. 6. **Config builder `condition()` mismatch** — `model_type` in `condition()` must match the exact string in `config.json`, not a display name or alias. 7. **MoE routing** — MoE models need `num_experts`, `num_experts_per_tok`, and a TopK gating mechanism in the MLP. Reference `qwen3_moe.py` for the pattern. 8. **CUDA graph + dynamic control flow** — models with data-dependent branching (e.g., conditional expert dispatch) may break CUDA graph capture. Use `_no_cudagraph` guards in `CudaGraphMixin` if needed. ______________________________________________________________________ ## Verification **LLM basic test:** ```bash python -m lmdeploy.pytorch.chat --backend pytorch ``` **VLM basic test:** ```python from lmdeploy import pipeline pipe = pipeline('') result = pipe(('Describe this image.', 'path/to/image.jpg')) print(result.text) ``` **Unit tests:** ```bash pytest tests/test_lmdeploy/test_vl/ # VLM tests pytest tests/test_lmdeploy/ # all unit tests ``` **Debug weight loading:** ```bash LMDEPLOY_LOG_LEVEL=DEBUG python -m lmdeploy.pytorch.chat --backend pytorch 2>&1 | grep -E "load|weight|miss" ``` ================================================ FILE: .github/CONTRIBUTING.md ================================================ ## Contributing to LMDeploy Welcome to the LMDeploy community, all kinds of contributions are welcomed, including but not limited to **Fix bug** You can directly post a Pull Request to fix typo in code or documents The steps to fix the bug of code implementation are as follows. 1. If the modification involve significant changes, you should create an issue first and describe the error information and how to trigger the bug. Other developers will discuss with you and propose an proper solution. 2. Posting a pull request after fixing the bug and adding corresponding unit test. **New Feature or Enhancement** 1. If the modification involve significant changes, you should create an issue to discuss with our developers to propose an proper design. 2. Post a Pull Request after implementing the new feature or enhancement and add corresponding unit test. **Document** You can directly post a pull request to fix documents. If you want to add a document, you should first create an issue to check if it is reasonable. ### Pull Request Workflow If you're not familiar with Pull Request, don't worry! The following guidance will tell you how to create a Pull Request step by step. If you want to dive into the develop mode of Pull Request, you can refer to the [official documents](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests) #### 1. Fork and clone If you are posting a pull request for the first time, you should fork the OpenMMLab repositories by clicking the **Fork** button in the top right corner of the GitHub page, and the forked repositories will appear under your GitHub profile. Then, you can clone the repositories to local: ```shell git clone git@github.com:{username}/lmdeploy.git ``` After that, you should add official repository as the upstream repository ```bash git remote add upstream git@github.com:InternLM/lmdeploy.git ``` Check whether remote repository has been added successfully by `git remote -v` ```bash origin git@github.com:{username}/lmdeploy.git (fetch) origin git@github.com:{username}/lmdeploy.git (push) upstream git@github.com:InternLM/lmdeploy.git (fetch) upstream git@github.com:InternLM/lmdeploy.git (push) ``` > Here's a brief introduction to origin and upstream. When we use "git clone", we create an "origin" remote by default, which points to the repository cloned from. As for "upstream", we add it ourselves to point to the target repository. Of course, if you don't like the name "upstream", you could name it as you wish. Usually, we'll push the code to "origin". If the pushed code conflicts with the latest code in official("upstream"), we should pull the latest code from upstream to resolve the conflicts, and then push to "origin" again. The posted Pull Request will be updated automatically. #### 2. Configure pre-commit You should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of LMDeploy. **Note**: The following code should be executed under the lmdeploy directory. ```shell pip install -U pre-commit pre-commit install ``` Check that pre-commit is configured successfully, and install the hooks defined in `.pre-commit-config.yaml`. ```shell pre-commit run --all-files ``` If the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation. If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically. If we want to commit our code bypassing the pre-commit hook, we can use the `--no-verify` option(**only for temporarily commit**). ```shell git commit -m "xxx" --no-verify ``` #### 3. Create a development branch After configuring the pre-commit, we should create a branch based on the master branch to develop the new feature or fix the bug. The proposed branch name is `username/pr_name` ```shell git checkout -b yhc/refactor_contributing_doc ``` In subsequent development, if the master branch of the local repository is behind the master branch of "upstream", we need to pull the upstream for synchronization, and then execute the above command: ```shell git pull upstream main ``` #### 4. Commit the code and pass the unit test - lmdeploy introduces mypy to do static type checking to increase the robustness of the code. Therefore, we need to add Type Hints to our code and pass the mypy check. If you are not familiar with Type Hints, you can refer to [this tutorial](https://docs.python.org/3/library/typing.html). - The committed code should pass through the unit test ```shell # Pass all unit tests pytest tests # Pass the unit test of runner pytest tests/test_runner/test_runner.py ``` If the unit test fails for lack of dependencies, you can install the dependencies referring to the [guidance](#unit-test) - If the documents are modified/added, we should check the rendering result referring to [guidance](#document-rendering) #### 5. Push the code to remote We could push the local commits to remote after passing through the check of unit test and pre-commit. You can associate the local branch with remote branch by adding `-u` option. ```shell git push -u origin {branch_name} ``` This will allow you to use the `git push` command to push code directly next time, without having to specify a branch or the remote repository. #### 6. Create a Pull Request (1) Create a pull request in GitHub's Pull request interface (2) Modify the PR description according to the guidelines so that other developers can better understand your changes Find more details about Pull Request description in [pull request guidelines](#pr-specs). **note** (a) The Pull Request description should contain the reason for the change, the content of the change, and the impact of the change, and be associated with the relevant Issue (see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)) (b) If it is your first contribution, please sign the CLA (c) Check whether the Pull Request pass through the CI LMDeploy will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code. (3) If the Pull Request passes the CI, then you can wait for the review from other developers. You'll modify the code based on the reviewer's comments, and repeat the steps [4](#4-commit-the-code-and-pass-the-unit-test)-[5](#5-push-the-code-to-remote) until all reviewers approve it. Then, we will merge it ASAP. #### 7. Resolve conflicts If your local branch conflicts with the latest master branch of "upstream", you'll need to resolove them. There are two ways to do this: ```shell git fetch --all --prune git rebase upstream/main ``` or ```shell git fetch --all --prune git merge upstream/main ``` If you are very good at handling conflicts, then you can use rebase to resolve conflicts, as this will keep your commit logs tidy. If you are not familiar with `rebase`, then you can use `merge` to resolve conflicts. ### Guidance #### Document rendering If the documents are modified/added, we should check the rendering result. We could install the dependencies and run the following command to render the documents and check the results: ```shell pip install -r requirements/docs.txt cd docs/zh_cn/ # or docs/en make html # check file in ./docs/zh_cn/_build/html/index.html ``` ### Code style #### Python We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. We use the following tools for linting and formatting: - [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools. - [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports. - [yapf](https://github.com/google/yapf): A formatter for Python files. - [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files. - [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files. - [docformatter](https://github.com/myint/docformatter): A formatter to format docstring. We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`, fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit. The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml). #### C++ and CUDA The clang-format config is stored in [.clang-format](../.clang-format). And it's recommended to use clang-format version **11**. Please do not use older or newer versions as they will result in differences after formatting, which can cause the [lint](https://github.com/InternLM/lmdeploy/blob/main/.github/workflows/lint.yml#L25) to fail. ### PR Specs 1. Use [pre-commit](https://pre-commit.com) hook to avoid issues of code style 2. One short-time branch should be matched with only one PR 3. Accomplish a detailed change in one PR. Avoid large PR - Bad: Support Faster R-CNN - Acceptable: Add a box head to Faster R-CNN - Good: Add a parameter to box head to support custom conv-layer number 4. Provide clear and significant commit message 5. Provide clear and meaningful PR description - Task name should be clarified in title. The general format is: \[Prefix\] Short description of the PR (Suffix) - Prefix: add new feature \[Feature\], fix bug \[Fix\], related to documents \[Docs\], in developing \[WIP\] (which will not be reviewed temporarily) - Introduce main changes, results and influences on other modules in short description - Associate related issues and pull requests with a milestone ================================================ FILE: .github/ISSUE_TEMPLATE/1-bug-report.yml ================================================ name: 🐞 Bug report description: Create a report to help us reproduce and fix the bug title: "[Bug] " labels: ['Bug'] body: - type: checkboxes attributes: label: Checklist options: - label: 1. I have searched related issues but cannot get the expected help. - label: 2. The bug has not been fixed in the latest version. - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. - type: textarea attributes: label: Describe the bug description: A clear and concise description of what the bug is. validations: required: true - type: textarea attributes: label: Reproduction description: | 1. What command or script did you run? placeholder: | A placeholder for the command. validations: required: true - type: textarea attributes: label: Environment description: | 1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here. 2. You may add addition that may be helpful for locating the problem, such as - Which **model** are you using? - How you installed PyTorch \[e.g., pip, conda, source\] - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) placeholder: Environment here. render: Shell validations: required: true - type: textarea attributes: label: Error traceback description: | If applicable, paste the error trackback here. placeholder: Logs and traceback here. render: Shell - type: markdown attributes: value: > If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated! Thanks for your bug report. We appreciate it a lot. ================================================ FILE: .github/ISSUE_TEMPLATE/2-feature-request.yml ================================================ name: 🚀 Feature request description: Suggest an idea for this project title: "[Feature] " body: - type: markdown attributes: value: | We strongly appreciate you creating a PR to implement this feature [here](https://github.com/InternLM/lmdeploy/pulls)! If you need our help, please fill in as much of the following form as you're able to. **The less clear the description, the longer it will take to solve it.** - type: textarea attributes: label: Motivation description: | A clear and concise description of the motivation of the feature. Ex1. It is inconvenient when \[....\]. validations: required: true - type: textarea attributes: label: Related resources description: | If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. - type: textarea attributes: label: Additional context description: | Add any other context or screenshots about the feature request here. If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated. ================================================ FILE: .github/ISSUE_TEMPLATE/3-documentation.yml ================================================ name: 📚 Documentation description: Report an issue related to the documentation. labels: "kind/doc,status/unconfirmed" title: "[Docs] " body: - type: textarea attributes: label: 📚 The doc issue description: > A clear and concise description the issue. validations: required: true - type: textarea attributes: label: Suggest a potential alternative/fix description: > Tell us how we could improve the documentation in this regard. - type: markdown attributes: value: > Thanks for contributing 🎉! ================================================ FILE: .github/pull_request_template.md ================================================ Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Please describe the motivation of this PR and the goal you want to achieve through this PR. ## Modification Please briefly describe what modification is made in this PR. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repositories? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness. 3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects. 4. The documentation has been modified accordingly, like docstring or example tutorials. ================================================ FILE: .github/release.yml ================================================ changelog: categories: - title: 🚀 Features labels: - feature - enhancement - title: 💥 Improvements labels: - improvement - title: 🐞 Bug fixes labels: - bug - Bug:P0 - Bug:P1 - Bug:P2 - Bug:P3 - title: 📚 Documentations labels: - documentation - title: 🌐 Other labels: - '*' exclude: labels: - feature - enhancement - improvement - bug - documentation - Bug:P0 - Bug:P1 - Bug:P2 - Bug:P3 ================================================ FILE: .github/scripts/action_tools.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import glob import json import logging import os import shutil import subprocess import time from collections import OrderedDict from typing import List import fire import pandas as pd from mmengine.config import Config def run_cmd(cmd_lines: List[str], log_path: str, cwd: str = None): """ Args: cmd_lines: (list[str]): A command in multiple line style. log_path (str): Path to log file. cwd (str): Path to the current working directory. Returns: int: error code. """ import platform system = platform.system().lower() if system == 'windows': sep = r'`' else: # 'Linux', 'Darwin' sep = '\\' cmd_for_run = ' '.join(cmd_lines) cmd_for_log = f' {sep}\n'.join(cmd_lines) + '\n' with open(log_path, 'w', encoding='utf-8') as file_handler: file_handler.write(f'Command: {cmd_for_log}\n') file_handler.flush() process_res = subprocess.Popen(cmd_for_run, shell=True, cwd=cwd, stdout=file_handler, stderr=file_handler) process_res.wait() return_code = process_res.returncode if return_code != 0: logging.error(f'Got shell abnormal return code={return_code}') with open(log_path, 'r') as f: content = f.read() logging.error(f'Log error message\n{content}') return return_code def _append_summary(content): summary_file = os.environ['GITHUB_STEP_SUMMARY'] with open(summary_file, 'a') as f: f.write(content + '\n') def add_summary(csv_path: str): """Add csv file to github step summary. Args: csv_path (str): Input csv file. """ with open(csv_path, 'r') as fr: lines = fr.readlines() header = lines[0].strip().split(',') n_col = len(header) header = '|' + '|'.join(header) + '|' aligner = '|' + '|'.join([':-:'] * n_col) + '|' _append_summary(header) _append_summary(aligner) for line in lines[1:]: line = '|' + line.strip().replace(',', '|') + '|' _append_summary(line) _append_summary('\n') def evaluate(models: List[str], datasets: List[str], workspace: str, evaluate_type: str, max_num_workers: int = 8, is_smoke: bool = False): """Evaluate models from lmdeploy using opencompass. Args: models: Input models. workspace: Working directory. """ os.makedirs(workspace, exist_ok=True) output_csv = os.path.join(workspace, f'results_{evaluate_type}.csv') if os.path.exists(output_csv): os.remove(output_csv) num_model = len(models) for idx, ori_model in enumerate(models): print() print(50 * '==') print(f'Start evaluating {idx+1}/{num_model} {ori_model} ...') model = ori_model.lower() lmdeploy_dir = os.path.abspath(os.environ['LMDEPLOY_DIR']) config_path = os.path.join(lmdeploy_dir, f'.github/scripts/eval_{evaluate_type}_config.py') config_path_new = os.path.join(lmdeploy_dir, 'eval_lmdeploy.py') if os.path.exists(config_path_new): os.remove(config_path_new) shutil.copy(config_path, config_path_new) cfg = Config.fromfile(config_path_new) if not hasattr(cfg, model): logging.error(f'Model {model} not in configuration file') continue model_cfg = cfg[model] logging.info(f'Start evaluating {model} ...\\nn{model_cfg}\n\n') with open(config_path_new, 'a') as f: f.write(f'\ndatasets = {datasets}\n') if is_smoke: f.write('\nfor d in datasets:\n') f.write(" if d['reader_cfg'] is not None:\n") f.write(" d['reader_cfg']['test_range'] = '[0:50]'\n") if model.startswith('hf'): f.write(f'\nmodels = [*{model}]\n') else: f.write(f'\nmodels = [{model}]\n') work_dir = os.path.join(workspace, model) cmd_eval = [ f'opencompass {config_path_new} -w {work_dir} --reuse --max-num-workers {max_num_workers} --dump-res-length' # noqa: E501 ] eval_log = os.path.join(workspace, f'eval.{ori_model}.txt') start_time = time.time() ret = run_cmd(cmd_eval, log_path=eval_log, cwd=lmdeploy_dir) end_time = time.time() task_duration_seconds = round(end_time - start_time, 2) logging.info(f'task_duration_seconds: {task_duration_seconds}\n') if ret != 0: continue csv_files = glob.glob(f'{work_dir}/*/summary/summary_*.csv') if len(csv_files) < 1: logging.error(f'Did not find summary csv file {csv_files}') continue else: csv_file = max(csv_files, key=os.path.getctime) # print csv_txt to screen csv_txt = csv_file.replace('.csv', '.txt') if os.path.exists(csv_txt): with open(csv_txt, 'r') as f: print(f.read()) # parse evaluation results from csv file model_results = OrderedDict() with open(csv_file, 'r') as f: lines = f.readlines() for line in lines[1:]: row = line.strip().split(',') row = [_.strip() for _ in row] if row[-1] != '-': model_results[row[0]] = row[-1] crows_pairs_json = glob.glob(os.path.join(work_dir, '*/results/*/crows_pairs.json'), recursive=True) if len(crows_pairs_json) == 1: with open(crows_pairs_json[0], 'r') as f: acc = json.load(f)['accuracy'] acc = f'{float(acc):.2f}' # noqa E231 model_results['crows_pairs'] = acc logging.info(f'\n{model}\n{model_results}') dataset_names = list(model_results.keys()) row = ','.join([model, str(task_duration_seconds)] + [model_results[_] for _ in dataset_names]) if not os.path.exists(output_csv): with open(output_csv, 'w') as f: header = ','.join(['Model', 'task_duration_secs'] + dataset_names) f.write(header + '\n') f.write(row + '\n') else: with open(output_csv, 'a') as f: f.write(row + '\n') # write to github action summary _append_summary('## Evaluation Results') if os.path.exists(output_csv): add_summary(output_csv) def create_model_links(src_dir: str, dst_dir: str): """Create softlinks for models.""" paths = glob.glob(os.path.join(src_dir, '*')) model_paths = [os.path.abspath(p) for p in paths if os.path.isdir(p)] os.makedirs(dst_dir, exist_ok=True) for src in model_paths: _, model_name = os.path.split(src) dst = os.path.join(dst_dir, model_name) if not os.path.exists(dst): os.symlink(src, dst) else: logging.warning(f'Model_path exists: {dst}') def generate_benchmark_report(report_path: str): # write to github action summary _append_summary('## Benchmark Results Start') subfolders = [f.path for f in os.scandir(report_path) if f.is_dir()] for dir_path in subfolders: second_subfolders = [f.path for f in sorted(os.scandir(dir_path), key=lambda x: x.name) if f.is_dir()] for sec_dir_path in second_subfolders: model = sec_dir_path.replace(report_path + '/', '') print('-' * 25, model, '-' * 25) _append_summary('-' * 25 + model + '-' * 25 + '\n') benchmark_subfolders = [ f.path for f in sorted(os.scandir(sec_dir_path), key=lambda x: x.name) if f.is_dir() ] for backend_subfolder in benchmark_subfolders: benchmark_type = backend_subfolder.replace(sec_dir_path + '/', '') print('*' * 10, benchmark_type, '*' * 10) _append_summary('-' * 10 + benchmark_type + '-' * 10 + '\n') merged_csv_path = os.path.join(backend_subfolder, 'summary.csv') csv_files = glob.glob(os.path.join(backend_subfolder, '*.csv')) average_csv_path = os.path.join(backend_subfolder, 'average.csv') if merged_csv_path in csv_files: csv_files.remove(merged_csv_path) if average_csv_path in csv_files: csv_files.remove(average_csv_path) merged_df = pd.DataFrame() if len(csv_files) > 0: for f in csv_files: df = pd.read_csv(f) merged_df = pd.concat([merged_df, df], ignore_index=True) if 'throughput' in backend_subfolder or 'longtext' in backend_subfolder: merged_df = merged_df.sort_values(by=merged_df.columns[1]) grouped_df = merged_df.groupby(merged_df.columns[1]) else: merged_df = merged_df.sort_values(by=merged_df.columns[0]) grouped_df = merged_df.groupby(merged_df.columns[0]) if 'generation' not in backend_subfolder: average_values = grouped_df.pipe((lambda group: { 'mean': group.mean(numeric_only=True).round(decimals=3) }))['mean'] average_values.to_csv(average_csv_path, index=True) avg_df = pd.read_csv(average_csv_path) merged_df = pd.concat([merged_df, avg_df], ignore_index=True) add_summary(average_csv_path) merged_df.to_csv(merged_csv_path, index=False) if 'generation' in backend_subfolder: add_summary(merged_csv_path) _append_summary('## Benchmark Results End') def generate_csv_from_profile_result(file_path: str, out_path: str): with open(file_path, 'r') as f: data = f.readlines() data = [json.loads(line) for line in data] data_csv = [] for item in data: row = [ item.get('request_rate'), item.get('completed'), round(item.get('completed') / item.get('duration'), 3), round(item.get('median_ttft_ms'), 3), round(item.get('output_throughput'), 3) ] data_csv.append(row) import csv with open(out_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['request_rate', 'completed', 'RPM', 'median_ttft_ms', 'output_throughput']) writer.writerows(data_csv) def generate_output_for_evaluation(result_dir: str): # find latest result latest_csv_file = find_csv_files(result_dir) df = pd.read_csv(latest_csv_file) transposed_df = df.T head_part = transposed_df.head(4) tail_part = transposed_df[4:] sorted_tail_part = tail_part.sort_index() transposed_df = pd.concat([head_part, sorted_tail_part]) transposed_df.to_csv('transposed_output.csv', header=False, index=True) # output to github action summary add_summary('transposed_output.csv') def find_csv_files(directory): csv_files = [] for root, dirs, files in os.walk(directory): for file in files: if file.endswith('.csv') and file.startswith('summary'): csv_files.append(os.path.join(root, file)) csv_files_with_time = {f: os.path.getctime(f) for f in csv_files} sorted_csv_files = sorted(csv_files_with_time.items(), key=lambda x: x[1]) latest_csv_file = sorted_csv_files[-1][0] return latest_csv_file if __name__ == '__main__': fire.Fire() ================================================ FILE: .github/scripts/check_lmdeploy.py ================================================ # Copyright (c) MegFlow. All rights reserved. import glob import os import fire def check_module_init(root: str): """Check if a module has __init__.py file.""" all_files = glob.glob(os.path.join(root, '**/*'), recursive=True) not_exist = [] for d in all_files: if not os.path.isdir(d): continue if '__pycache__' in d: continue elif d.startswith('lmdeploy/bin'): continue elif d.startswith('lmdeploy/lib'): continue elif d.startswith('lmdeploy/monitoring'): continue elif d.startswith('lmdeploy/serve/turbomind/triton_models'): continue elif d.startswith('lmdeploy/serve/turbomind/triton_python_backend'): continue init_file = os.path.join(d, '__init__.py') if not os.path.exists(init_file): not_exist.append(init_file) assert len(not_exist) == 0, f'Missing files: {not_exist}' if __name__ == '__main__': fire.Fire() ================================================ FILE: .github/scripts/doc_link_checker.py ================================================ # Copyright (c) MegFlow. All rights reserved. # /bin/python3 import argparse import os import re def make_parser(): parser = argparse.ArgumentParser('Doc link checker') parser.add_argument('--http', default=False, type=bool, help='check http or not ') parser.add_argument('--target', default='./docs', type=str, help='the directory or file to check') return parser pattern = re.compile(r'\[.*?\]\(.*?\)') def analyze_doc(home, path): print('analyze {}'.format(path)) problem_list = [] code_block = 0 with open(path) as f: lines = f.readlines() for line in lines: line = line.strip() if line.startswith('```'): code_block = 1 - code_block if code_block > 0: continue if '[' in line and ']' in line and '(' in line and ')' in line: all = pattern.findall(line) for item in all: # skip ![]() if item.find('[') == item.find(']') - 1: continue # process the case [text()]() offset = item.find('](') if offset == -1: continue item = item[offset:] start = item.find('(') end = item.find(')') ref = item[start + 1:end] if ref.startswith('http') or ref.startswith('#'): continue if '.md#' in ref: ref = ref[ref.find('#'):] fullpath = os.path.join(home, ref) if not os.path.exists(fullpath): problem_list.append(ref) else: continue if len(problem_list) > 0: print(f'{path}:') for item in problem_list: print(f'\t {item}') print('\n') raise Exception('found link error') def traverse(target): if os.path.isfile(target): analyze_doc(os.path.dirname(target), target) return for home, dirs, files in os.walk(target): for filename in files: if filename.endswith('.md'): path = os.path.join(home, filename) if os.path.islink(path) is False: analyze_doc(home, path) if __name__ == '__main__': args = make_parser().parse_args() traverse(args.target) ================================================ FILE: .github/scripts/eval_base_config.py ================================================ from copy import deepcopy from mmengine.config import read_base from opencompass.models import TurboMindModel with read_base(): # choose a list of datasets from opencompass.configs.datasets.ARC_c.ARC_c_few_shot_ppl import ARC_c_datasets # noqa: F401, E501 from opencompass.configs.datasets.bbh.bbh_gen_98fba6 import bbh_datasets # noqa: F401, E501 from opencompass.configs.datasets.ceval.ceval_ppl import ceval_datasets # noqa: F401, E501 from opencompass.configs.datasets.cmmlu.cmmlu_ppl_041cbf import cmmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.crowspairs.crowspairs_ppl import crowspairs_datasets # noqa: F401, E501 from opencompass.configs.datasets.drop.drop_gen_a2697c import drop_datasets # noqa: F401, E501 # Corebench v1.7 from opencompass.configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_d21e37 import \ GaokaoBench_datasets # noqa: F401, E501 from opencompass.configs.datasets.gpqa.gpqa_few_shot_ppl_4b5a83 import gpqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.gsm8k.gsm8k_gen_17d0dc import gsm8k_datasets # noqa: F401, E501 from opencompass.configs.datasets.hellaswag.hellaswag_10shot_ppl_59c85e import \ hellaswag_datasets # noqa: F401, E501 from opencompass.configs.datasets.humaneval.internal_humaneval_gen_ce6b06 import \ humaneval_datasets as humaneval_v2_datasets # noqa: F401, E501 from opencompass.configs.datasets.humaneval.internal_humaneval_gen_d2537e import \ humaneval_datasets # noqa: F401, E501 from opencompass.configs.datasets.math.math_4shot_base_gen_43d5b6 import math_datasets # noqa: F401, E501 from opencompass.configs.datasets.MathBench.mathbench_2024_few_shot_mixed_4a3fd4 import \ mathbench_datasets # noqa: F401, E501 from opencompass.configs.datasets.mbpp.sanitized_mbpp_gen_742f0c import sanitized_mbpp_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu.mmlu_ppl_ac766d import mmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu_pro.mmlu_pro_few_shot_gen_bfaf90 import mmlu_pro_datasets # noqa: F401, E501 from opencompass.configs.datasets.nq.nq_open_1shot_gen_20a989 import nq_datasets # noqa: F401, E501 from opencompass.configs.datasets.race.race_few_shot_ppl import race_datasets # noqa: F401, E501 from opencompass.configs.datasets.SuperGLUE_BoolQ.SuperGLUE_BoolQ_few_shot_ppl import \ BoolQ_datasets # noqa: F401, E501 from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets # noqa: F401, E501 from opencompass.configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import \ triviaqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.wikibench.wikibench_few_shot_ppl_c23d79 import \ wikibench_datasets # noqa: F401, E501 from opencompass.configs.datasets.winogrande.winogrande_5shot_ll_252f01 import \ winogrande_datasets # noqa: F401, E501 # Summary Groups from opencompass.configs.summarizers.groups.cmmlu import cmmlu_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.GaokaoBench import GaokaoBench_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mathbench_v1_2024 import \ mathbench_2024_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mmlu import mmlu_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups # noqa: F401, E501 # read models race_datasets = [race_datasets[1]] mmlu_datasets = [ x for x in mmlu_datasets if x['abbr'].replace('lukaemon_mmlu_', '') in [ 'business_ethics', 'clinical_knowledge', 'college_medicine', 'global_facts', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology' ] ] summarizer = dict( dataset_abbrs=[ ['race-high', 'accuracy'], ['ARC-c', 'accuracy'], ['BoolQ', 'accuracy'], ['mmlu_pro', 'naive_average'], ['GPQA_diamond', 'accuracy'], ['cmmlu', 'naive_average'], ['mmlu', 'naive_average'], ['drop', 'accuracy'], ['bbh', 'naive_average'], ['math', 'accuracy'], ['openai_humaneval', 'humaneval_pass@1'], ['openai_humaneval_v2', 'humaneval_pass@1'], ['sanitized_mbpp', 'score'], ['wikibench-wiki-single_choice_cncircular', 'perf_4'], ['gsm8k', 'accuracy'], ['GaokaoBench', 'weighted_average'], ['triviaqa_wiki_1shot', 'score'], ['nq_open_1shot', 'score'], ['winogrande', 'accuracy'], ['hellaswag', 'accuracy'], ['TheoremQA', 'score'], '###### MathBench-A: Application Part ######', 'college', 'high', 'middle', 'primary', 'arithmetic', 'mathbench-a (average)', '###### MathBench-T: Theory Part ######', 'college_knowledge', 'high_knowledge', 'middle_knowledge', 'primary_knowledge', 'mathbench-t (average)', '###### Overall: Average between MathBench-A and MathBench-T ######', 'Overall', '', 'mmlu', 'mmlu-stem', 'mmlu-social-science', 'mmlu-humanities', 'mmlu-other', 'cmmlu', 'cmmlu-stem', 'cmmlu-social-science', 'cmmlu-humanities', 'cmmlu-other', 'cmmlu-china-specific', 'mmlu_pro', 'mmlu_pro_biology', 'mmlu_pro_business', 'mmlu_pro_chemistry', 'mmlu_pro_computer_science', 'mmlu_pro_economics', 'mmlu_pro_engineering', 'mmlu_pro_health', 'mmlu_pro_history', 'mmlu_pro_law', 'mmlu_pro_math', 'mmlu_pro_philosophy', 'mmlu_pro_physics', 'mmlu_pro_psychology', 'mmlu_pro_other', ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), ) base_model = dict( type=TurboMindModel, engine_config=dict(session_len=7168, tp=1), gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024), max_seq_len=7168, max_out_len=1024, batch_size=32, run_cfg=dict(num_gpus=1), ) turbomind_qwen2_5_1_5b = deepcopy(base_model) turbomind_qwen2_5_1_5b['path'] = 'Qwen/Qwen2.5-1.5B' turbomind_qwen2_5_1_5b['abbr'] = 'turbomind_qwen2_5_1_5b' turbomind_qwen2_5_7b = deepcopy(base_model) turbomind_qwen2_5_7b['path'] = 'Qwen/Qwen2.5-7B' turbomind_qwen2_5_7b['abbr'] = 'turbomind_qwen2_5_7b' turbomind_qwen2_5_32b = deepcopy(base_model) turbomind_qwen2_5_32b['path'] = 'Qwen/Qwen2.5-32B' turbomind_qwen2_5_32b['abbr'] = 'turbomind_qwen2_5_32b' turbomind_qwen2_5_32b['run_cfg']['num_gpus'] = 2 turbomind_qwen2_5_32b['engine_config']['tp'] = 2 turbomind_internlm2_5_7b = deepcopy(base_model) turbomind_internlm2_5_7b['path'] = 'internlm/internlm2_5-7b-chat' turbomind_internlm2_5_7b['abbr'] = 'turbomind_internlm2_5_7b' turbomind_glm_4_9b = deepcopy(base_model) turbomind_glm_4_9b['path'] = 'THUDM/glm-4-9b' turbomind_glm_4_9b['abbr'] = 'turbomind_glm_4_9b' turbomind_llama_3_70b = deepcopy(base_model) turbomind_llama_3_70b['path'] = 'meta-llama/Meta-Llama-3-70B' turbomind_llama_3_70b['abbr'] = 'turbomind_llama_3_70b' turbomind_llama_3_70b['run_cfg']['num_gpus'] = 4 turbomind_llama_3_70b['engine_config']['tp'] = 4 turbomind_llama_3_1_8b = deepcopy(base_model) turbomind_llama_3_1_8b['path'] = 'meta-llama/Llama-3.1-8B' turbomind_llama_3_1_8b['abbr'] = 'turbomind_llama_3_1_8b' turbomind_qwen3_0_6b_base = deepcopy(base_model) turbomind_qwen3_0_6b_base['path'] = 'Qwen/Qwen3-0.6B-Base' turbomind_qwen3_0_6b_base['abbr'] = 'turbomind_qwen3_0_6b_base' turbomind_qwen3_8b_base = deepcopy(base_model) turbomind_qwen3_8b_base['path'] = 'Qwen/Qwen3-8B-Base' turbomind_qwen3_8b_base['abbr'] = 'turbomind_qwen3_8b_base' turbomind_qwen3_30b_A3B_base = deepcopy(base_model) turbomind_qwen3_30b_A3B_base['path'] = 'Qwen/Qwen3-30B-A3B-Base' turbomind_qwen3_30b_A3B_base['abbr'] = 'turbomind_qwen3_30b_A3B_base' turbomind_qwen3_30b_A3B_base['run_cfg']['num_gpus'] = 2 turbomind_qwen3_30b_A3B_base['engine_config']['tp'] = 2 pytorch_qwen2_5_1_5b = deepcopy(base_model) pytorch_qwen2_5_1_5b['path'] = 'Qwen/Qwen2.5-1.5B' pytorch_qwen2_5_1_5b['abbr'] = 'pytorch_qwen2_5_1_5b' pytorch_qwen2_5_7b = deepcopy(base_model) pytorch_qwen2_5_7b['path'] = 'Qwen/Qwen2.5-7B' pytorch_qwen2_5_7b['abbr'] = 'pytorch_qwen2_5_7b' pytorch_qwen2_5_32b = deepcopy(base_model) pytorch_qwen2_5_32b['path'] = 'Qwen/Qwen2.5-32B' pytorch_qwen2_5_32b['abbr'] = 'pytorch_qwen2_5_32b' pytorch_qwen2_5_32b['run_cfg']['num_gpus'] = 2 pytorch_qwen2_5_32b['engine_config']['tp'] = 2 pytorch_internlm2_5_7b = deepcopy(base_model) pytorch_internlm2_5_7b['path'] = 'internlm/internlm2_5-7b-chat' pytorch_internlm2_5_7b['abbr'] = 'pytorch_internlm2_5_7b' pytorch_gemma_2_9b = deepcopy(base_model) pytorch_gemma_2_9b['path'] = 'google/gemma-2-9b' pytorch_gemma_2_9b['abbr'] = 'pytorch_gemma_2_9b' pytorch_llama_3_70b = deepcopy(base_model) pytorch_llama_3_70b['path'] = 'meta-llama/Meta-Llama-3-70B' pytorch_llama_3_70b['abbr'] = 'pytorch_llama_3_70b' pytorch_llama_3_70b['run_cfg']['num_gpus'] = 4 pytorch_llama_3_70b['engine_config']['tp'] = 4 pytorch_llama_3_1_8b = deepcopy(base_model) pytorch_llama_3_1_8b['path'] = 'meta-llama/Llama-3.1-8B' pytorch_llama_3_1_8b['abbr'] = 'pytorch_llama_3_1_8b' pytorch_qwen3_0_6b_base = deepcopy(base_model) pytorch_qwen3_0_6b_base['path'] = 'Qwen/Qwen3-0.6B-Base' pytorch_qwen3_0_6b_base['abbr'] = 'pytorch_qwen3_0_6b_base' pytorch_qwen3_8b_base = deepcopy(base_model) pytorch_qwen3_8b_base['path'] = 'Qwen/Qwen3-8B-Base' pytorch_qwen3_8b_base['abbr'] = 'pytorch_qwen3_8b_base' pytorch_qwen3_30b_A3B_base = deepcopy(base_model) pytorch_qwen3_30b_A3B_base['path'] = 'Qwen/Qwen3-30B-A3B-Base' pytorch_qwen3_30b_A3B_base['abbr'] = 'pytorch_qwen3_30b_A3B_base' pytorch_qwen3_30b_A3B_base['run_cfg']['num_gpus'] = 2 pytorch_qwen3_30b_A3B_base['engine_config']['tp'] = 2 for model in [v for k, v in locals().items() if k.startswith('pytorch_')]: model['backend'] = 'pytorch' ================================================ FILE: .github/scripts/eval_chat_config.py ================================================ from copy import deepcopy from mmengine.config import read_base from opencompass.models import TurboMindModelwithChatTemplate from opencompass.utils.text_postprocessors import extract_non_reasoning_content with read_base(): # choose a list of datasets from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets # noqa: F401, E501 from opencompass.configs.datasets.ceval.ceval_gen_2daf24 import ceval_datasets # noqa: F401, E501 from opencompass.configs.datasets.cmmlu.cmmlu_gen_c13365 import cmmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets # noqa: F401, E501 from opencompass.configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import \ GaokaoBench_datasets # noqa: F401, E501 from opencompass.configs.datasets.gpqa.gpqa_gen_4baadb import gpqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets # noqa: F401, E501 from opencompass.configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import \ hellaswag_datasets # noqa: F401, E501 from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets # noqa: F401, E501 from opencompass.configs.datasets.IFEval.IFEval_gen_3321a3 import ifeval_datasets # noqa: F401, E501 from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets # noqa: F401, E501 from opencompass.configs.datasets.mbpp.sanitized_mbpp_gen_a0fc46 import sanitized_mbpp_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu.mmlu_gen_4d595a import mmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import \ mmlu_pro_datasets # noqa: F401, E501 from opencompass.configs.datasets.nq.nq_open_1shot_gen_01cf41 import nq_datasets # noqa: F401, E501 from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets # noqa: F401, E501 from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets # noqa: F401, E501 from opencompass.configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_eaf81e import \ triviaqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.winogrande.winogrande_5shot_gen_b36770 import \ winogrande_datasets # noqa: F401, E501 # read models from opencompass.configs.models.baichuan.hf_baichuan2_7b_chat import \ models as hf_baichuan2_chat_7b # noqa: F401, E501 from opencompass.configs.models.gemma.hf_gemma2_9b_it import models as hf_gemma2_9b_it # noqa: F401, E501 from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b_chat import \ models as hf_internlm2_5_7b_chat # noqa: F401, E501 from opencompass.configs.models.hf_internlm.hf_internlm2_5_20b_chat import \ models as hf_internlm2_5_20b_chat # noqa: F401, E501 from opencompass.configs.models.hf_internlm.hf_internlm2_chat_7b import \ models as hf_internlm2_chat_7b # noqa: F401, E501 from opencompass.configs.models.hf_internlm.hf_internlm2_chat_20b import \ models as hf_internlm2_chat_20b # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \ models as lmdeploy_internlm2_5_7b_chat # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import \ models as lmdeploy_internlm2_5_20b_chat # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b import \ models as lmdeploy_internlm2_chat_7b # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_20b import \ models as lmdeploy_internlm2_chat_20b # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import \ models as lmdeploy_internlm3_8b_instruct # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm_chat_7b import \ models as lmdeploy_internlm_chat_7b # noqa: F401, E501 from opencompass.configs.models.hf_llama.hf_llama2_7b_chat import models as hf_llama2_chat_7b # noqa: F401, E501 from opencompass.configs.models.hf_llama.hf_llama3_1_8b_instruct import \ models as hf_llama3_1_8b_instruct # noqa: F401, E501 from opencompass.configs.models.hf_llama.hf_llama3_8b_instruct import \ models as hf_llama_3_8b_instruct # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama2_7b_chat import \ models as lmdeploy_llama2_7b_chat # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import \ models as lmdeploy_llama3_1_8b_instruct # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import \ models as lmdeploy_llama3_8b_instruct # noqa: F401, E501 from opencompass.configs.models.mistral.hf_mistral_7b_instruct_v0_1 import \ models as hf_mistral_chat_7b # noqa: F401, E501 from opencompass.configs.models.mistral.hf_mixtral_8x7b_instruct_v0_1 import \ models as hf_mixtral_chat_8x7b # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \ models as lmdeploy_qwen2_5_7b_instruct # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import \ models as lmdeploy_qwen2_5_32b_instruct # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen1_5_7b_chat import models as hf_qwen1_5_chat_7b # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen1_5_moe_a2_7b_chat import \ models as hf_qwen1_5_moe_a2_7b_chat # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen2_7b_instruct import models as hf_qwen2_7b_instruct # noqa: F401, E501 from opencompass.configs.models.qwen.hf_qwen_7b_chat import models as hf_qwen_chat_7b # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen1_5_7b_chat import \ models as lmdeploy_qwen1_5_7b_chat # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import \ models as lmdeploy_qwen2_7b_instruct # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen_7b_chat import \ models as lmdeploy_qwen_7b_chat # noqa: F401, E501 # Summary Groups from opencompass.configs.summarizers.groups.bbh import bbh_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.cmmlu import cmmlu_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.ds1000 import ds1000_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.GaokaoBench import GaokaoBench_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.humanevalx import humanevalx_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mathbench_v1_2024 import \ mathbench_2024_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mmlu import mmlu_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.scicode import scicode_summary_groups # noqa: F401, E501 from opencompass.configs.summarizers.groups.teval import teval_summary_groups # noqa: F401, E501 llama2_meta_template = dict(round=[ dict(role='HUMAN', begin='[INST] ', end=' [/INST]'), dict(role='BOT', begin='', end='', generate=True), ], eos_token_id=2) MAX_SESSION_LEN = 2048 MAX_NEW_TOKENS = 1024 # ===== Configs for internlm/internlm2-chat-7b ===== turbomind_internlm2_chat_7b = deepcopy(*lmdeploy_internlm2_chat_7b) turbomind_internlm2_chat_7b_4bits = deepcopy(*lmdeploy_internlm2_chat_7b) turbomind_internlm2_chat_7b_kvint4 = deepcopy(*lmdeploy_internlm2_chat_7b) turbomind_internlm2_chat_7b_kvint8 = deepcopy(*lmdeploy_internlm2_chat_7b) pytorch_internlm2_chat_7b = deepcopy(*lmdeploy_internlm2_chat_7b) # ===== Configs for internlm/internlm2_5_7b_chat ===== turbomind_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_7b_chat) pytorch_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat) pytorch_internlm2_5_7b_chat_w8a8 = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_batch1 = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm2_5_7b_chat_batch1_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat) turbomind_internlm3_8b_instruct = deepcopy(*lmdeploy_internlm3_8b_instruct) turbomind_internlm3_8b_instruct_4bits = deepcopy(*lmdeploy_internlm3_8b_instruct) turbomind_internlm3_8b_instruct_kvint4 = deepcopy(*lmdeploy_internlm3_8b_instruct) turbomind_internlm3_8b_instruct_kvint8 = deepcopy(*lmdeploy_internlm3_8b_instruct) pytorch_internlm3_8b_instruct = deepcopy(*lmdeploy_internlm3_8b_instruct) pytorch_internlm3_8b_instruct_w8a8 = deepcopy(*lmdeploy_internlm3_8b_instruct) # ===== Configs for internlm/internlm2_5_20b_chat ===== turbomind_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat) turbomind_internlm2_5_20b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_20b_chat) turbomind_internlm2_5_20b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_20b_chat) turbomind_internlm2_5_20b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_20b_chat) pytorch_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat) # ===== Configs for internlm/internlm2_chat_20b ===== turbomind_internlm2_chat_20b = deepcopy(*lmdeploy_internlm2_chat_20b) turbomind_internlm2_chat_20b_4bits = deepcopy(*lmdeploy_internlm2_chat_20b) turbomind_internlm2_chat_20b_kvint4 = deepcopy(*lmdeploy_internlm2_chat_20b) turbomind_internlm2_chat_20b_kvint8 = deepcopy(*lmdeploy_internlm2_chat_20b) pytorch_internlm2_chat_20b = deepcopy(*lmdeploy_internlm2_chat_20b) # ===== Configs for Qwen/Qwen1.5-7B-Chat ===== turbomind_qwen1_5_7b_chat = deepcopy(*lmdeploy_qwen1_5_7b_chat) turbomind_qwen1_5_7b_chat_4bits = deepcopy(*lmdeploy_qwen1_5_7b_chat) turbomind_qwen1_5_7b_chat_kvint4 = deepcopy(*lmdeploy_qwen1_5_7b_chat) turbomind_qwen1_5_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen1_5_7b_chat) pytorch_qwen1_5_7b_chat = deepcopy(*lmdeploy_qwen1_5_7b_chat) # ===== Configs for Qwen/Qwen-7B-Chat ===== turbomind_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat) turbomind_qwen_7b_chat_4bits = deepcopy(*lmdeploy_qwen_7b_chat) turbomind_qwen_7b_chat_kvint4 = deepcopy(*lmdeploy_qwen_7b_chat) turbomind_qwen_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen_7b_chat) pytorch_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat) # ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct ===== turbomind_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct) turbomind_llama3_8b_instruct_4bits = deepcopy(*lmdeploy_llama3_8b_instruct) turbomind_llama3_8b_instruct_kvint4 = deepcopy(*lmdeploy_llama3_8b_instruct) turbomind_llama3_8b_instruct_kvint8 = deepcopy(*lmdeploy_llama3_8b_instruct) pytorch_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct) # ===== Configs for meta-llama/Meta-Llama-3.1-8B-Instruct ===== turbomind_llama3_1_8b_instruct = deepcopy(*lmdeploy_llama3_1_8b_instruct) turbomind_llama3_1_8b_instruct['path'] = 'meta-llama/Meta-Llama-3-1-8B-Instruct' turbomind_llama3_1_8b_instruct_4bits = deepcopy(turbomind_llama3_1_8b_instruct) turbomind_llama3_1_8b_instruct_kvint4 = deepcopy(turbomind_llama3_1_8b_instruct) turbomind_llama3_1_8b_instruct_kvint8 = deepcopy(turbomind_llama3_1_8b_instruct) pytorch_llama3_1_8b_instruct = deepcopy(turbomind_llama3_1_8b_instruct) pytorch_llama3_1_8b_instruct_w8a8 = deepcopy(turbomind_llama3_1_8b_instruct) # ===== Configs for Qwen/Qwen2-7B-Instruct ===== turbomind_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct) turbomind_qwen2_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_7b_instruct) turbomind_qwen2_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_7b_instruct) turbomind_qwen2_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_7b_instruct) pytorch_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct) pytorch_qwen2_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_7b_instruct) # ===== Configs for Qwen/Qwen25-7B-Instruct ===== turbomind_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct) turbomind_qwen2_5_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_7b_instruct) turbomind_qwen2_5_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) turbomind_qwen2_5_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) pytorch_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct) pytorch_qwen2_5_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct) # ===== Configs for Qwen/Qwen25-32B-Instruct ===== turbomind_qwen2_5_32b_instruct = deepcopy(*lmdeploy_qwen2_5_32b_instruct) turbomind_qwen2_5_32b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_32b_instruct) turbomind_qwen2_5_32b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_32b_instruct) turbomind_qwen2_5_32b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_32b_instruct) pytorch_qwen2_5_32b_instruct = deepcopy(*lmdeploy_qwen2_5_32b_instruct) pytorch_qwen2_5_32b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_32b_instruct) # ===== Configs for meta-llama/Llama-2-7b-chat-hf ===== turbomind_llama2_7b_chat = deepcopy(*lmdeploy_llama2_7b_chat) turbomind_llama2_7b_chat_4bits = deepcopy(*lmdeploy_llama2_7b_chat) turbomind_llama2_7b_chat_kvint4 = deepcopy(*lmdeploy_llama2_7b_chat) turbomind_llama2_7b_chat_kvint8 = deepcopy(*lmdeploy_llama2_7b_chat) base_model = dict(type=TurboMindModelwithChatTemplate, engine_config=dict(session_len=32768, max_batch_size=256), gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=32768), max_seq_len=32768, max_out_len=32768, batch_size=500, pred_postprocessor=dict(type=extract_non_reasoning_content), run_cfg=dict(num_gpus=1)) turbomind_qwen3_32b = deepcopy(base_model) pytorch_qwen3_32b = deepcopy(base_model) turbomind_qwen3_32b_4bits = deepcopy(base_model) turbomind_qwen3_32b_kvint8 = deepcopy(base_model) turbomind_qwen3_30b_a3b = deepcopy(base_model) pytorch_qwen3_30b_a3b = deepcopy(base_model) turbomind_qwen3_30b_a3b_4bits = deepcopy(base_model) turbomind_qwen3_30b_a3b_kvint8 = deepcopy(base_model) turbomind_qwen3_30b_a3b_fp8 = deepcopy(base_model) pytorch_qwen3_30b_a3b_fp8 = deepcopy(base_model) turbomind_qwen3_30b_a3b_fp8['engine_config']['cache_max_entry_count'] = 0.6 turbomind_qwen3_235b_a22b = deepcopy(base_model) pytorch_qwen3_235b_a22b = deepcopy(base_model) turbomind_qwen3_235b_a22b_4bits = deepcopy(base_model) turbomind_qwen3_235b_a22b_kvint8 = deepcopy(base_model) turbomind_qwen3_235b_a22b_fp8 = deepcopy(base_model) pytorch_qwen3_235b_a22b_fp8 = deepcopy(base_model) # update config for Qwen3-32B, Qwen3-30B-A3B, Qwen3-235B-A22B for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_32b') or k.startswith('pytorch_qwen3_32b') ]: model['abbr'] = 'qwen3_32b_turbomind' model['path'] = 'Qwen/Qwen3-32B' for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_30b_a3b') or k.startswith('pytorch_qwen3_30b_a3b') ]: model['abbr'] = 'qwen3_30b_a3b_turbomind' model['path'] = 'Qwen/Qwen3-30B-A3B' for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_30b_a3b_fp8') or k.startswith('pytorch_qwen3_30b_a3b_fp8') ]: model['abbr'] = 'qwen3_30b_a3b_fp8_turbomind' model['path'] = 'Qwen/Qwen3-30B-A3B-FP8' for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_235b_a22b') or k.startswith('pytorch_qwen3_235b_a22b') ]: model['abbr'] = 'qwen3_235b_a22b_turbomind' model['path'] = 'Qwen/Qwen3-235B-A22B' for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_235b_a22b_fp8') or k.startswith('pytorch_qwen3_235b_a22b_fp8') ]: model['abbr'] = 'qwen3_235b_a22b_fp8_turbomind' model['path'] = 'Qwen/Qwen3-235B-A22B-FP8' # update config for turbomind, w4a4, w8a8, kvint4, kvint8, pytorch models for model in [v for k, v in locals().items() if k.startswith('turbomind_')]: model['engine_config']['max_batch_size'] = 512 model['gen_config']['do_sample'] = False model['batch_size'] = 1000 for model in [v for k, v in locals().items() if k.endswith('_4bits')]: model['engine_config']['model_format'] = 'awq' model['abbr'] = model['abbr'] + '_4bits' model['path'] = model['path'] + '-inner-4bits' for model in [v for k, v in locals().items() if k.endswith('_w8a8')]: model['abbr'] = model['abbr'] + '_w8a8' model['path'] = model['path'] + '-inner-w8a8' for model in [v for k, v in locals().items() if k.endswith('_kvint4')]: model['engine_config']['quant_policy'] = 4 model['abbr'] = model['abbr'] + '_kvint4' for model in [v for k, v in locals().items() if k.endswith('_kvint8')]: model['engine_config']['quant_policy'] = 8 model['abbr'] = model['abbr'] + '_kvint8' for model in [v for k, v in locals().items() if k.startswith('pytorch_')]: model['abbr'] = model['abbr'].replace('turbomind', 'pytorch') model['backend'] = 'pytorch' model['engine_config']['max_batch_size'] = 512 model['gen_config']['do_sample'] = False model['batch_size'] = 1000 for model in [v for k, v in locals().items() if '_batch1' in k]: model['abbr'] = model['abbr'] + '_batch1' model['engine_config']['max_batch_size'] = 1 model['batch_size'] = 1 # update config for Qwen3-32B, Qwen3-30B-A3B, Qwen3-235B-A22B for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_32b') or k.startswith('pytorch_qwen3_32b') ]: model['run_cfg']['num_gpus'] = 2 model['engine_config']['tp'] = 2 model['engine_config']['max_batch_size'] = 1024 model['batch_size'] = 2048 for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_30b_a3b') or k.startswith('pytorch_qwen3_30b_a3b') ]: model['run_cfg']['num_gpus'] = 2 model['engine_config']['tp'] = 2 model['engine_config']['max_batch_size'] = 1024 model['batch_size'] = 2048 for model in [ v for k, v in locals().items() if k.startswith('turbomind_qwen3_235b_a22b') or k.startswith('pytorch_qwen3_235b_a22b') ]: model['run_cfg']['num_gpus'] = 8 model['engine_config']['tp'] = 8 model['engine_config']['max_batch_size'] = 1024 model['batch_size'] = 2048 turbomind_qwen3_235b_a22b_fp8['engine_config']['cache_max_entry_count'] = 0.6 turbomind_qwen3_235b_a22b_fp8['engine_config']['tp'] = 4 turbomind_qwen3_235b_a22b_fp8['run_cfg']['num_gpus'] = 4 pytorch_qwen3_235b_a22b_fp8['engine_config']['tp'] = 4 pytorch_qwen3_235b_a22b_fp8['run_cfg']['num_gpus'] = 4 basic_pytorch_chat_tp1 = dict(type=TurboMindModelwithChatTemplate, engine_config=dict(session_len=MAX_SESSION_LEN, max_batch_size=512, tp=1), gen_config=dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS), max_out_len=MAX_NEW_TOKENS, max_seq_len=MAX_SESSION_LEN, batch_size=1000, run_cfg=dict(num_gpus=1)) # ===== Configs for Qwen/Qwen1.5-MoE-A2.7B-Chat ===== pytorch_qwen1_5_moe_2_7b_chat = deepcopy(basic_pytorch_chat_tp1) pytorch_qwen1_5_moe_2_7b_chat['abbr'] = 'pytorch_qwen1_5_moe_2_7b_chat' pytorch_qwen1_5_moe_2_7b_chat['path'] = 'Qwen/Qwen1.5-MoE-A2.7B-Chat' # ===== Configs for google/gemma2-7b-it ===== pytorch_gemma_2_9b_it = deepcopy(basic_pytorch_chat_tp1) pytorch_gemma_2_9b_it['abbr'] = 'pytorch_gemma_2_9b_it' pytorch_gemma_2_9b_it['path'] = 'google/gemma-2-9b-it' # ===== Configs for google/gemma2-27b-it ===== pytorch_gemma_2_27b_it = deepcopy(basic_pytorch_chat_tp1) pytorch_gemma_2_27b_it['abbr'] = 'pytorch_gemma_2_27b_it' pytorch_gemma_2_27b_it['path'] = 'google/gemma-2-27b-it' pytorch_gemma_2_27b_it['run_cfg']['num_gpus'] = 2 pytorch_gemma_2_27b_it['engine_config']['tp'] = 2 race_datasets = [race_datasets[1]] # Summarizer summarizer = dict( dataset_abbrs=[ ['race-high', 'accuracy'], ['ARC-c', 'accuracy'], ['BoolQ', 'accuracy'], ['mmlu_pro', 'naive_average'], ['drop', 'accuracy'], ['bbh', 'naive_average'], ['GPQA_diamond', 'accuracy'], ['math', 'accuracy'], ['wikibench-wiki-single_choice_cncircular', 'perf_4'], ['openai_humaneval', 'humaneval_pass@1'], ['sanitized_mbpp', 'score'], ['cmmlu', 'naive_average'], ['mmlu', 'naive_average'], ['teval', 'naive_average'], ['SciCode', 'accuracy'], ['SciCode', 'sub_accuracy'], ['humanevalx', 'naive_average'], ['ds1000', 'naive_average'], ['IFEval', 'Prompt-level-strict-accuracy'], ['gsm8k', 'accuracy'], ['GaokaoBench', 'weighted_average'], ['triviaqa_wiki_1shot', 'score'], ['nq_open_1shot', 'score'], ['hellaswag', 'accuracy'], ['TheoremQA', 'score'], '###### MathBench-A: Application Part ######', 'college', 'high', 'middle', 'primary', 'arithmetic', 'mathbench-a (average)', '###### MathBench-T: Theory Part ######', 'college_knowledge', 'high_knowledge', 'middle_knowledge', 'primary_knowledge', 'mathbench-t (average)', '###### Overall: Average between MathBench-A and MathBench-T ######', 'Overall', '', '' 'mmlu', 'mmlu-stem', 'mmlu-social-science', 'mmlu-humanities', 'mmlu-other', '', 'cmmlu', 'cmmlu-stem', 'cmmlu-social-science', 'cmmlu-humanities', 'cmmlu-other', 'cmmlu-china-specific', '', 'mmlu_pro', 'mmlu_pro_biology', 'mmlu_pro_business', 'mmlu_pro_chemistry', 'mmlu_pro_computer_science', 'mmlu_pro_economics', 'mmlu_pro_engineering', 'mmlu_pro_health', 'mmlu_pro_history', 'mmlu_pro_law', 'mmlu_pro_math', 'mmlu_pro_philosophy', 'mmlu_pro_physics', 'mmlu_pro_psychology', 'mmlu_pro_other', '', 'humanevalx-python', 'humanevalx-cpp', 'humanevalx-go', 'humanevalx-java', 'humanevalx-js', '', 'ds1000_Pandas', 'ds1000_Numpy', 'ds1000_Tensorflow', 'ds1000_Scipy', 'ds1000_Sklearn', 'ds1000_Pytorch', 'ds1000_Matplotlib', ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), ) ================================================ FILE: .github/scripts/eval_regression_base_models.py ================================================ from copy import deepcopy from mmengine.config import read_base with read_base(): # choose a list of datasets from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.gsm8k.gsm8k_gen_17d0dc import gsm8k_datasets # noqa: F401, E501 from opencompass.configs.datasets.race.race_ppl import race_datasets # noqa: F401, E501 from opencompass.configs.datasets.winogrande.winogrande_5shot_ll_252f01 import \ winogrande_datasets # noqa: F401, E501 # read hf models - chat models from opencompass.configs.models.chatglm.lmdeploy_glm4_9b import models as lmdeploy_glm4_9b_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_7b_base import \ models as lmdeploy_deepseek_7b_base_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_67b_base import \ models as lmdeploy_deepseek_67b_base_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2 import lmdeploy_deepseek_v2_model # noqa: F401, E501 from opencompass.configs.models.gemma.lmdeploy_gemma_9b import models as pytorch_gemma_9b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_1_8b import \ models as lmdeploy_internlm2_1_8b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b import \ models as lmdeploy_internlm2_5_7b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_20b import \ models as lmdeploy_internlm2_20b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_base_7b import \ models as lmdeploy_internlm2_base_7b_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b import \ models as lmdeploy_llama3_1_8b_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b import \ models as lmdeploy_llama3_8b_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_70b import \ models as lmdeploy_llama3_70b_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_1_5b import \ models as lmdeploy_qwen2_5_1_5b_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b import \ models as lmdeploy_qwen2_5_7b_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b import \ models as lmdeploy_qwen2_5_32b_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b import \ models as lmdeploy_qwen2_5_72b_model # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen2_1_5b import \ models as lmdeploy_qwen2_1_5b_model # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen2_7b import models as lmdeploy_qwen2_7b_model # noqa: F401, E501 from opencompass.configs.models.yi.lmdeploy_yi_1_5_9b import models as lmdeploy_yi_1_5_9b_model # noqa: F401, E501 from .volc import infer as volc_infer # noqa: F401, E501 race_datasets = [race_datasets[1]] datasets = sum([v for k, v in locals().items() if k.endswith('_datasets')], []) pytorch_glm4_9b_model = deepcopy(lmdeploy_glm4_9b_model) pytorch_deepseek_7b_base_model = deepcopy(lmdeploy_deepseek_7b_base_model) pytorch_deepseek_67b_base_model = deepcopy(lmdeploy_deepseek_67b_base_model) pytorch_deepseek_v2_model = deepcopy(lmdeploy_deepseek_v2_model) pytorch_internlm2_5_7b_model = deepcopy(lmdeploy_internlm2_5_7b_model) pytorch_internlm2_20b_model = deepcopy(lmdeploy_internlm2_20b_model) pytorch_internlm2_base_7b_model = deepcopy(lmdeploy_internlm2_base_7b_model) pytorch_llama3_1_8b_model = deepcopy(lmdeploy_llama3_1_8b_model) pytorch_llama3_70b_model = deepcopy(lmdeploy_llama3_70b_model) pytorch_qwen2_5_1_5b_model = deepcopy(lmdeploy_qwen2_5_1_5b_model) pytorch_qwen2_5_72b_model = deepcopy(lmdeploy_qwen2_5_72b_model) pytorch_qwen2_7b_model = deepcopy(lmdeploy_qwen2_7b_model) pytorch_yi_1_5_9b_model = deepcopy(lmdeploy_yi_1_5_9b_model) pytorch_deepseek_v2_model['engine_config']['cache_max_entry_count'] = 0.6 lmdeploy_glm4_9b_model_native = deepcopy(lmdeploy_glm4_9b_model) lmdeploy_deepseek_7b_base_model_native = deepcopy(lmdeploy_deepseek_7b_base_model) lmdeploy_deepseek_67b_base_model_native = deepcopy(lmdeploy_deepseek_67b_base_model) lmdeploy_deepseek_v2_model_native = deepcopy(lmdeploy_deepseek_v2_model) lmdeploy_internlm2_5_7b_model_native = deepcopy(lmdeploy_internlm2_5_7b_model) lmdeploy_internlm2_20b_model_native = deepcopy(lmdeploy_internlm2_20b_model) lmdeploy_internlm2_base_7b_model_native = deepcopy(lmdeploy_internlm2_base_7b_model) lmdeploy_llama3_1_8b_model_native = deepcopy(lmdeploy_llama3_1_8b_model) lmdeploy_llama3_70b_model_native = deepcopy(lmdeploy_llama3_70b_model) lmdeploy_qwen2_5_1_5b_model_native = deepcopy(lmdeploy_qwen2_5_1_5b_model) lmdeploy_qwen2_5_72b_model_native = deepcopy(lmdeploy_qwen2_5_72b_model) lmdeploy_qwen2_7b_model_native = deepcopy(lmdeploy_qwen2_7b_model) lmdeploy_yi_1_5_9b_model_native = deepcopy(lmdeploy_yi_1_5_9b_model) for model in [v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')]: for m in model: m['engine_config']['max_batch_size'] = 512 m['gen_config']['do_sample'] = False m['batch_size'] = 5000 for model in [v for k, v in locals().items() if k.startswith('lmdeploy_')]: for m in model: m['backend'] = 'turbomind' for model in [v for k, v in locals().items() if k.startswith('pytorch_')]: for m in model: m['abbr'] = m['abbr'].replace('turbomind', 'pytorch').replace('lmdeploy', 'pytorch') m['backend'] = 'pytorch' for model in [v for k, v in locals().items() if k.endswith('_native')]: for m in model: m['abbr'] = m['abbr'] + '_native' m['engine_config']['communicator'] = 'native' # models = sum([v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')], []) # models = sorted(models, key=lambda x: x['run_cfg']['num_gpus']) summarizer = dict( dataset_abbrs=[ ['gsm8k', 'accuracy'], ['GPQA_diamond', 'accuracy'], ['race-high', 'accuracy'], ['winogrande', 'accuracy'], ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), ) ================================================ FILE: .github/scripts/eval_regression_chat_models.py ================================================ from copy import deepcopy from mmengine.config import read_base with read_base(): # choose a list of datasets from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets # noqa: F401, E501 from opencompass.configs.datasets.math.math_0shot_gen_11c4b5 import math_datasets # noqa: F401, E501 # read hf models - chat models from opencompass.configs.models.chatglm.lmdeploy_glm4_9b_chat import \ models as lmdeploy_glm4_9b_chat_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_r1_distill_qwen_32b import \ models as lmdeploy_deepseek_r1_distill_qwen_32b_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2_5_1210 import \ models as lmdeploy_deepseek_v2_5_1210_model # noqa: F401, E501 from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2_lite import \ models as lmdeploy_deepseek_v2_lite_model # noqa: F401, E501 from opencompass.configs.models.gemma.lmdeploy_gemma_9b_it import \ models as pytorch_gemma_9b_it_model # noqa: F401, E501 from opencompass.configs.models.gemma.lmdeploy_gemma_27b_it import \ models as pytorch_gemma_27b_it_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \ models as lmdeploy_internlm2_5_7b_chat_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import \ models as lmdeploy_internlm2_5_20b_chat_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_1_8b import \ models as lmdeploy_internlm2_chat_1_8b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_1_8b_sft import \ models as lmdeploy_internlm2_chat_1_8b_sft_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b import \ models as lmdeploy_internlm2_chat_7b_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b_sft import \ models as lmdeploy_internlm2_chat_7b_sft_model # noqa: F401, E501 from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import \ models as lmdeploy_internlm3_8b_instruct_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama2_7b_chat import \ models as lmdeploy_llama2_7b_chat_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import \ models as lmdeploy_llama3_1_8b_instruct_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_2_3b_instruct import \ models as lmdeploy_llama3_2_3b_instruct_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_3_70b_instruct import \ models as lmdeploy_llama3_3_70b_instruct_model # noqa: F401, E501 from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import \ models as lmdeploy_llama3_8b_instruct_model # noqa: F401, E501 from opencompass.configs.models.mistral.lmdeploy_mistral_large_instruct_2411 import \ models as lmdeploy_mistral_large_instruct_2411_model # noqa: F401, E501 from opencompass.configs.models.mistral.lmdeploy_mistral_nemo_instruct_2407 import \ models as lmdeploy_mistral_nemo_instruct_2407_model # noqa: F401, E501 from opencompass.configs.models.mistral.lmdeploy_mistral_small_instruct_2409 import \ models as lmdeploy_mistral_small_instruct_2409_model # noqa: F401, E501 from opencompass.configs.models.nvidia.lmdeploy_nemotron_70b_instruct_hf import \ models as lmdeploy_nemotron_70b_instruct_hf_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_0_5b_instruct import \ models as lmdeploy_qwen2_5_0_5b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_3b_instruct import \ models as lmdeploy_qwen2_5_3b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import \ models as lmdeploy_qwen2_5_14b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import \ models as lmdeploy_qwen2_5_32b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b_instruct import \ models as lmdeploy_qwen2_5_72b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen2_1_5b_instruct import \ models as lmdeploy_qwen2_1_5b_instruct_model # noqa: F401, E501 from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import \ models as lmdeploy_qwen2_7b_instruct_model # noqa: F401, E501 from opencompass.configs.models.yi.lmdeploy_yi_1_5_6b_chat import \ models as lmdeploy_yi_1_5_6b_chat_model # noqa: F401, E501 from opencompass.configs.models.yi.lmdeploy_yi_1_5_9b_chat import \ models as lmdeploy_yi_1_5_9b_chat_model # noqa: F401, E501 from opencompass.configs.models.yi.lmdeploy_yi_1_5_34b_chat import \ models as lmdeploy_yi_1_5_34b_chat_model # noqa: F401, E501 from .volc import infer as volc_infer # noqa: F401, E501 datasets = sum([v for k, v in locals().items() if k.endswith('_datasets')], []) pytorch_glm4_9b_chat_model = deepcopy(lmdeploy_glm4_9b_chat_model) pytorch_deepseek_v2_lite_model = deepcopy(lmdeploy_deepseek_v2_lite_model) pytorch_deepseek_v2_5_1210_model = deepcopy(lmdeploy_deepseek_v2_5_1210_model) pytorch_internlm3_8b_instruct_model = deepcopy(lmdeploy_internlm3_8b_instruct_model) pytorch_internlm2_5_7b_chat_model = deepcopy(lmdeploy_internlm2_5_7b_chat_model) pytorch_internlm2_5_20b_chat_model = deepcopy(lmdeploy_internlm2_5_20b_chat_model) pytorch_llama3_2_3b_instruct_model = deepcopy(lmdeploy_llama3_2_3b_instruct_model) pytorch_llama3_3_70b_instruct_model = deepcopy(lmdeploy_llama3_3_70b_instruct_model) pytorch_mistral_nemo_instruct_2407_model = deepcopy(lmdeploy_mistral_nemo_instruct_2407_model) pytorch_mistral_small_instruct_2409_model = deepcopy(lmdeploy_mistral_small_instruct_2409_model) pytorch_qwen2_5_72b_instruct_model = deepcopy(lmdeploy_qwen2_5_72b_instruct_model) pytorch_qwen2_5_32b_instruct_model = deepcopy(lmdeploy_qwen2_5_32b_instruct_model) pytorch_qwen2_7b_instruct_model = deepcopy(lmdeploy_qwen2_7b_instruct_model) pytorch_yi_1_5_34b_chat_model = deepcopy(lmdeploy_yi_1_5_34b_chat_model) pytorch_deepseek_v2_5_1210_model['engine_config']['cache_max_entry_count'] = 0.6 lmdeploy_glm4_9b_chat_model_native = deepcopy(lmdeploy_glm4_9b_chat_model) lmdeploy_deepseek_r1_distill_qwen_32b_model_native = deepcopy(lmdeploy_deepseek_r1_distill_qwen_32b_model) lmdeploy_deepseek_v2_lite_model_native = deepcopy(lmdeploy_deepseek_v2_lite_model) lmdeploy_deepseek_v2_5_1210_model_native = deepcopy(lmdeploy_deepseek_v2_5_1210_model) lmdeploy_internlm3_8b_instruct_model_native = deepcopy(lmdeploy_internlm3_8b_instruct_model) lmdeploy_internlm2_5_7b_chat_model_native = deepcopy(lmdeploy_internlm2_5_7b_chat_model) lmdeploy_internlm2_5_20b_chat_model_native = deepcopy(lmdeploy_internlm2_5_20b_chat_model) lmdeploy_llama3_1_8b_instruct_model_native = deepcopy(lmdeploy_llama3_1_8b_instruct_model) lmdeploy_llama3_2_3b_instruct_model_native = deepcopy(lmdeploy_llama3_2_3b_instruct_model) lmdeploy_llama3_8b_instruct_model_native = deepcopy(lmdeploy_llama3_8b_instruct_model) lmdeploy_llama3_3_70b_instruct_model_native = deepcopy(lmdeploy_llama3_3_70b_instruct_model) lmdeploy_mistral_large_instruct_2411_model_native = deepcopy(lmdeploy_mistral_large_instruct_2411_model) lmdeploy_mistral_nemo_instruct_2407_model_native = deepcopy(lmdeploy_mistral_nemo_instruct_2407_model) lmdeploy_mistral_small_instruct_2409_model_native = deepcopy(lmdeploy_mistral_small_instruct_2409_model) lmdeploy_nemotron_70b_instruct_hf_model_native = deepcopy(lmdeploy_nemotron_70b_instruct_hf_model) lmdeploy_qwen2_5_0_5b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_0_5b_instruct_model) lmdeploy_qwen2_5_14b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_14b_instruct_model) lmdeploy_qwen2_5_32b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_32b_instruct_model) lmdeploy_qwen2_5_72b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_72b_instruct_model) lmdeploy_qwen2_7b_instruct_model_native = deepcopy(lmdeploy_qwen2_7b_instruct_model) lmdeploy_yi_1_5_6b_chat_model_native = deepcopy(lmdeploy_yi_1_5_6b_chat_model) lmdeploy_yi_1_5_34b_chat_model_native = deepcopy(lmdeploy_yi_1_5_34b_chat_model) for model in [v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')]: for m in model: m['engine_config']['max_batch_size'] = 512 m['gen_config']['do_sample'] = False m['batch_size'] = 5000 for model in [v for k, v in locals().items() if k.startswith('lmdeploy_')]: for m in model: m['backend'] = 'turbomind' for model in [v for k, v in locals().items() if k.startswith('pytorch_')]: for m in model: m['abbr'] = m['abbr'].replace('turbomind', 'pytorch').replace('lmdeploy', 'pytorch') m['backend'] = 'pytorch' for model in [v for k, v in locals().items() if k.endswith('_native')]: for m in model: m['abbr'] = m['abbr'] + '_native' m['engine_config']['communicator'] = 'native' # models = sum([v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')], []) # models = sorted(models, key=lambda x: x['run_cfg']['num_gpus']) summarizer = dict( dataset_abbrs=[ ['GPQA_diamond', 'accuracy'], ['math', 'accuracy'], ['IFEval', 'Prompt-level-strict-accuracy'], ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), ) ================================================ FILE: .github/scripts/eval_stable_object_config.py ================================================ from mmengine.config import read_base from opencompass.models import OpenAISDK with read_base(): # choose a list of datasets from opencompass.configs.datasets.ARC_c.ARC_c_cot_gen_926652 import ARC_c_datasets # noqa: F401, E501 from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets # noqa: F401, E501 from opencompass.configs.datasets.CHARM.charm_reason_cot_only_gen_f7b7d3 import \ charm_reason_datasets # noqa: F401, E501 from opencompass.configs.datasets.cmmlu.cmmlu_0shot_cot_gen_305931 import cmmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.drop.drop_openai_simple_evals_gen_3857b0 import drop_datasets # noqa: F401, E501 from opencompass.configs.datasets.ds1000.ds1000_service_eval_gen_cbc84f import ds1000_datasets # noqa: F401, E501 from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets # noqa: F401, E501 from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import gsm8k_datasets # noqa: F401, E501 from opencompass.configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import \ hellaswag_datasets # noqa: F401, E501 from opencompass.configs.datasets.humaneval.humaneval_openai_sample_evals_gen_159614 import \ humaneval_datasets # noqa: F401, E501 from opencompass.configs.datasets.humanevalx.humanevalx_gen_620cfa import humanevalx_datasets # noqa: F401, E501 from opencompass.configs.datasets.IFEval.IFEval_gen_3321a3 import ifeval_datasets # noqa: F401, E501 from opencompass.configs.datasets.LCBench.lcbench_gen_5ff288 import LCBench_datasets # noqa: F401, E501 from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets # noqa: F401, E501 from opencompass.configs.datasets.MathBench.mathbench_2024_gen_50a320 import mathbench_datasets # noqa: F401, E501 from opencompass.configs.datasets.mbpp.sanitized_mbpp_mdblock_gen_a447ff import \ sanitized_mbpp_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu.mmlu_openai_simple_evals_gen_b618ea import mmlu_datasets # noqa: F401, E501 from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import \ mmlu_pro_datasets # noqa: F401, E501 from opencompass.configs.datasets.race.race_cot_gen_d95929 import race_datasets # noqa: F401, E501 from opencompass.configs.datasets.scicode.scicode_gen_085b98 import SciCode_datasets # noqa: F401, E501 from opencompass.configs.datasets.SuperGLUE_BoolQ.SuperGLUE_BoolQ_cot_gen_1d56df import \ BoolQ_datasets # noqa: F401, E501 from opencompass.configs.datasets.teval.teval_en_gen_1ac254 import \ teval_datasets as teval_en_datasets # noqa: F401, E501 from opencompass.configs.datasets.teval.teval_zh_gen_1ac254 import \ teval_datasets as teval_zh_datasets # noqa: F401, E501 from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets # noqa: F401, E501 from opencompass.configs.datasets.wikibench.wikibench_gen_0978ad import wikibench_datasets # noqa: F401, E501 datasets = sum( (v for k, v in locals().items() if k.endswith('_datasets') and 'scicode' not in k.lower() and 'teval' not in k), []) datasets += teval_en_datasets datasets += teval_zh_datasets datasets += SciCode_datasets api_meta_template = dict( round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')], ) models = [ dict( abbr='lmdeploy-api-test', type=OpenAISDK, key='EMPTY', openai_api_base='http://localhost:23344/v1', path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat', tokenizer_path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat', rpm_verbose=True, meta_template=api_meta_template, query_per_second=100, max_out_len=1024, max_seq_len=4096, temperature=0.01, batch_size=128, retry=3, ) ] ================================================ FILE: .github/scripts/eval_stable_subject_config.py ================================================ from mmengine.config import read_base from opencompass.models import OpenAISDK from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner from opencompass.runners import LocalRunner from opencompass.tasks.subjective_eval import SubjectiveEvalTask with read_base(): # choose a list of datasets from opencompass.configs.datasets.subjective.alignbench.alignbench_judgeby_critiquellm import \ alignbench_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import \ alpacav2_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.arena_hard.arena_hard_compare import \ arenahard_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.compassarena.compassarena_compare import \ compassarena_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.fofo.fofo_bilingual_judge import fofo_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.multiround.mtbench101_judge import \ mtbench101_datasets # noqa: F401, E501 from opencompass.configs.datasets.subjective.wildbench.wildbench_pair_judge import \ wildbench_datasets # noqa: F401, E501 datasets = sum((v for k, v in locals().items() if k.endswith('_datasets') and 'wildbench' not in k), []) datasets += wildbench_datasets api_meta_template = dict( round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ], reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')], ) models = [ dict( abbr='lmdeploy-api-test', type=OpenAISDK, key='EMPTY', openai_api_base='http://localhost:23344/v1', path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat', tokenizer_path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat', rpm_verbose=True, meta_template=api_meta_template, query_per_second=100, max_out_len=1024, max_seq_len=4096, temperature=0.01, batch_size=128, retry=3, ) ] judge_models = models eval = dict( partitioner=dict( type=SubjectiveNaivePartitioner, models=models, judge_models=judge_models, ), runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)), ) ================================================ FILE: .github/workflows/api_eval.yml ================================================ name: api_eval on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM/lmdeploy' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" execution_mode: required: false description: 'Select execution mode: infer, eval, or both. Default is "both"' type: choice options: - both - infer - eval default: 'both' run_id: required: false description: 'Set custom run ID. If not provided, github.run_id will be used' type: string default: '' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true REPORT_DIR: /nvme/qa_test_models/evaluation_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }} COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy COMPASS_DATA_CACHE: /nvme/qa_test_models/compass_data_cache HF_DATASETS_OFFLINE: 1 HF_DATASETS_CACHE: /nvme/qa_test_models/hf_datasets HF_HUB_OFFLINE: 1 HF_EVALUATE_OFFLINE: 1 RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }} jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, linux-a100] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | chmod -R 777 ${{env.TEST_CODE_PATH}} mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt test_evaluation: needs: download_pkgs if: ${{ !cancelled() }} runs-on: [self-hosted, linux-a100] timeout-minutes: 7200 strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8'] transformers: ["", "legacy"] env: TEST_ENV: ${{ matrix.transformers }} container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/github-actions/resources:/root/resources - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install opencompass run: | git clone https://github.com/open-compass/opencompass.git --depth 1 cd opencompass python3 -m pip install . python3 -m pip install langdetect - name: Downgrade transformers if: ${{matrix.transformers == 'legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Setup paths for evaluation if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') run: | overall_exit=0 ln -s /mnt/104/opencompass-data/data ./data ln -s /nvme/qa_test_models/resource/nltk_data /usr/share/nltk_data execution_mode="${{ github.event.inputs.execution_mode || 'both' }}" ulimit -n 65535 if [ "$execution_mode" = "both" ] || [ "$execution_mode" = "infer" ]; then pytest autotest/evaluate/test_api_evaluate.py -m "${{matrix.gpu_num}} and ${{matrix.backend}} and infer" --alluredir=${{env.REPORT_DIR}} || overall_exit=$? fi if [ "$execution_mode" = "both" ] || [ "$execution_mode" = "eval" ]; then pytest autotest/evaluate/test_api_evaluate.py -m "${{matrix.gpu_num}} and ${{matrix.backend}} and eval" -n 4 --alluredir=${{env.REPORT_DIR}} || overall_exit=$? fi exit $overall_exit - name: Clear workspace if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.REPORT_DIR}} export workdir=$(pwd) rm -rf $workdir/* ================================================ FILE: .github/workflows/benchmark.yml ================================================ name: benchmark_test on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' benchmark_type: required: true description: 'Set benchmark type. Default is "["longtext", "throughput", "api_server", "prefixcache"]"' type: string default: "['apiserver', 'mllm_apiserver', 'throughput', 'longtext', 'prefixcache']" backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} REPORT_DIR: /nvme/qa_test_models/benchmark_report/${{ inputs.repo_ref }}_${{ github.run_id }} ALLURE_REPORT_DIR: /nvme/qa_test_models/benchmark_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }} TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }} jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, linux-a100] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | chmod -R 777 ${{env.TEST_CODE_PATH}} mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt benchmark: needs: download_pkgs if: ${{github.event_name == 'schedule' || !cancelled()}} runs-on: [self-hosted, linux-a100] strategy: fail-fast: false matrix: benchmark_type: ${{fromJSON(github.event.inputs.benchmark_type)}} gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8'] transformers: ["", "legacy"] include: - n: 8 gpu_num: gpu_num_1 - n: 4 gpu_num: gpu_num_2 - n: 2 gpu_num: gpu_num_4 - n: 1 gpu_num: gpu_num_8 env: TEST_ENV: ${{ matrix.transformers }} timeout-minutes: 480 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Downgrade transformers if: ${{matrix.transformers == 'legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env - name: Run other benchmark - all if: contains(fromJson(github.event.inputs.backend), 'turbomind') && contains(fromJson(github.event.inputs.backend), 'pytorch') run: | pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function' --alluredir=${{env.ALLURE_REPORT_DIR}} - name: Run other benchmark - turbomind if: contains(fromJson(github.event.inputs.backend), 'turbomind') && !contains(fromJson(github.event.inputs.backend), 'pytorch') run: | pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function and turbomind' --alluredir=${{env.ALLURE_REPORT_DIR}} - name: Run other benchmark - pytorch if: contains(fromJson(github.event.inputs.backend), 'pytorch') && !contains(fromJson(github.event.inputs.backend), 'turbomind') run: | pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function and pytorch' --alluredir=${{env.ALLURE_REPORT_DIR}} - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/cuda12.8_whl_release.yml ================================================ name: cuda12.8-whl-release on: push: tags: - '*' workflow_dispatch: permissions: contents: write jobs: linux-build: strategy: matrix: pyver: [py310, py311, py312, py313] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 OUTPUT_FOLDER: cuda12.8_dist CUDA_VER: 12.8 steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }}/* retention-days: 1 name: linux-${{ matrix.pyver }} windows-build: strategy: matrix: pyver: ['3.10', '3.11', '3.12', '3.13'] runs-on: windows-latest steps: - name: Set git for windows run: | git config --global core.longpaths true - name: Checkout repository uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: ${{ matrix.pyver }} - name: Install python packages run: | pip install build change-wheel-version - name: Setup CUDA Toolkit id: cuda-toolkit shell: pwsh run: ./builder/windows/setup_cuda.ps1 env: INPUT_CUDA_VERSION: '12.8.1' - name: Build wheel run: | python -m build --wheel -o build/wheel Get-ChildItem -Path "build" -Filter "*.whl" | ForEach-Object { change_wheel_version $_.FullName --local-version cu128 --delete-old-wheel } - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: build/wheel/* retention-days: 1 name: windows-${{ matrix.pyver }} publish: runs-on: ubuntu-latest environment: 'prod' needs: - linux-build - windows-build steps: - name: Checkout repository uses: actions/checkout@v3 - name: Download artifacts uses: actions/download-artifact@v4 with: path: artifact merge-multiple: true - name: Add cuda version to package name run: | ver=$(cat lmdeploy/version.py | grep '__version__ =' | cut -d\' -f2) cuver=$ver+cu128 ls -lh cd artifact for file in *; do mv "$file" "`echo $file | sed "s/$ver/$cuver/g"`"; done - name: Display artifacts run: ls artifact/ -lh - name: Publish uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: artifact/* ================================================ FILE: .github/workflows/daily_ete_test.yml ================================================ name: daily_ete_test on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" model: required: true description: 'Set testcase module filter: llm, mllm. Default contains all models' type: string default: "['llm','mllm']" function: required: true description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions' type: string default: '["pipeline", "restful", "chat"]' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false regression_func: required: true description: 'regression functions' type: string default: "['quant', 'tools','restful','pipeline','benchmark','evaluation']" schedule: - cron: '00 14 * * 0-4' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true ROOT_DIR: /nvme/qa_test_models REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt DEEPSEEK_VL: /nvme/qa_test_models/offline_pkg/DeepSeek-VL RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} jobs: linux-build: if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, linux-a100] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | chmod -R 777 ${{env.TEST_CODE_PATH}} mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt test_quantization: needs: download_pkgs if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}} runs-on: [self-hosted, linux-a100] timeout-minutes: 150 strategy: matrix: transformers: ["", "legacy"] env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: ${{ matrix.transformers }} container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install auto_gptq matplotlib attrdict python3 -m pip install -r requirements/lite.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt pip install ${{env.DEEPSEEK_VL}} --no-deps rm -rf ${{env.DEEPSEEK_VL}}/build - name: Check env run: | pip install transformers==4.57.6 python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - quantization w4a16 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind') run: | pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - quantization w8a8 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch') run: | pytest autotest/tools/quantization/test_quantization_w8a8.py -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_tools: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization timeout-minutes: 300 strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} model: ${{ fromJSON(inputs.model || '["llm", "mllm"]')}} transformers: ["", "legacy"] function: ${{ fromJSON(inputs.function || '["pipeline","restful","chat"]')}} exclude: - backend: turbomind model: mllm function: chat - backend: pytorch model: mllm function: chat include: - backend: turbomind model: llm function: other env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: ${{ matrix.transformers }} container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt pip install ${{env.DEEPSEEK_VL}} --no-deps rm -rf ${{env.DEEPSEEK_VL}}/build - name: Downgrade transformers if: ${{matrix.transformers == 'legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env cp -r /nvme/qa_test_models/offline_pkg/lora . rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - chat continue-on-error: true if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat' run: | pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - pipeline continue-on-error: true if: matrix.function == 'pipeline' run: | pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - restful continue-on-error: true if: matrix.function == 'restful' run: | pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - local testcase if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'other' run: | pytest autotest/toolchain --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_restful: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} model_path: ['Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-32B', 'OpenGVLab/InternVL3_5-30B-A3B', 'OpenGVLab/InternVL3-38B', 'Qwen/Qwen3-VL-8B-Instruct', 'Qwen/Qwen3-VL-30B-A3B-Instruct'] include: - tp: 2 model: Qwen3-8B-Base model_path: Qwen/Qwen3-8B-Base case_info: ['completions_v1'] generate_type: base - tp: 2 model: Qwen3-30B-A3B model_path: Qwen/Qwen3-30B-A3B case_info: ['chat_completions_v1', 'generate'] generate_type: all extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts' backend: pytorch - tp: 2 model: Qwen3-30B-A3B model_path: Qwen/Qwen3-30B-A3B case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' backend: turbomind - tp: 2 model: InternVL3_5-30B-A3B model_path: OpenGVLab/InternVL3_5-30B-A3B case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts' backend: pytorch - tp: 2 model: InternVL3_5-30B-A3B model_path: OpenGVLab/InternVL3_5-30B-A3B case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' backend: turbomind - tp: 2 model: Qwen3-VL-30B-A3B-Instruct model_path: Qwen/Qwen3-VL-30B-A3B-Instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts' backend: pytorch - tp: 2 model: Qwen3-VL-30B-A3B-Instruct model_path: Qwen/Qwen3-VL-30B-A3B-Instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' backend: turbomind - tp: 2 model: Qwen3-32B model_path: Qwen/Qwen3-32B case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' - tp: 1 model: Qwen3-VL-8B-Instruct model_path: Qwen/Qwen3-VL-8B-Instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts' backend: pytorch - tp: 1 model: Qwen3-VL-8B-Instruct model_path: Qwen/Qwen3-VL-8B-Instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' backend: turbomind - tp: 2 model: InternVL3-38B model_path: OpenGVLab/InternVL3-38B case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' timeout-minutes: 60 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Check env run: | python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Start restful api run: | lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} --allow-terminate-by-client > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 240) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then echo "health check success" exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 1 - name: Test lmdeploy - chat_completions_v1 if: matrix.model != 'internlm2_5-20b-chat' && matrix.model != 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - chat_completions_v1 if: matrix.model == 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - chat_completions_v1 - internlm2_5-20b-chat if: matrix.model == 'internlm2_5-20b-chat' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - internlm2_5-20b if: matrix.model == 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - other if: matrix.model != 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - base if: matrix.generate_type == 'base' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not logprob and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - logprob if: matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - all if: matrix.generate_type == 'all' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Kill api server if: always() run: | curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_pipeline: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'pipeline'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization timeout-minutes: 240 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt pip install ${{env.DEEPSEEK_VL}} --no-deps rm -rf ${{env.DEEPSEEK_VL}}/build - name: Check env run: | python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - interface pipeline case run: | pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_8 and not pr_test' -n 1 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_benchmark: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'benchmark'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization timeout-minutes: 120 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt pip install ${{env.DEEPSEEK_VL}} --no-deps rm -rf ${{env.DEEPSEEK_VL}}/build - name: Check env run: | python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test benchmark script run: | pytest autotest/benchmark -n 4 -m function --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_restful_legacy: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} model_path: ['internlm/Intern-S1'] include: - tp: 8 model: Intern-S1 model_path: internlm/Intern-S1 case_info: ['chat_completions_v1', 'generate'] generate_type: base timeout-minutes: 60 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Check env run: | pip install transformers==4.57.6 python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Start restful api run: | lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} --allow-terminate-by-client > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 240) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then echo "health check success" exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 1 - name: Test lmdeploy - chat_completions_v1 if: matrix.model != 'internlm2_5-20b-chat' && matrix.model != 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - chat_completions_v1 if: matrix.model == 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - chat_completions_v1 - internlm2_5-20b-chat if: matrix.model == 'internlm2_5-20b-chat' && contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - internlm2_5-20b if: matrix.model == 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - other if: matrix.model != 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - base if: matrix.generate_type == 'base' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not logprob and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - logprob if: matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - all if: matrix.generate_type == 'all' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Kill api server if: always() run: | curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_pipeline_legacy: if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'pipeline'))}} runs-on: [self-hosted, linux-a100] needs: test_quantization timeout-minutes: 240 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt pip install ${{env.DEEPSEEK_VL}} --no-deps rm -rf ${{env.DEEPSEEK_VL}}/build - name: Check env run: | pip install transformers==4.57.6 python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - interface pipeline case run: | pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_8 and not pr_test' -n 1 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir get_coverage_report: if: ${{!cancelled()}} runs-on: [self-hosted, linux-a100] needs: [test_tools, test_restful, test_pipeline, test_benchmark] timeout-minutes: 5 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Get coverage report run: | pip install coverage coverage combine ${{env.REPORT_DIR}} coverage xml -o ${{env.REPORT_DIR}}/coverage.xml coverage report -m mv .coverage ${{env.REPORT_DIR}}/.coverage - name: Clear workfile if: always() run: | chmod -R 777 ${{env.ROOT_DIR}} export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/daily_ete_test_3090.yml ================================================ name: daily_ete_test_3090 on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" model: required: true description: 'Set testcase module filter: llm, mllm. Default contains all models' type: string default: "['llm','mllm']" function: required: true description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions' type: string default: '["pipeline", "restful", "chat"]' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false regression_func: required: true description: 'regression functions' type: string default: "['quant', 'tools', 'restful']" schedule: - cron: '00 14 * * 0-4' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda12.4_dist_${{ github.run_id }} ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy FAIL_CONFIG: ${{ github.event_name == 'schedule' && github.run_attempt != 1 && '--lf --lfnf none' || '--lf'}} TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} jobs: linux-build: if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.4 steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, 3090-r1] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /data1:/data1 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt test_quantization: needs: download_pkgs if: ${{!cancelled() && contains(needs.download_pkgs.result, 'success') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}} runs-on: [self-hosted, 3090-r1] timeout-minutes: 150 env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: 3090_legacy container: image: openmmlab/lmdeploy:latest-cu12 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /data1:/data1 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install auto_gptq matplotlib python3 -m pip install -r requirements/lite.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Check env run: | python3 -m pip list pip install transformers==4.57.6 lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - quantization w4a16 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind') run: | pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - quantization w8a8 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch') run: | pytest autotest/tools/quantization/test_quantization_w8a8.py --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_tools: if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}} runs-on: [self-hosted, 3090-r1] needs: test_quantization timeout-minutes: 300 strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} transformers: ["3090", "3090_legacy"] model: ${{ fromJSON(inputs.model || '["llm", "mllm"]')}} function: ${{ fromJSON(inputs.function || '["pipeline","restful","chat"]')}} exclude: - backend: turbomind model: mllm function: chat - backend: pytorch model: mllm function: chat env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: ${{matrix.transformers}} container: image: openmmlab/lmdeploy:latest-cu12 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /data1:/data1 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Downgrade transformers if: ${{matrix.transformers == '3090_legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - chat continue-on-error: true if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat' run: | pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Test lmdeploy - pipeline continue-on-error: true if: matrix.function == 'pipeline' run: | pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Test lmdeploy - restful continue-on-error: true if: matrix.function == 'restful' run: | pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_restful: if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}} runs-on: [self-hosted, 3090-r1] needs: test_quantization strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} transformers: ["3090", "3090_legacy"] model_path: ['internlm/internlm3-8b-instruct', 'Qwen/Qwen3-8B'] include: - tp: 1 model: internlm3-8b-instruct model_path: internlm/internlm3-8b-instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' - tp: 1 model: Qwen3-8B model_path: Qwen/Qwen3-8B case_info: ['completions_v1'] generate_type: base timeout-minutes: 60 container: image: openmmlab/lmdeploy:latest-cu12 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro env: TEST_ENV: ${{matrix.transformers}} steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Downgrade transformers if: ${{matrix.transformers == '3090_legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Start restful api run: | lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 & echo "restful_pid=$!" >> "$GITHUB_ENV" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then echo "health check success" exit 0 fi done echo "health check fail" kill -15 $restful_pid 2>/dev/null || true exit 1 - name: Test lmdeploy - chat_completions_v1 if: contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - other if: contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - logprob if: matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Kill api server if: always() run: | kill -15 "$restful_pid" - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir get_coverage_report: if: ${{!cancelled()}} runs-on: [self-hosted, 3090-r1] needs: [test_tools, test_restful] timeout-minutes: 5 container: image: openmmlab/lmdeploy:latest-cu12 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Get coverage report run: | pip install coverage coverage combine ${{env.REPORT_DIR}} coverage xml -o ${{env.REPORT_DIR}}/coverage.xml coverage report -m mv .coverage ${{env.REPORT_DIR}}/.coverage - name: Clear workfile if: always() run: | chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/daily_ete_test_5080.yml ================================================ name: daily_ete_test_5080 on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" model: required: true description: 'Set testcase module filter: llm, mllm. Default contains all models' type: string default: "['llm','mllm']" function: required: true description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions' type: string default: '["pipeline", "restful", "chat"]' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false regression_func: required: true description: 'regression functions' type: string default: "['quant', 'tools', 'restful']" schedule: - cron: '00 14 * * 0-4' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy FAIL_CONFIG: ${{ github.event_name == 'schedule' && github.run_attempt != 1 && '--lf --lfnf none' || '--lf'}} TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }} jobs: linux-build: if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, 5080-r1] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/3090:/mnt/3090 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt test_quantization: needs: download_pkgs if: ${{!cancelled() && contains(needs.download_pkgs.result, 'success') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}} runs-on: [self-hosted, 5080-r1] timeout-minutes: 150 env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: 5080 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/3090:/mnt/3090 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install auto_gptq matplotlib python3 -m pip install -r requirements/lite.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Check env run: | for i in $(seq 1 10); do output=$(lmdeploy check_env 2>&1) if echo "$output" | grep -q "CUDA available: False"; then echo "CUDA not available (attempt $i/10), retrying in 5 seconds..." sleep 5 else echo "CUDA check passed" break fi done python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - quantization w4a16 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind') run: | pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - quantization w8a8 continue-on-error: true if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch') run: | pytest autotest/tools/quantization/test_quantization_w8a8.py --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Clear workfile if: always() run: | chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_tools: if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}} runs-on: [self-hosted, 5080-r1] needs: test_quantization timeout-minutes: 300 strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} model: ${{ fromJSON(inputs.model || '["llm", "mllm"]')}} transformers: ["5080", "5080_legacy"] function: ${{ fromJSON(inputs.function || '["pipeline","restful","chat"]')}} exclude: - backend: turbomind model: mllm function: chat - backend: pytorch model: mllm function: chat env: PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules TEST_ENV: ${{ matrix.transformers }} container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/3090:/mnt/3090 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Downgrade transformers if: ${{matrix.transformers == '5080_legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | for i in $(seq 1 10); do output=$(lmdeploy check_env 2>&1) if echo "$output" | grep -q "CUDA available: False"; then echo "CUDA not available (attempt $i/10), retrying in 5 seconds..." sleep 5 else echo "CUDA check passed" break fi done python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Test lmdeploy - chat continue-on-error: true if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat' run: | pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Test lmdeploy - pipeline continue-on-error: true if: matrix.function == 'pipeline' run: | pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Test lmdeploy - restful continue-on-error: true if: matrix.function == 'restful' run: | pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true - name: Clear workfile if: always() run: | chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir test_restful: if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}} runs-on: [self-hosted, 5080-r1] needs: test_quantization strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} model_path: ['meta-llama/Llama-3.2-3B-Instruct', 'Qwen/Qwen3-4B'] transformers: ["5080", "5080_legacy"] include: - tp: 1 model: Llama-3.2-3B-Instruct model_path: meta-llama/Llama-3.2-3B-Instruct case_info: ['chat_completions_v1', 'generate'] generate_type: logprob extra: '--logprobs-mode raw_logprobs' - tp: 1 model: Qwen3-4B model_path: Qwen/Qwen3-4B case_info: ['completions_v1'] generate_type: base timeout-minutes: 60 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/3090:/mnt/3090 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro env: TEST_ENV: ${{ matrix.transformers }} steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}} - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Downgrade transformers if: ${{matrix.transformers == '5080_legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | for i in $(seq 1 10); do output=$(lmdeploy check_env 2>&1) if echo "$output" | grep -q "CUDA available: False"; then echo "CUDA not available (attempt $i/10), retrying in 5 seconds..." sleep 5 else echo "CUDA check passed" break fi done python3 -m pip list lmdeploy check_env rm -rf allure-results # remove tmp log in testcase mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest - name: Start restful api run: | lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 & echo "restful_pid=$!" >> "$GITHUB_ENV" for i in $(seq 1 50) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then echo "health check success" exit 0 fi done echo "health check fail" kill -15 $restful_pid 2>/dev/null || true exit 1 - name: Test lmdeploy - chat_completions_v1 if: contains(matrix.case_info, 'chat_completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test lmdeploy - completions_v1 - other if: contains(matrix.case_info, 'completions_v1') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Test generate - logprob if: matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate') timeout-minutes: 60 run: | pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') - name: Kill api server if: always() run: | kill -15 "$restful_pid" - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir get_coverage_report: if: ${{!cancelled()}} runs-on: [self-hosted, 5080-r1] needs: [test_tools, test_restful] timeout-minutes: 5 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/3090:/mnt/3090 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: cp -r ${{env.TEST_CODE_PATH}}/. . - name: Install lmdeploy run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Get coverage report run: | pip install coverage coverage combine ${{env.REPORT_DIR}} coverage xml -o ${{env.REPORT_DIR}}/coverage.xml coverage report -m mv .coverage ${{env.REPORT_DIR}}/.coverage - name: Clear workfile if: always() run: | chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/docker.yml ================================================ name: publish-docker on: push: paths-ignore: - "!.github/workflows/docker.yml" - ".github/**" - "docs/**" - "resources/**" - "benchmark/**" - "tests/**" - "**/*.md" - "autotest/**" - "builder/**" - "k8s/**" branches: - main tags: - "v*.*.*" workflow_dispatch: inputs: repo_ref: required: false description: 'Set branch or tag or commit id. Default is ""' type: string default: 'main' image_tag: required: true description: 'Set docker image tag. Default is "latest"' type: string default: latest jobs: publish_docker_image: runs-on: ubuntu-latest environment: 'prod' strategy: fail-fast: false matrix: cuda_version: ['cu12.8', 'cu12'] env: CUDA_VERSION: ${{ matrix.cuda_version }} TAG_PREFIX: "openmmlab/lmdeploy" TAG: "openmmlab/lmdeploy:latest-${{matrix.cuda_version}}" steps: - name: Checkout repository uses: actions/checkout@v4 with: ref: ${{github.event.inputs.repo_ref}} - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Login to Docker Hub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Update docker TAG from workflow input if: github.event_name == 'workflow_dispatch' run: | export TAG=$TAG_PREFIX:${{github.event.inputs.image_tag}}-${CUDA_VERSION} echo $TAG echo "TAG=${TAG}" >> $GITHUB_ENV - name: Build and push Docker image run: | echo $TAG docker build . -f docker/Dockerfile -t ${TAG} --build-arg CUDA_VERSION=${CUDA_VERSION} docker push $TAG - name: Push Docker image as latest if: endsWith(env.TAG, 'latest-cu12') == true run: | export latest_TAG=${TAG_PREFIX}:latest echo $latest_TAG docker tag $TAG $latest_TAG docker push $latest_TAG - name: Push docker image with released tag if: startsWith(github.ref, 'refs/tags/') == true run: | export RELEASE_TAG=${TAG_PREFIX}:${{github.ref_name}}-${CUDA_VERSION} echo $RELEASE_TAG docker tag $TAG $RELEASE_TAG docker push $RELEASE_TAG publish_ascend_docker_image: runs-on: ubuntu-latest environment: 'prod' env: TAG_PREFIX: "openmmlab/lmdeploy" TAG: "openmmlab/lmdeploy:ascend" steps: - name: Checkout repository uses: actions/checkout@v4 with: ref: ${{github.event.inputs.repo_ref}} - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Login to Docker Hub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Update docker TAG from workflow input if: github.event_name == 'workflow_dispatch' run: | export TAG=$TAG_PREFIX:${{github.event.inputs.image_tag}}-ascend echo $TAG echo "TAG=${TAG}" >> $GITHUB_ENV - name: Build and push Docker image run: | echo $TAG docker build . -t ${TAG} -f docker/Dockerfile_ascend_a3 --platform linux/arm64 docker push $TAG ================================================ FILE: .github/workflows/docker_dev.yml ================================================ name: publish-dev-docker on: workflow_dispatch: inputs: repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' jobs: publish_dev_docker_image: runs-on: ubuntu-latest environment: 'prod' env: TAG: "openmmlab/lmdeploy:dev-cu12.8" steps: - name: Checkout repository uses: actions/checkout@v4 with: ref: ${{ github.event.inputs.repo_ref }} - name: Free disk space uses: jlumbroso/free-disk-space@v1.3.1 with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Login to Docker Hub uses: docker/login-action@v2 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push Docker image run: | echo $TAG docker build . -f docker/Dockerfile_dev -t ${TAG} docker push $TAG ================================================ FILE: .github/workflows/evaluate.yml ================================================ name: evaluate on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM/lmdeploy' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' base_models: required: true description: 'Tested TurboMind models list. eg. [turbomind_qwen2_5_1_5b, turbomind_qwen2_5_7b, turbomind_qwen2_5_32b, turbomind_glm_4_9b, turbomind_llama_3_1_8b, turbomind_llama_3_70b, turbomind_qwen3_0_6b_base, turbomind_qwen3_8b_base, turbomind_qwen3_30b_A3B_base, pytorch_qwen2_5_1_5b, pytorch_qwen2_5_7b, pytorch_qwen2_5_32b, pytorch_gemma_2_9b, pytorch_llama_3_70b, pytorch_llama_3_1_8b, pytorch_qwen3_0_6b_base, pytorch_qwen3_8b_base, pytorch_qwen3_30b_A3B_base]' type: string default: '[turbomind_qwen2_5_1_5b, turbomind_qwen2_5_7b, turbomind_qwen2_5_32b, turbomind_glm_4_9b, turbomind_llama_3_1_8b, turbomind_llama_3_70b, turbomind_qwen3_0_6b_base, turbomind_qwen3_8b_base, turbomind_qwen3_30b_A3B_base, pytorch_qwen2_5_1_5b, pytorch_qwen2_5_7b, pytorch_qwen2_5_32b, pytorch_gemma_2_9b, pytorch_llama_3_70b, pytorch_llama_3_1_8b, pytorch_qwen3_0_6b_base, pytorch_qwen3_8b_base, pytorch_qwen3_30b_A3B_base]' baes_datasets: required: true description: 'Tested datasets list. eg. [*mmlu_datasets, *gsm8k_datasets]' type: string default: '[*mmlu_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]' oc_repo_org: required: false description: 'Tested repository organization name. Default is open-compass/opencompass' type: string default: 'open-compass/opencompass' oc_repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false env: ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true COMPASS_DATA_CACHE: /nvme/qa_test_models/compass_data_cache jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v6 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} evaluate: needs: linux-build if: ${{github.event_name == 'schedule' || !cancelled()}} runs-on: [self-hosted, linux-a100] timeout-minutes: 4320 # 72hours strategy: fail-fast: false matrix: evaluate_type: ['base'] container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/github-actions/resources:/root/resources - /nvme/qa_test_models/evaluation_report:/root/evaluation_report - /nvme/qa_test_models:/root/models - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Setup systems run: | export TIME_STAMP="$(date +'%Y%m%d-%H%M%S')" echo "TIME_STAMP=$TIME_STAMP" >> $GITHUB_ENV - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: cp -r /root/models/offline_pkg/lmdeploy/. . - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Install lmdeploy - dependency run: | python3 -m pip install -r /root/models/offline_pkg/requirements.txt - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | python3 -m pip install /root/models/offline_pkg/py310/lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install opencompass run: | git clone https://github.com/${{ github.event.inputs.oc_repo_org}}.git cd opencompass git checkout ${{ github.event.inputs.oc_repo_ref}} python3 -m pip install . echo "OPENCOMPASS_DIR=$(pwd)" >> $GITHUB_ENV - name: Check env run: | python3 -m pip list lmdeploy check_env - name: Setup paths for evaluation run: | ln -s /root/opencompass-data ./data python3 .github/scripts/action_tools.py create_model_links /root/models . - name: Evaluate base models if: matrix.evaluate_type == 'base' run: | echo ${{github.event.inputs.base_models}} echo ${{github.event.inputs.baes_datasets}} export LMDEPLOY_DIR=$(pwd) python3 .github/scripts/action_tools.py evaluate "${{github.event.inputs.base_models}}" "${{github.event.inputs.baes_datasets}}" /root/evaluation_report/${{ github.run_id }} base - name: Clear workspace if: always() run: | export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/lint.yml ================================================ name: lint on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python 3.10 uses: actions/setup-python@v4 with: python-version: '3.10' - name: Install pre-commit hook run: | python -m pip install pre-commit pre-commit install - name: Linting run: pre-commit run --all-files - name: Check markdown link uses: gaurav-nelson/github-action-markdown-link-check@v1 with: use-quiet-mode: 'yes' use-verbose-mode: 'yes' # check-modified-files-only: 'yes' config-file: '.github/md-link-config.json' file-path: './README.md, ./LICENSE, ./README_zh-CN.md' - name: Check module init files run: | python -m pip install fire python .github/scripts/check_lmdeploy.py check_module_init lmdeploy - name: Check doc link run: | python .github/scripts/doc_link_checker.py --target README_zh-CN.md python .github/scripts/doc_link_checker.py --target README.md - name: Check docstring coverage run: | python -m pip install interrogate interrogate -v --exclude ./lmdeploy/pytorch_poc/modeling/ --ignore-init-method --ignore-magic --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 70 lmdeploy - name: Check pylint score run: | python -m pip install pylint pylint lmdeploy ================================================ FILE: .github/workflows/linux_x64_gpu.yml ================================================ name: linux-x64-gpu on: push: paths: - '.github/workflows/linux_x64_gpu.yml' - 'src/**' - 'CMakeLists.txt' - 'cmake/**' - 'examples/**' - '3rdparty/**' - 'tests/csrc/**' pull_request: paths: - '.github/workflows/linux_x64_gpu.yml' - 'src/**' - 'CMakeLists.txt' - 'cmake/**' - 'examples/**' - '3rdparty/**' - 'tests/csrc/**' concurrency: group: linux-x64-gpu-${{ github.ref }} cancel-in-progress: true permissions: contents: read jobs: build: strategy: fail-fast: false matrix: cudaver: [12.4, 12.8] name: cuda-${{ matrix.cudaver }} runs-on: ubuntu-latest steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 - name: Build run: | docker run --rm \ -v ${{ github.workspace }}:/work \ -w /work \ openmmlab/lmdeploy-builder:cuda${{ matrix.cudaver }} \ bash -c " source /opt/conda/bin/activate && \ conda activate py310 && \ pip install build && \ python -m build --wheel " ================================================ FILE: .github/workflows/mllm_api_eval.yml ================================================ name: mllm_api_eval on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM/lmdeploy' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' backend: required: true description: 'Set backend filter. Default is "["turbomind", "pytorch"]"' type: string default: "['turbomind', 'pytorch']" execution_mode: required: false description: 'Select execution mode: infer, eval, or both. Default is "both"' type: choice options: - both - infer - eval default: 'both' run_id: required: false description: 'Set custom run ID. If not provided, github.run_id will be used' type: string default: '' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true REPORT_DIR: /nvme/qa_test_models/mllm_evaluation_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }} COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }} OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt DEEPSEEK_VL: /nvme/qa_test_models/offline_pkg/DeepSeek-VL LMUData: /nvme/qa_test_models/LMUData LOCAL_LLM: turbomind_Qwen2.5-32B-Instruct_nccl_tp2_0 OPENAI_API_KEY: sk-empty HF_DATASETS_OFFLINE: 1 HF_DATASETS_CACHE: /nvme/qa_test_models/hf_datasets HF_HUB_OFFLINE: 1 HF_EVALUATE_OFFLINE: 1 RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }} jobs: linux-build: if: ${{ !cancelled() }} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.8 OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }} steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} download_pkgs: needs: linux-build if: ${{!cancelled()}} runs-on: [self-hosted, linux-a100] timeout-minutes: 50 container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}} - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Copy Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Copy Artifacts - offline if: ${{inputs.offline_mode}} run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}} - name: Mark as start run: | chmod -R 777 ${{env.TEST_CODE_PATH}} mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt test_evaluation: needs: download_pkgs if: ${{ !cancelled() }} runs-on: [self-hosted, linux-a100] timeout-minutes: 2400 strategy: fail-fast: false matrix: backend: ${{ fromJSON(inputs.backend || '["turbomind", "pytorch"]')}} gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8'] transformers: ["", "legacy"] env: TEST_ENV: ${{ matrix.transformers }} container: image: openmmlab/lmdeploy:latest-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/github-actions/resources:/root/resources - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/huggingface_hub:/nvme/huggingface_hub - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Copy repository and Artifacts run: | cp -r ${{env.TEST_CODE_PATH}}/. . mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Install lmdeploy - dependency run: | python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt - name: Install lmdeploy run: | python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install vlmeval run: | python3 -m pip install pandas datasets scikit-learn pylatexenc math_verify apt update && apt install -y libgl1 libglib2.0-0 cp -r /nvme/qa_test_models/offline_pkg/VLMEvalKit . cd VLMEvalKit && pip install . - name: Downgrade transformers if: ${{matrix.transformers == 'legacy'}} run: | pip install transformers==4.57.6 - name: Check env run: | python3 -m pip list lmdeploy check_env mkdir ${{env.REPORT_DIR}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Setup paths for evaluation if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') run: | unset HTTP_PROXY;unset HTTPS_PROXY;unset http_proxy;unset https_proxy; cd VLMEvalKit && cp -r ../autotest . execution_mode="${{ github.event.inputs.execution_mode || 'both' }}" ulimit -n 65535 if [ "$execution_mode" = "both" ] || [ "$execution_mode" = "infer" ]; then pytest autotest/evaluate/test_mllm_api_evaluate.py -m "${{matrix.gpu_num}} and ${{matrix.backend}} and infer" --alluredir=${{env.REPORT_DIR}} || overall_exit=$? fi if [ "$execution_mode" = "both" ] || [ "$execution_mode" = "eval" ]; then pytest autotest/evaluate/test_mllm_api_evaluate.py -m "${{matrix.gpu_num}} and ${{matrix.backend}} and eval" -n 4 --alluredir=${{env.REPORT_DIR}} || overall_exit=$? fi exit $overall_exit - name: Clear workspace if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt chmod -R 777 ${{env.REPORT_DIR}} export workdir=$(pwd) rm -rf $workdir/* ================================================ FILE: .github/workflows/pr_ete_test.yml ================================================ name: pr_ete_test on: pull_request: paths: - ".github/workflows/pr_ete_test.yml" - "cmake/**" - "src/**" - "autotest/**" - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA jobs: pr_functions_test: runs-on: [self-hosted, linux-a100-pr] timeout-minutes: 120 env: REPORT_DIR: /nvme/qa_test_models/test-reports/${{ github.head_ref }}_${{ github.run_id }} SERVER_LOG: /nvme/qa_test_models/server_log/${{ github.head_ref }}_${{ github.run_id }} container: image: openmmlab/lmdeploy:dev-cu12.8 options: --gpus all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never volumes: - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip - /nvme/share_data/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/121:/mnt/121 - /mnt/104:/mnt/104 - /mnt/bigdisk:/mnt/bigdisk - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v2 - name: Install lmdeploy run: | python3 -m pip install -r requirements/lite.txt python3 -m pip install -r requirements/test.txt python3 -m pip install -e . - name: Check env run: | python3 -m pip list lmdeploy check_env mkdir ${{env.REPORT_DIR}} -p mkdir ${{env.SERVER_LOG}} -p echo "starttime=$(date +%s)" > ${{env.REPORT_DIR}}/status.txt - name: Test lmdeploy - func run: | pytest autotest -m 'pr_test and gpu_num_2' -x --alluredir=${{env.REPORT_DIR}} --clean-alluredir pytest autotest -m 'pr_test and gpu_num_1' -n 2 -x --alluredir=${{env.REPORT_DIR}} - name: Update transformers run: | pip install transformers==4.57.3 - name: Test restful server - turbomind Qwen3-32B run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-32B --tp 2 --backend turbomind --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/turbomind_Qwen3-32B_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-32B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-32B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/turbomind_Qwen3-32B_start_restful.log exit 1 - name: Test restful server - turbomind InternVL3-38B run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/OpenGVLab/InternVL3-38B --tp 2 --backend turbomind --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/turbomind_InternVL3-38B_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'OpenGVLab/InternVL3-38B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'OpenGVLab/InternVL3-38B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/turbomind_InternVL3-38B_start_restful.log exit 1 - name: Test restful server - turbomind Qwen3-30B-A3B run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-30B-A3B --tp 2 --backend turbomind --logprobs-mode raw_logprobs --allow-terminate-by-client> ${{env.SERVER_LOG}}/turbomind_Qwen3-30B-A3B_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-30B-A3B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-30B-A3B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/turbomind_Qwen3-30B-A3B_start_restful.log exit 1 - name: Test restful server - pytorch Qwen3-30B-A3B run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-30B-A3B --tp 2 --backend pytorch --logprobs-mode raw_logprobs --enable-return-routed-experts --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_Qwen3-30B-A3B_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-30B-A3B and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-30B-A3B and pytorch' -m 'not not_pytorch' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/pytorch_Qwen3-30B-A3B_start_restful.log exit 1 - name: Test restful server - pytorch Qwen3-VL-30B-A3B-Instruct run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-VL-30B-A3B-Instruct --tp 2 --backend pytorch --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_Qwen3-VL-30B-A3B-Instruct_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-VL-30B-A3B-Instruct and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-VL-30B-A3B-Instruct and pytorch' -m 'not not_pytorch and not experts' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/pytorch_Qwen3-VL-30B-A3B-Instruct_start_restful.log exit 1 - name: Test restful server - pytorch InternVL3_5-30B-A3B run: | CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/OpenGVLab/InternVL3_5-30B-A3B --tp 2 --backend pytorch --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_InternVL3_5-30B-A3B_start_restful.log 2>&1 & echo "restful_pid=$!" for i in $(seq 1 180) do sleep 5 echo "health check try $i" if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'OpenGVLab/InternVL3_5-30B-A3B and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}} pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'OpenGVLab/InternVL3_5-30B-A3B and pytorch' -m 'not not_pytorch and not experts' --alluredir=${{env.REPORT_DIR}} curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 exit 0 fi done echo "health check fail" curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1 cat ${{env.SERVER_LOG}}/pytorch_InternVL3_5-30B-A3B_start_restful.log exit 1 - name: Clear workfile if: always() run: | echo "status=done" >> ${{env.REPORT_DIR}}/status.txt export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/pypi.yml ================================================ name: publish to pypi on: push: branches: - main paths: - "lmdeploy/version.py" workflow_dispatch: jobs: linux-build: strategy: matrix: pyver: [py310, py311, py312, py313] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda12.4 OUTPUT_FOLDER: cuda12_dist steps: - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Checkout repository uses: actions/checkout@v3 - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }}/* retention-days: 1 name: linux-${{ matrix.pyver }} windows-build: strategy: matrix: pyver: ['3.10', '3.11', '3.12', '3.13'] runs-on: windows-latest steps: - name: Set git for windows run: | git config --global core.longpaths true - name: Checkout repository uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: ${{ matrix.pyver }} - name: Install python packages run: | pip install build change-wheel-version - name: Setup CUDA Toolkit id: cuda-toolkit shell: pwsh run: ./builder/windows/setup_cuda.ps1 env: INPUT_CUDA_VERSION: '12.6.2' - name: Build wheel run: | python -m build --wheel -o build/wheel Get-ChildItem -Path "build" -Filter "*.whl" | ForEach-Object { change_wheel_version $_.FullName --local-version cu121 --delete-old-wheel } - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: build/wheel/* retention-days: 1 name: windows-${{ matrix.pyver }} publish: runs-on: ubuntu-latest environment: 'prod' needs: - linux-build - windows-build steps: - name: Download artifacts uses: actions/download-artifact@v4 with: path: artifact merge-multiple: true - name: Display artifacts run: ls artifact/ -lh - name: Set up python 3.10 uses: actions/setup-python@v4 with: python-version: '3.10' - name: Upload to pypi run: | pip install twine twine upload artifact/* -u __token__ -p ${{ secrets.pypi_password }} ================================================ FILE: .github/workflows/stable.yml ================================================ name: stable_test on: workflow_dispatch: inputs: repo_org: required: false description: 'Tested repository organization name. Default is InternLM' type: string default: 'InternLM/lmdeploy' repo_ref: required: false description: 'Set branch or tag or commit id. Default is "main"' type: string default: 'main' offline_mode: required: true description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself' type: boolean default: false schedule: - cron: '00 8 * * 1' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda11.8_dist_${{ github.run_id }} REPORT_DIR: /nvme/qa_test_models/stable_reports/${{ github.run_id }} ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true COMPASS_DATA_CACHE: /nvme/qa_test_models/dataset jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} PLAT_NAME: manylinux2014_x86_64 DOCKER_TAG: cuda11.8 steps: - name: Checkout repository uses: actions/checkout@v3 with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Build run: | echo ${PYTHON_VERSION} echo ${PLAT_NAME} echo ${DOCKER_TAG} echo ${OUTPUT_FOLDER} echo ${GITHUB_RUN_ID} # remove -it sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} - name: Upload Artifacts uses: actions/upload-artifact@v4 with: if-no-files-found: error path: builder/manywheel/${{ env.OUTPUT_FOLDER }} retention-days: 1 name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }} benchmark: needs: linux-build if: ${{github.event_name == 'schedule' || !cancelled()}} runs-on: [self-hosted, lmdeploy-stable] timeout-minutes: 10080 strategy: fail-fast: false matrix: model: ['internlm/internlm2_5-20b-chat'] container: image: openmmlab/lmdeploy:latest-cu11 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 -e NO_PROXY=localhost,127.0.0.1 -e no_proxy=localhost,127.0.0.1 --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /mnt/187:/mnt/187 - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v3 if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} with: repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Copy repository - offline if: ${{inputs.offline_mode}} run: cp -r /nvme/qa_test_models/offline_pkg/lmdeploy/. . - name: Download Artifacts if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: name: my-artifact-${{ github.run_id }}-py310 - name: Install lmdeploy - dependency run: | # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl --no-deps python3 -m pip install -r requirements/test.txt - name: Install opencompass run: | git clone --depth=1 https://github.com/open-compass/opencompass.git cd opencompass python3 -m pip install -e . cd .. - name: Check env run: | python3 -m pip list lmdeploy check_env - name: Start restful api turbomind run: | mkdir ${{env.REPORT_DIR}} -p CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model}} --tp 2 --max-batch-size 256 --cache-max-entry-count 0.9 --server-port 23344 > ${{env.REPORT_DIR}}/restful.log 2>&1 & echo "restful_pid=$!" >> "$GITHUB_ENV" sleep 120s - name: Run OC result continue-on-error: true run: | ln -s /nvme/qa_test_models/dataset/data . opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-1 opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-1 opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-2 opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-2 opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-3 opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-3 - name: Test lmdeploy - restful api run: | python3 benchmark/profile_restful_api.py --backend lmdeploy --base-url http://localhost:23344 --dataset-path /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10000 --output-file ${{env.REPORT_DIR}}/stable.jsonl > ${{env.REPORT_DIR}}/stable.log python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-1.csv > ${{env.REPORT_DIR}}/stable-internal-1.log python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-2.log python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-3.log python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-4.log python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-5.log - name: Attach result if: always() run: | python3 .github/scripts/action_tools.py generate_csv_from_profile_result ${{env.REPORT_DIR}}/stable.jsonl ${{env.REPORT_DIR}}/stable.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-1.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-2.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-3.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-4.csv python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-5.csv - name: Kill api server if: always() run: | kill -15 "$restful_pid" - name: Clear workfile if: always() run: | chmod -R 777 $REPORT_DIR export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/stale.yml ================================================ name: 'Close stale issues and PRs' on: schedule: # check issue and pull request once at 01:30 a.m. every day - cron: '30 1 * * *' permissions: contents: read jobs: stale: permissions: issues: write pull-requests: write runs-on: ubuntu-latest steps: - uses: actions/stale@v7 with: stale-issue-message: 'This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.' stale-pr-message: 'This PR is marked as stale because there has been no activity in the past 45 days. It will be closed in 10 days if the stale label is not removed or if there is no further updates.' close-issue-message: 'This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.' close-pr-message: 'This PR is closed because it has been stale for 10 days. Please reopen this PR if you have any updates and want to keep contributing the code.' # only issues/PRS with following labels are checked any-of-labels: 'invalid, awaiting response, duplicate' days-before-issue-stale: 7 days-before-pr-stale: 45 days-before-issue-close: 5 days-before-pr-close: 10 # automatically remove the stale label when the issues or the pull requests are updated or commented remove-stale-when-updated: true operations-per-run: 50 ================================================ FILE: .github/workflows/test_docker.yml ================================================ name: test-docker on: push: paths: - 'docker/**' - '.github/workflows/*docker.yml' pull_request: paths: - 'docker/**' - '.github/workflows/*docker.yml' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: test_docker_image: permissions: pull-requests: write runs-on: ubuntu-latest strategy: matrix: cuda_version: [cu13, cu12] python_version: ['3.10', '3.11', '3.12', '3.13'] env: CUDA_VERSION: ${{ matrix.cuda_version }} PYTHON_VERSION: ${{ matrix.python_version }} steps: - name: Checkout repository uses: actions/checkout@v3 with: ref: ${{github.event.inputs.repo_ref}} - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Build Docker image run: | docker build . -t lmdeploy:latest -f docker/Dockerfile --build-arg CUDA_VERSION=${CUDA_VERSION} --build-arg PYTHON_VERSION=${PYTHON_VERSION} - name: Test image with lmdeploy check_env run: | docker images docker run --rm lmdeploy:latest lmdeploy check_env - name: Dive if: ${{ matrix.cuda_version == 'cu12' }} uses: MaxymVlasov/dive-action@v1.5.0 with: image: lmdeploy:latest github-token: ${{ secrets.GITHUB_TOKEN }} test_ascend_docker_image: permissions: pull-requests: write runs-on: ubuntu-22.04-arm steps: - name: Checkout repository uses: actions/checkout@v3 with: ref: ${{github.event.inputs.repo_ref}} - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Build Docker image run: | docker build . -t lmdeploy:ascend -f docker/Dockerfile_ascend_a3 # - name: Test image with lmdeploy check_env # run: | # docker images # docker run --rm lmdeploy:ascend lmdeploy check_env - name: Dive uses: MaxymVlasov/dive-action@v1.5.0 with: image: lmdeploy:ascend github-token: ${{ secrets.GITHUB_TOKEN }} test_jetson_docker_image: permissions: pull-requests: write runs-on: ubuntu-22.04-arm steps: - name: Checkout repository uses: actions/checkout@v3 with: ref: ${{github.event.inputs.repo_ref}} - name: Free disk space uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false docker-images: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true swap-storage: false - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Get docker info run: | docker info # remove http extraheader git config --local --unset "http.https://github.com/.extraheader" - name: Build Docker image run: | docker build . -t lmdeploy:jetson -f docker/Dockerfile.jetson - name: Test image with lmdeploy check_env run: | docker images docker run --rm lmdeploy:jetson lmdeploy check_env - name: Dive uses: MaxymVlasov/dive-action@v1.5.0 with: image: lmdeploy:jetson github-token: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/unit_test.yml ================================================ name: unit-test on: pull_request: paths: - ".github/workflows/unit_test.yml" - "cmake/**" - "src/**" - "tests/**" - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" push: branches: - main paths: - ".github/workflows/unit_test.yml" - "cmake/**" - "src/**" - "tests/**" - "3rdparty/**" - "lmdeploy/**" - "requirements/**" - "requirements_cuda.txt" - "CMakeLists.txt" - "setup.py" tags: - "v*.*.*" jobs: unit_test: runs-on: [self-hosted, linux-a100-s2] timeout-minutes: 4320 # 72hours container: image: openmmlab/lmdeploy:dev-cu12.8 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 -e HF_HOME=/root/.cache/huggingface --pull never" volumes: - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface - /nvme/share_data/github-actions/packages:/root/packages - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository uses: actions/checkout@v5 - name: Install lmdeploy run: | python3 -m pip install -r requirements/test.txt python3 -m pip install -e . - name: Check env run: | python3 -m pip list lmdeploy check_env - name: Test lmdeploy python UT run: | coverage run --branch --source lmdeploy -m pytest -rsE tests coverage xml coverage report -m - name: Clear workfile if: always() run: | export workdir=$(pwd) cd .. rm -rf $workdir mkdir $workdir chmod -R 777 $workdir ================================================ FILE: .github/workflows/windows_x64_gpu.yml ================================================ name: windows-x64-gpu on: push: paths: - '.github/workflows/windows_x64_gpu.yml' - 'src/**' - 'CMakeLists.txt' - 'cmake/**' - 'examples/**' - '3rdparty/**' - 'tests/csrc/**' pull_request: paths: - '.github/workflows/windows_x64_gpu.yml' - 'src/**' - 'CMakeLists.txt' - 'cmake/**' - 'examples/**' - '3rdparty/**' - 'tests/csrc/**' concurrency: group: windows-x64-gpu-${{ github.ref }} cancel-in-progress: true permissions: contents: read jobs: build: strategy: fail-fast: false matrix: cudaver: [12.6.2, 12.8.1] name: cuda-${{ matrix.cudaver }} runs-on: windows-latest steps: - name: Set git for windows run: | git config --global core.longpaths true - name: Checkout repository uses: actions/checkout@v3 - name: Set up python uses: actions/setup-python@v4 with: python-version: '3.10' - name: Install python packages run: | pip install build - name: Setup CUDA Toolkit id: cuda-toolkit shell: pwsh run: ./builder/windows/setup_cuda.ps1 env: INPUT_CUDA_VERSION: ${{ matrix.cudaver }} - name: Build wheel run: | python -m build --wheel ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class .vscode/ .idea/ # C extensions *.so # Distribution / packaging .Python triton-rerope/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ .venv/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST tmp/ # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache *build*/ !builder/ lmdeploy/lib/ lmdeploy/bin/ dist/ examples/cpp/llama/*.csv *.npy *.weight install/ /docs/*/_static/*.yaml # LMDeploy workspace/ work_dir*/ # Huggingface *.bin *config.json *generate_config.json !lmdeploy/turbomind/hf_repo/config.json # Pytorch *.pt *.pth *.py~ *.sh~ *.pyc **/src/pytorch-sphinx-theme/ # Outputs and logs *.txt *.log *.out *.csv !start_ids.csv *.pkl !CMakeLists.txt proxy_config.yml ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/PyCQA/flake8 rev: 5.0.4 hooks: - id: flake8 args: ['--extend-ignore=E231', "--max-line-length=120"] - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: - id: isort args: ["--line-length=120"] - repo: https://github.com/google/yapf rev: v0.43.0 hooks: - id: yapf args: ['-i', '--style={based_on_style: pep8, column_limit: 120}'] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: - id: trailing-whitespace - id: check-yaml - id: end-of-file-fixer - id: requirements-txt-fixer - id: double-quote-string-fixer - id: check-merge-conflict - id: fix-encoding-pragma args: ["--remove"] - id: mixed-line-ending args: ["--fix=lf"] - repo: https://github.com/executablebooks/mdformat rev: 0.7.9 hooks: - id: mdformat args: ["--number"] additional_dependencies: - mdformat-openmmlab - mdformat_frontmatter - linkify-it-py - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: - id: codespell args: ["--skip=third_party/*,*.ipynb,*.proto,src/turbomind/*,docker/Dockerfile_ascend*,docs/en/get_started/ascend/get_started.md,docs/zh_cn/get_started/ascend/get_started.md"] - repo: https://github.com/myint/docformatter rev: v1.7.7 hooks: - id: docformatter language_version: python3.10 args: ["--in-place", "--wrap-descriptions", "120"] - repo: https://github.com/open-mmlab/pre-commit-hooks rev: v0.2.0 hooks: - id: check-copyright args: ["lmdeploy"] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v11.1.0 hooks: - id: clang-format files: ^src/ types_or: [c, c++, cuda] exclude: | (?x)( ^cmake/.*\.patch$ ) ================================================ FILE: .pylintrc ================================================ [MASTER] # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-whitelist= # Specify a score threshold to be exceeded before program exits with error. fail-under=8.5 # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS,configs # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=print-statement, parameter-unpacking, unpacking-in-except, old-raise-syntax, backtick, long-suffix, old-ne-operator, old-octal-literal, import-star-module-level, non-ascii-bytes-literal, raw-checker-failed, bad-inline-option, locally-disabled, file-ignored, suppressed-message, useless-suppression, deprecated-pragma, use-symbolic-message-instead, apply-builtin, basestring-builtin, buffer-builtin, cmp-builtin, coerce-builtin, execfile-builtin, file-builtin, long-builtin, raw_input-builtin, reduce-builtin, standarderror-builtin, unicode-builtin, xrange-builtin, coerce-method, delslice-method, getslice-method, setslice-method, no-absolute-import, old-division, dict-iter-method, dict-view-method, next-method-called, metaclass-assignment, indexing-exception, raising-string, reload-builtin, oct-method, hex-method, nonzero-method, cmp-method, input-builtin, round-builtin, intern-builtin, unichr-builtin, map-builtin-not-iterating, zip-builtin-not-iterating, range-builtin-not-iterating, filter-builtin-not-iterating, using-cmp-argument, eq-without-hash, div-method, idiv-method, rdiv-method, exception-message-attribute, invalid-str-codec, sys-max-int, bad-python3-import, deprecated-string-function, deprecated-str-translate-call, deprecated-itertools-function, deprecated-types-field, next-method-defined, dict-items-not-iterating, dict-keys-not-iterating, dict-values-not-iterating, deprecated-operator-function, deprecated-urllib-function, xreadlines-attribute, deprecated-sys-function, exception-escape, comprehension-escape, no-member, invalid-name, too-many-branches, wrong-import-order, too-many-arguments, missing-function-docstring, missing-module-docstring, too-many-locals, too-few-public-methods, abstract-method, broad-except, too-many-nested-blocks, too-many-instance-attributes, missing-class-docstring, duplicate-code, not-callable, protected-access, dangerous-default-value, no-name-in-module, logging-fstring-interpolation, super-init-not-called, redefined-builtin, attribute-defined-outside-init, arguments-differ, cyclic-import, bad-super-call, too-many-statements, unused-argument, import-outside-toplevel, import-error, super-with-arguments # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=c-extension-no-member [REPORTS] # Python expression which should return a score less than or equal to 10. You # have access to the variables 'error', 'warning', 'refactor', and 'convention' # which contain the number of messages in each category, as well as 'statement' # which is the total number of statements analyzed. This score is used by the # global evaluation report (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. #msg-template= # Set the output format. Available formats are text, parseable, colorized, json # and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. output-format=text # Tells whether to display a full report or only the messages. reports=yes # Activate the evaluation score. score=yes [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 # List of decorators that change the signature of a decorated function. signature-mutators= [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it work, # install the python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains the private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to the private dictionary (see the # --spelling-private-dict-file option) instead of raising a message. spelling-store-unknown-words=no [LOGGING] # The type of string formatting that logging methods do. `old` means using % # formatting, `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=100 # Maximum number of lines in a module. max-module-lines=1000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [STRING] # This flag controls whether inconsistent-quotes generates a warning when the # character used as a quote delimiter is used inconsistently within a module. check-quote-consistency=no # This flag controls whether the implicit-str-concat should generate a warning # on implicit string concatenation in sequences defined over several lines. check-str-concat-over-line-jumps=no [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=4 [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO # Regular expression of note tags to take in consideration. #notes-rgx= [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. #argument-rgx= # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. #attr-rgx= # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Bad variable names regexes, separated by a comma. If names match any regex, # they will always be refused bad-names-rgxs= # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. #class-attribute-rgx= # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. #class-rgx= # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. #const-rgx= # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. #function-rgx= # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _, x, y, w, h, a, b # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted good-names-rgxs= # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. #inlinevar-rgx= # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. #method-rgx= # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. #module-rgx= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. #variable-rgx= [DESIGN] # Maximum number of arguments for function / method. max-args=5 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=12 # Maximum number of locals for function / method body. max-locals=15 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of return / yield for function / method body. max-returns=6 # Maximum number of statements in function / method body. max-statements=50 # Minimum number of public methods for a class (see R0903). min-public-methods=2 [IMPORTS] # List of modules that can be imported at any level, not just the top level # one. allow-any-import-level= # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Couples of modules and preferred modules, separated by a comma. preferred-modules= [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp, __post_init__ # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=cls [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". overgeneral-exceptions=BaseException, Exception ================================================ FILE: CLAUDE.md ================================================ # CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Commands **Linting:** ```bash pre-commit run --all-files ``` Style: PEP8, max line length 120, double quotes, LF endings. C++ source under `src/` uses clang-format. **Tests:** ```bash pytest tests/test_lmdeploy # all unit tests pytest tests/test_lmdeploy/test_model.py # specific file pytest tests/test_lmdeploy/test_lite/ # quantization tests pytest tests/test_lmdeploy/test_vl/ # vision-language tests ``` **Debug logging:** ```bash LMDEPLOY_LOG_LEVEL=DEBUG python ... ``` **Build (TurboMind C++ extension):** - Controlled via `setup.py` + CMake. Relevant env vars: `LMDEPLOY_TARGET_DEVICE` (default `cuda`), `DISABLE_TURBOMIND`, `CMAKE_BUILD_TYPE`, `CUDACXX`. - Requirements split by device: `requirements/runtime_cuda.txt`, `runtime_ascend.txt`, etc. ## Architecture ### Two Backends, One Pipeline `lmdeploy/pipeline.py` is the main user-facing entry point (`pipeline()` in `api.py`). It instantiates either the **PyTorch engine** (`lmdeploy/pytorch/`) or the **TurboMind engine** (`lmdeploy/turbomind/`) based on config. ### PyTorch Backend **Model patching** is the core mechanism: HuggingFace models are loaded normally, then their layers are dynamically replaced with optimized LMDeploy implementations. - `lmdeploy/pytorch/models/module_map.py` — registry mapping HF class names → LMDeploy replacement classes. Device-specific overrides in `DEVICE_SPECIAL_MODULE_MAP`. - `lmdeploy/pytorch/models/patch.py` — applies the substitutions at runtime via `_get_rewrite_qualname()` / `_class_from_qualname()`. - `lmdeploy/pytorch/models/` — 40+ per-model files (e.g., `llama.py`, `qwen.py`, `deepseek_v2.py`). Each reimplements attention, MLP, and embeddings using custom kernels. - `lmdeploy/pytorch/nn/` — reusable optimized modules: `linear/` (AWQ, W8A8, blocked-FP8, LoRA variants), `attention.py`, `norm.py`, `rotary_embedding.py`, `moe/`. - `lmdeploy/pytorch/kernels/` — Triton/CUDA kernels (e.g., `w8a8_triton_kernels.py`). - `lmdeploy/pytorch/backends/` — kernel/operator dispatchers per quantization type (FP8, AWQ, CUDA). **Engine execution flow (key files):** - `engine.py` — main PyTorch engine. - `paging/scheduler.py` — sequences → batches; prefill/decode, block eviction, prefix caching (`BlockTrie`). - `engine/engine_loop.py` — async inference loop. - (See `pytorch/engine/` and `pytorch/paging/` for full execution detail.) **Configuration dataclasses** (`lmdeploy/pytorch/config.py`): `ModelConfig`, `CacheConfig`, `SchedulerConfig`, `BackendConfig`, `DistConfig`, `MiscConfig`. ### TurboMind Backend - Python wrapper: `lmdeploy/turbomind/turbomind.py` (~800 lines). Bridges into `lmdeploy/lib/_turbomind` (pybind11 extension built from `src/turbomind/`). - Tensor interop via `torch.from_dlpack()` / `_tm.from_dlpack()`. - Config and model conversion: `lmdeploy/turbomind/deploy/config.py`, `supported_models.py`. - Parallel config helpers: `update_parallel_config()`, `complete_parallel_config()` in `messages.py`. ### Lite / Quantization Entrypoints in `lmdeploy/lite/apis/`: `calibrate.py` (main), `auto_awq.py`, `gptq.py`, `smooth_quant.py`. **Flow:** load HF model → `CalibrationContext` collects activation statistics → scale computation (`lmdeploy/lite/quantization/`) → write quantized weights. - `lite/quantization/awq.py` — AWQ (NORM_FCS_MAP, FC_FCS_MAP define per-model layer structure). - `lite/quantization/weight/quantizer.py` — weight quantizer. - `lite/quantization/activation/observer.py` — activation statistics. - `lite/modeling/` — model-specific GPTQ implementations (e.g., `internlm2_gptq.py`). - `lite/utils/cal_qparams.py` — quantization parameter calculation utilities. Layer/norm/head mappings per model family are defined directly in `calibrate.py` and `awq.py`. ### Vision-Language Models - `lmdeploy/vl/model/` — VLM preprocessing (InternVL, Qwen-VL, LLaVA, CogVLM, etc.). - `lmdeploy/vl/media/` — image/video loaders and base classes. - `lmdeploy/pytorch/multimodal/` — multimodal input handling for the PyTorch engine. - Reference VLM implementation: `lmdeploy/vl/model/qwen3.py`. ### Other Key Files - `lmdeploy/messages.py` — core types: `GenerationConfig`, `EngineConfig`, `TurbomindEngineConfig`, `SchedulerSequence`, `MessageStatus`. - `lmdeploy/model.py` — chat templates; critical for correct conversation formatting. - `lmdeploy/archs.py` — architecture registry mapping model arch names to runtime patches. - `lmdeploy/tokenizer.py` — HuggingFace/SentencePiece tokenizer wrapper. - `lmdeploy/serve/openai/` — OpenAI-compatible API server. ## Adding a New PyTorch Model Use the `/support-new-model` skill for a complete step-by-step guide. ================================================ FILE: CMakeLists.txt ================================================ # Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cmake_minimum_required(VERSION 3.11 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13 cmake_policy(SET CMP0074 NEW) project(TurboMind LANGUAGES CXX CUDA) if (MSVC) # use standard conformant preprocessor add_compile_options($<$:/Zc:preprocessor>) add_compile_options($<$:/Zc:__cplusplus>) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor -Xcompiler=/Zc:__cplusplus") endif () find_package(CUDAToolkit REQUIRED) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11") add_definitions("-DENABLE_BF16") endif() set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) option(BUILD_MULTI_GPU "Build multi-gpu support" ON) option(BUILD_PY_FFI "Build python ffi" ON) option(BUILD_TEST "Build tests" OFF) option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF) option(BUILD_FAST_MATH "Build in fast math mode" ON) include(FetchContent) if (BUILD_TEST) FetchContent_Declare( Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_TAG v3.8.0 GIT_SHALLOW ON GIT_PROGRESS TRUE USES_TERMINAL_DOWNLOAD TRUE EXCLUDE_FROM_ALL ) FetchContent_MakeAvailable(Catch2) endif() FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git GIT_TAG v3.9.2 GIT_SHALLOW ON GIT_PROGRESS TRUE USES_TERMINAL_DOWNLOAD TRUE EXCLUDE_FROM_ALL ) set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES ON CACHE BOOL "Enable extended GMMA shapes") set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") FetchContent_MakeAvailable(repo-cutlass) FetchContent_Declare( yaml-cpp GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git GIT_TAG 65c1c270dbe7eec37b2df2531d7497c4eea79aee GIT_PROGRESS TRUE USES_TERMINAL_DOWNLOAD TRUE ) set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp") FetchContent_MakeAvailable(yaml-cpp) FetchContent_Declare( xgrammar GIT_REPOSITORY https://github.com/mlc-ai/xgrammar.git GIT_TAG v0.1.27 GIT_SUBMODULES "3rdparty/dlpack" GIT_PROGRESS TRUE USES_TERMINAL_DOWNLOAD TRUE ) FetchContent_GetProperties(xgrammar) if(NOT xgrammar_POPULATED) # Fetch the content using previously declared details FetchContent_Populate(xgrammar) file(WRITE ${xgrammar_SOURCE_DIR}/config.cmake "set(XGRAMMAR_BUILD_PYTHON_BINDINGS OFF)\n") if(NOT MSVC) file(APPEND ${xgrammar_SOURCE_DIR}/config.cmake "set(CMAKE_CXX_FLAGS \"-Wno-error\")\n") endif() # Bring the populated content into the build add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR}) if(TARGET xgrammar) target_compile_options(xgrammar PRIVATE $<$:/utf-8>) target_compile_options(xgrammar PRIVATE $<$:/utf-8>) endif() endif() # the environment variable # ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0 # LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6 # must be set at runtime # https://github.com/google/sanitizers/issues/1322 if (LMDEPLOY_ASAN_ENABLE) add_compile_options($<$:-fsanitize=address>) add_link_options(-fsanitize=address) endif () # notice that ubsan has linker issues for ubuntu < 18.04, see # https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed if (LMDEPLOY_UBSAN_ENABLE) add_compile_options($<$:-fsanitize=undefined>) add_link_options(-fsanitize=undefined) endif () if(BUILD_MULTI_GPU) execute_process( COMMAND python -c "import importlib.util; print(importlib.util.find_spec('nvidia.nccl').submodule_search_locations[0])" RESULT_VARIABLE result OUTPUT_VARIABLE nccl_path ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE ) if(result EQUAL 0 AND NOT nccl_path STREQUAL "") set(NCCL_ROOT ${nccl_path}) message(STATUS "Found NCCL at: ${nccl_path}") if(result EQUAL 0 AND NOT nccl_path STREQUAL "") file(GLOB nccl_lib_files "${nccl_path}/lib/libnccl.so.*") if(nccl_lib_files) list(GET nccl_lib_files -1 latest_lib) string(REGEX MATCH "\\.([0-9]+)$" version_match ${latest_lib}) if(version_match) set(NCCL_ROOT ${nccl_path}) set(ENV{NCCL_VERSION} ${CMAKE_MATCH_1}) endif() endif() endif() endif() add_definitions("-DBUILD_MULTI_GPU=1") set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) find_package(NCCL) if (NCCL_FOUND) set(USE_NCCL ON) add_definitions("-DUSE_NCCL=1") endif () endif() set(CXX_STD "17" CACHE STRING "C++ standard") # enable gold linker for binary and .so if(NOT MSVC) find_program(GOLD_PATH ld.gold REQUIRED) if(NOT GOLD_PATH) message(FATAL_ERROR "GNU gold linker is required but not found. " "Please install binutils-gold package.") endif() set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fuse-ld=gold") endif() set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) set(CUSPARSELT_PATH "" CACHE STRING "cuSPARSELt path") list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64) # profiling option(USE_NVTX "Whether or not to use nvtx" ON) if(USE_NVTX) message(STATUS "NVTX is enabled.") add_definitions("-DUSE_NVTX") endif() # setting compiler flags set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") # -Xptxas -v if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") set(ARCH "x86_64") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64") set(ARCH "x86_64") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") # cmake reports AMD64 on Windows, but we might be building for 32-bit. if(CMAKE_SIZEOF_VOID_P EQUAL 8) set(ARCH "x86_64") else() set(ARCH "x86") endif() elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86") set(ARCH "x86") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "i386") set(ARCH "x86") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "i686") set(ARCH "x86") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") set(ARCH "aarch64") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") set(ARCH "aarch64") # Apple A12 Bionic chipset which is added in iPhone XS/XS Max/XR uses arm64e architecture. elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64e") set(ARCH "aarch64") elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm*") set(ARCH "arm") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "mips") # Just to avoid the “unknown processor” error. set(ARCH "generic") elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le") set(ARCH "ppc64le") else() message(FATAL_ERROR "Unknown processor:" ${CMAKE_SYSTEM_PROCESSOR}) endif() if(ARCH STREQUAL "x86_64") if (NOT CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES "") if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS "13.0") list(APPEND CMAKE_CUDA_ARCHITECTURES 70-real 75-real) # V100, 2080 endif() if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11") list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real) # A100 endif () if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.1") list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real) # 3090 endif () if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.8") list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) # 4090 endif () if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.0") list(APPEND CMAKE_CUDA_ARCHITECTURES 90a-real) # H100 endif () if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8") list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) # 5090 endif () if (MSVC) list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real) endif () endif () elseif(ARCH STREQUAL "aarch64") if (NOT CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 72-real 87-real) # Jetson endif() else() message(FATAL_ERROR "Unsupported Architecture:" ${ARCH}) endif() message(STATUS "Building with CUDA archs: ${CMAKE_CUDA_ARCHITECTURES}") set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") # set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall") set(CMAKE_CXX_STANDARD "${CXX_STD}") set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE}") string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3") set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -O3") if(BUILD_FAST_MATH) set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math") message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}") endif() set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDA_PATH}/include ${CUTLASS_HEADER_DIR} ) message("-- COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") set(COMMON_LIB_DIRS ${CUDA_PATH}/lib64 ) if (SPARSITY_SUPPORT) list(APPEND COMMON_HEADER_DIRS ${CUSPARSELT_PATH}/include) list(APPEND COMMON_LIB_DIRS ${CUSPARSELT_PATH}/lib64) add_definitions(-DSPARSITY_ENABLED=1) endif() set(PYTHON_PATH "python" CACHE STRING "Python path") # turn off warnings on windows if (MSVC) foreach( flag_var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO CMAKE_CUDA_FLAGS CMAKE_CUDA_FLAGS_DEBUG CMAKE_CUDA_FLAGS_RELEASE CMAKE_CUDA_FLAGS_MINSIZEREL CMAKE_CUDA_FLAGS_RELWITHDEBINFO) string(REGEX REPLACE "-Wall" " /W0 " ${flag_var} "${${flag_var}}") endforeach() # avoid min/max macro in "windows.h" conflict with std::min/std::max add_definitions(-DNOMINMAX=1) endif() include_directories( ${COMMON_HEADER_DIRS} ) link_directories( ${COMMON_LIB_DIRS} ) add_subdirectory(src) # if(BUILD_TEST) # add_subdirectory(tests/csrc) # endif() # install python api if (BUILD_PY_FFI) if (CALL_FROM_SETUP_PY) install(TARGETS _turbomind DESTINATION ${CMAKE_INSTALL_PREFIX}) install(TARGETS _xgrammar DESTINATION ${CMAKE_INSTALL_PREFIX}) else() install(TARGETS _turbomind DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib) install(TARGETS _xgrammar DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib) endif() endif () ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2023-2024 Shanghai AI Laboratory. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ include lmdeploy/lib/*.so include lmdeploy/lib/*.so* include lmdeploy/lib/*.dll include lmdeploy/lib/*.pyd include lmdeploy/bin/* ================================================ FILE: README.md ================================================
[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy) ![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy) [![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE) [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [📘Documentation](https://lmdeploy.readthedocs.io/en/latest/) | [🛠️Quick Start](https://lmdeploy.readthedocs.io/en/latest/get_started/get_started.html) | [🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose) English | [简体中文](README_zh-CN.md) | [日本語](README_ja.md) 👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)
______________________________________________________________________ ## Latest News 🎉
2026 - \[2026/02\] Support [Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) - \[2026/02\] Support [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) 4bit symmetric/asymmetric quantization. Refer [here](./docs/en/quantization/llm_compressor.md) for detailed guide
2025 - \[2025/09\] TurboMind supports MXFP4 on NVIDIA GPUs starting from V100, achieving 1.5x the performmance of vLLM on H800 for openai gpt-oss models! - \[2025/06\] Comprehensive inference optimization for FP8 MoE Models - \[2025/06\] DeepSeek PD Disaggregation deployment is now supported through integration with [DLSlime](https://github.com/DeepLink-org/DLSlime) and [Mooncake](https://github.com/kvcache-ai/Mooncake). Huge thanks to both teams! - \[2025/04\] Enhance DeepSeek inference performance by integration deepseek-ai techniques: FlashMLA, DeepGemm, DeepEP, MicroBatch and eplb - \[2025/01\] Support DeepSeek V3 and R1
2024 - \[2024/11\] Support Mono-InternVL with PyTorch engine - \[2024/10\] PyTorchEngine supports graph mode on ascend platform, doubling the inference speed - \[2024/09\] LMDeploy PyTorchEngine adds support for [Huawei Ascend](./docs/en/get_started/ascend/get_started.md). See supported models [here](docs/en/supported_models/supported_models.md) - \[2024/09\] LMDeploy PyTorchEngine achieves 1.3x faster on Llama3-8B inference by introducing CUDA graph - \[2024/08\] LMDeploy is integrated into [modelscope/swift](https://github.com/modelscope/swift) as the default accelerator for VLMs inference - \[2024/07\] Support Llama3.1 8B, 70B and its TOOLS CALLING - \[2024/07\] Support [InternVL2](docs/en/multi_modal/internvl.md) full-series models, [InternLM-XComposer2.5](docs/en/multi_modal/xcomposer2d5.md) and [function call](docs/en/llm/api_server_tools.md) of InternLM2.5 - \[2024/06\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next - \[2024/05\] Balance vision model when deploying VLMs with multiple GPUs - \[2024/05\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2 - \[2024/04\] Support Llama3 and more VLMs, such as InternVL v1.1, v1.2, MiniGemini, InternLMXComposer2. - \[2024/04\] TurboMind adds online int8/int4 KV cache quantization and inference for all supported devices. Refer [here](docs/en/quantization/kv_quant.md) for detailed guide - \[2024/04\] TurboMind latest upgrade boosts GQA, rocketing the [internlm2-20b](https://huggingface.co/internlm/internlm2-20b) model inference to 16+ RPS, about 1.8x faster than vLLM. - \[2024/04\] Support Qwen1.5-MOE and dbrx. - \[2024/03\] Support DeepSeek-VL offline inference pipeline and serving. - \[2024/03\] Support VLM offline inference pipeline and serving. - \[2024/02\] Support Qwen 1.5, Gemma, Mistral, Mixtral, Deepseek-MOE and so on. - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) seamless integration with [LMDeploy Serving Service](docs/en/llm/api_server.md). - \[2024/01\] Support for multi-model, multi-machine, multi-card inference services. For usage instructions, please refer to [here](docs/en/llm/proxy_server.md) - \[2024/01\] Support [PyTorch inference engine](./docs/en/inference/pytorch.md), developed entirely in Python, helping to lower the barriers for developers and enable rapid experimentation with new features and technologies.
2023 - \[2023/12\] Turbomind supports multimodal input. - \[2023/11\] Turbomind supports loading hf model directly. Click [here](docs/en/inference/load_hf.md) for details. - \[2023/11\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75 - \[2023/09\] TurboMind supports Qwen-14B - \[2023/09\] TurboMind supports InternLM-20B - \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/llm/codellama.md) for deployment guide - \[2023/09\] TurboMind supports Baichuan2-7B - \[2023/08\] TurboMind supports flash-attention2. - \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling - \[2023/08\] TurboMind supports Windows (tp=1) - \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation. Check [this](docs/en/quantization/w4a16.md) guide for detailed info - \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models. - \[2023/08\] LMDeploy supports 4-bit quantization using the [AWQ](https://arxiv.org/abs/2306.00978) algorithm. - \[2023/07\] TurboMind supports Llama-2 70B with GQA. - \[2023/07\] TurboMind supports Llama-2 7B/13B. - \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
______________________________________________________________________ # Introduction LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams. It has the following core features: - **Efficient Inference**: LMDeploy delivers up to 1.8x higher request throughput than vLLM, by introducing key features like persistent batch(a.k.a. continuous batching), blocked KV cache, dynamic split&fuse, tensor parallelism, high-performance CUDA kernels and so on. - **Effective Quantization**: LMDeploy supports weight-only and k/v quantization, and the 4-bit inference performance is 2.4x higher than FP16. The quantization quality has been confirmed via OpenCompass evaluation. - **Effortless Distribution Server**: Leveraging the request distribution service, LMDeploy facilitates an easy and efficient deployment of multi-model services across multiple machines and cards. - **Excellent Compatibility**: LMDeploy supports [KV Cache Quant](docs/en/quantization/kv_quant.md), [AWQ](docs/en/quantization/w4a16.md) and [Automatic Prefix Caching](docs/en/inference/turbomind_config.md) to be used simultaneously. # Performance ![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba) # Supported Models
LLMs VLMs
  • Llama (7B - 65B)
  • Llama2 (7B - 70B)
  • Llama3 (8B, 70B)
  • Llama3.1 (8B, 70B)
  • Llama3.2 (1B, 3B)
  • InternLM (7B - 20B)
  • InternLM2 (7B - 20B)
  • InternLM3 (8B)
  • InternLM2.5 (7B)
  • Qwen (1.8B - 72B)
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • ChatGLM2 (6B)
  • GLM-4 (9B)
  • GLM-4-0414 (9B, 32B)
  • CodeGeeX4 (9B)
  • YI (6B-34B)
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • DeepSeek-V3 (685B)
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • Phi-3-mini (3.8B)
  • Phi-3.5-mini (3.8B)
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • SDAR (1.7B-30B)
  • gpt-oss (20B, 120B)
  • GLM-4.7-Flash (30B)
  • GLM-5 (754B)
  • LLaVA(1.5,1.6) (7B-34B)
  • InternLM-XComposer2 (7B, 4khd-7B)
  • InternLM-XComposer2.5 (7B)
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • Qwen3-VL (2B - 235B)
  • Qwen3.5 (0.8B - 397B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • InternVL2 (1B-76B)
  • InternVL2.5(MPO) (1B-78B)
  • InternVL3 (1B-78B)
  • InternVL3.5 (1B-241BA28B)
  • Intern-S1 (241B)
  • Intern-S1-mini (8.3B)
  • Intern-S1-Pro (1TB)
  • Mono-InternVL (2B)
  • ChemVLM (8B-26B)
  • CogVLM-Chat (17B)
  • CogVLM2-Chat (19B)
  • MiniCPM-Llama3-V-2_5
  • MiniCPM-V-2_6
  • Phi-3-vision (4.2B)
  • Phi-3.5-vision (4.2B)
  • GLM-4V (9B)
  • GLM-4.1V-Thinking (9B)
  • Llama3.2-vision (11B, 90B)
  • Molmo (7B-D,72B)
  • Gemma3 (1B - 27B)
  • Llama4 (Scout, Maverick)
LMDeploy has developed two inference engines - [TurboMind](./docs/en/inference/turbomind.md) and [PyTorch](./docs/en/inference/pytorch.md), each with a different focus. The former strives for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers. They differ in the types of supported models and the inference data type. Please refer to [this table](./docs/en/supported_models/supported_models.md) for each engine's capability and choose the proper one that best fits your actual needs. # Quick Start [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) ## Installation It is recommended installing lmdeploy using pip in a conda environment (python 3.10 - 3.13): ```shell conda create -n lmdeploy python=3.10 -y conda activate lmdeploy pip install lmdeploy ``` The default prebuilt package is compiled on **CUDA 12** since v0.3.0. For the GeForce RTX 50 series, please install the LMDeploy prebuilt package complied with **CUDA 12.8** ```shell export LMDEPLOY_VERSION=0.12.2 export PYTHON_VERSION=310 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu128-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu128 ``` For more information on installing on CUDA 11+ platform, or for instructions on building from source, please refer to the [installation guide](docs/en/get_started/installation.md). ## Offline Batch Inference ```python import lmdeploy with lmdeploy.pipeline("internlm/internlm3-8b-instruct") as pipe: response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` > \[!NOTE\] > By default, LMDeploy downloads model from HuggingFace. If you would like to use models from ModelScope, please install ModelScope by `pip install modelscope` and set the environment variable: > > `export LMDEPLOY_USE_MODELSCOPE=True` > > If you would like to use models from openMind Hub, please install openMind Hub by `pip install openmind_hub` and set the environment variable: > > `export LMDEPLOY_USE_OPENMIND_HUB=True` For more information about inference pipeline, please refer to [here](docs/en/llm/pipeline.md). # Tutorials Please review [getting_started](docs/en/get_started/get_started.md) section for the basic usage of LMDeploy. For detailed user guides and advanced guides, please refer to our [tutorials](https://lmdeploy.readthedocs.io/en/latest/): - User Guide - [LLM Inference pipeline](docs/en/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) - [VLM Inference pipeline](docs/en/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing) - [LLM Serving](docs/en/llm/api_server.md) - [VLM Serving](docs/en/multi_modal/api_server_vl.md) - [Quantization](docs/en/quantization) - Advance Guide - [Inference Engine - TurboMind](docs/en/inference/turbomind.md) - [Inference Engine - PyTorch](docs/en/inference/pytorch.md) - [Customize chat templates](docs/en/advance/chat_template.md) - [Add a new model](docs/en/advance/pytorch_new_model.md) - gemm tuning - [Long context inference](docs/en/advance/long_context.md) - [Multi-model inference service](docs/en/llm/proxy_server.md) # Third-party projects - Deploying LLMs offline on the NVIDIA Jetson platform by LMDeploy: [LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson) - Example project for deploying LLMs using LMDeploy and BentoML: [BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy) # Contributing We appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. # Acknowledgement - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) - [llm-awq](https://github.com/mit-han-lab/llm-awq) - [vLLM](https://github.com/vllm-project/vllm) - [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) # Citation ```bibtex @misc{2023lmdeploy, title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM}, author={LMDeploy Contributors}, howpublished = {\url{https://github.com/InternLM/lmdeploy}}, year={2023} } ``` ```bibtex @article{zhang2025efficient, title={Efficient Mixed-Precision Large Language Model Inference with TurboMind}, author={Zhang, Li and Jiang, Youhe and He, Guoliang and Chen, Xin and Lv, Han and Yao, Qian and Fu, Fangcheng and Chen, Kai}, journal={arXiv preprint arXiv:2508.15601}, year={2025} } ``` # License This project is released under the [Apache 2.0 license](LICENSE). ================================================ FILE: README_ja.md ================================================
[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy) ![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy) [![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE) [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [📘Documentation](https://lmdeploy.readthedocs.io/en/latest/) | [🛠️Quick Start](https://lmdeploy.readthedocs.io/en/latest/get_started/get_started.html) | [🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose) [English](README.md) | [简体中文](README_zh-CN.md) | 日本語 👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)
______________________________________________________________________ ## 最新ニュース 🎉
2024 - \[2024/08\] 🔥🔥 LMDeployは[modelscope/swift](https://github.com/modelscope/swift)に統合され、VLMs推論のデフォルトアクセラレータとなりました - \[2024/07\] 🎉🎉 Llama3.1 8B、70Bおよびそのツールコールをサポート - \[2024/07\] [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e)全シリーズモデル、[InternLM-XComposer2.5](docs/en/multi_modal/xcomposer2d5.md)およびInternLM2.5の[ファンクションコール](docs/en/llm/api_server_tools.md)をサポート - \[2024/06\] PyTorchエンジンはDeepSeek-V2およびいくつかのVLMs、例えばCogVLM2、Mini-InternVL、LlaVA-Nextをサポート - \[2024/05\] 複数のGPUでVLMsをデプロイする際にビジョンモデルをバランスさせる - \[2024/05\] InternVL v1.5、LLaVa、InternLMXComposer2などのVLMsで4ビットの重みのみの量子化と推論をサポート - \[2024/04\] Llama3およびInternVL v1.1、v1.2、MiniGemini、InternLMXComposer2などのVLMモデルをサポート - \[2024/04\] TurboMindはすべてのサポートされているデバイスでのオンラインint8/int4 KVキャッシュ量子化と推論を追加しました。詳細なガイドは[こちら](docs/en/quantization/kv_quant.md)を参照してください - \[2024/04\] TurboMindの最新アップグレードによりGQAが強化され、[internlm2-20b](https://huggingface.co/internlm/internlm2-20b)モデルの推論が16+ RPSに達し、vLLMの約1.8倍の速さになりました - \[2024/04\] Qwen1.5-MOEおよびdbrxをサポート - \[2024/03\] DeepSeek-VLのオフライン推論パイプラインとサービングをサポート - \[2024/03\] VLMのオフライン推論パイプラインとサービングをサポート - \[2024/02\] Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOEなどをサポート - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE)が[LMDeployサービングサービス](./docs/en/llm/api_server.md)とシームレスに統合されました - \[2024/01\] 複数モデル、複数マシン、複数カードの推論サービスをサポート。使用方法は[こちら](./docs/en/llm/proxy_server.md)を参照してください - \[2024/01\] [PyTorch推論エンジン](./docs/en/inference/pytorch.md)をサポートし、完全にPythonで開発されており、開発者の障壁を下げ、新機能や技術の迅速な実験を可能にします
2023 - \[2023/12\] Turbomindはマルチモーダル入力をサポート - \[2023/11\] Turbomindはhfモデルの直接読み込みをサポート。詳細は[こちら](docs/en/inference/load_hf.md)をクリックしてください - \[2023/11\] TurboMindの主要なアップグレード、包括的なPaged Attention、シーケンス長制限のない高速なアテンションカーネル、2倍速いKV8カーネル、Split-Kデコーディング(Flash Decoding)、およびsm_75のW4A16推論 - \[2023/09\] TurboMindはQwen-14Bをサポート - \[2023/09\] TurboMindはInternLM-20Bをサポート - \[2023/09\] TurboMindはCode Llamaのすべての機能をサポート:コード補完、インフィリング、チャット/インストラクト、Pythonスペシャリスト。デプロイメントガイドは[こちら](./docs/en/llm/codellama.md)をクリックしてください - \[2023/09\] TurboMindはBaichuan2-7Bをサポート - \[2023/08\] TurboMindはflash-attention2をサポート - \[2023/08\] TurboMindはQwen-7B、動的NTK-RoPEスケーリング、動的logNスケーリングをサポート - \[2023/08\] TurboMindはWindowsをサポート(tp=1) - \[2023/08\] TurboMindは4ビット推論をサポートし、FP16の2.4倍の速さで、最速のオープンソース実装です。詳細な情報は[こちら](docs/en/quantization/w4a16.md)のガイドを確認してください - \[2023/08\] LMDeployは[HuggingFace Hub](https://huggingface.co/lmdeploy)で提供され、すぐに使用できる4ビットモデルを提供します - \[2023/08\] LMDeployは[AWQ](https://arxiv.org/abs/2306.00978)アルゴリズムを使用した4ビット量子化をサポート - \[2023/07\] TurboMindはGQAを使用したLlama-2 70Bをサポート - \[2023/07\] TurboMindはLlama-2 7B/13Bをサポート - \[2023/07\] TurboMindはInternLMのテンソル並列推論をサポート
______________________________________________________________________ # 紹介 LMDeployは、[MMRazor](https://github.com/open-mmlab/mmrazor)および[MMDeploy](https://github.com/open-mmlab/mmdeploy)チームによって開発された、LLMの圧縮、デプロイ、およびサービングのためのツールキットです。以下の主要な機能を備えています: - **効率的な推論**:LMDeployは、persistent batch(連続バッチ)、ブロック化されたKVキャッシュ、動的分割と融合、テンソル並列、高性能なCUDAカーネルなどの主要な機能を導入し、vLLMよりも最大1.8倍のリクエストスループットを提供します。 - **効果的な量子化**:LMDeployは、重みのみおよびk/vの量子化をサポートし、4ビットの推論性能はFP16の2.4倍です。量子化の品質はOpenCompassの評価を通じて確認されています。 - **簡単な分散サーバー**:リクエスト分散サービスを活用することで、LMDeployは複数のマシンおよびカードにわたるマルチモデルサービスのデプロイを容易にします。 - **優れた互換性**:LMDeployは、[KV Cache Quant](docs/en/quantization/kv_quant.md)、[AWQ](docs/en/quantization/w4a16.md)、および[Automatic Prefix Caching](docs/en/inference/turbomind_config.md)を同時に使用することをサポートします。 # パフォーマンス LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざまな規模のモデルで、vLLMの1.36〜1.85倍のリクエストを毎秒処理します。静的推論能力の面では、TurboMind 4ビットモデルの推論速度(out token/s)はFP16/BF16推論をはるかに上回ります。小さなバッチでは、2.4倍に向上します。 ![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba) # サポートされているモデル
LLMs VLMs
  • Llama (7B - 65B)
  • Llama2 (7B - 70B)
  • Llama3 (8B, 70B)
  • Llama3.1 (8B, 70B)
  • Llama3.2 (1B, 3B)
  • InternLM (7B - 20B)
  • InternLM2 (7B - 20B)
  • InternLM3 (8B)
  • InternLM2.5 (7B)
  • Qwen (1.8B - 72B)
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • ChatGLM2 (6B)
  • GLM-4 (9B)
  • GLM-4-0414 (9B, 32B)
  • CodeGeeX4 (9B)
  • YI (6B-34B)
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • DeepSeek-V3 (685B)
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • Phi-3-mini (3.8B)
  • Phi-3.5-mini (3.8B)
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • SDAR (1.7B-30B)
  • gpt-oss (20B, 120B)
  • GLM-4.7-Flash (30B)
  • GLM-5 (754B)
  • LLaVA(1.5,1.6) (7B-34B)
  • InternLM-XComposer2 (7B, 4khd-7B)
  • InternLM-XComposer2.5 (7B)
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • Qwen3-VL (2B - 235B)
  • Qwen3.5 (0.8B - 397B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • InternVL2 (1B-76B)
  • InternVL2.5(MPO) (1B-78B)
  • InternVL3 (1B-78B)
  • InternVL3.5 (1B-241BA28B)
  • Intern-S1 (241B)
  • Intern-S1-mini (8.3B)
  • Mono-InternVL (2B)
  • ChemVLM (8B-26B)
  • CogVLM-Chat (17B)
  • CogVLM2-Chat (19B)
  • MiniCPM-Llama3-V-2_5
  • MiniCPM-V-2_6
  • Phi-3-vision (4.2B)
  • Phi-3.5-vision (4.2B)
  • GLM-4V (9B)
  • GLM-4.1V-Thinking (9B)
  • Llama3.2-vision (11B, 90B)
  • Molmo (7B-D,72B)
  • Gemma3 (1B - 27B)
  • Llama4 (Scout, Maverick)
LMDeployは、[TurboMind](./docs/en/inference/turbomind.md)および[PyTorch](./docs/en/inference/pytorch.md)の2つの推論エンジンを開発しました。それぞれ異なる焦点を持っています。前者は推論性能の究極の最適化を目指し、後者は完全にPythonで開発されており、開発者の障壁を下げることを目指しています。 サポートされているモデルの種類や推論データタイプに違いがあります。各エンジンの能力については[この表](./docs/en/supported_models/supported_models.md)を参照し、実際のニーズに最適なものを選択してください。 # クイックスタート [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) ## インストール クリーンなconda環境(Python 3.10 - 3.13)でlmdeployをインストールすることをお勧めします。 ```shell conda create -n lmdeploy python=3.10 -y conda activate lmdeploy pip install lmdeploy ``` v0.3.0から、デフォルトの事前構築済みパッケージはCUDA 12でコンパイルされています。 CUDA 11+プラットフォームでのインストールに関する情報、またはソースからのビルド手順については、[インストールガイドを](docs/en/get_started/installation.md)参照してください。 ## オフラインバッチ推論 ```python import lmdeploy with lmdeploy.pipeline("internlm/internlm3-8b-instruct") as pipe: response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` > \[!NOTE\] > デフォルトでは、LMDeployはHuggingFaceからモデルをダウンロードします。ModelScopeからモデルを使用する場合は、`pip install modelscope`コマンドでModelScopeをインストールし、環境変数を設定してください: > > `export LMDEPLOY_USE_MODELSCOPE=True` > > openMind Hubからモデルを使用する場合は、`pip install openmind_hub`コマンドでopenMind Hubをインストールし、環境変数を設定してください: > > `export LMDEPLOY_USE_OPENMIND_HUB=True` 推論パイプラインに関する詳細情報は[こちら](./docs/en/llm/pipeline.md)を参照してください。 # チュートリアル LMDeployの基本的な使用方法については、[getting_started](docs/en/get_started/get_started.md)セクションを参照してください。 詳細なユーザーガイドと高度なガイドについては、[チュートリアル](https://lmdeploy.readthedocs.io/en/latest/)を参照してください: - ユーザーガイド - [LLM推論パイプライン](./docs/en/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) - [VLM推論パイプライン](./docs/en/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing) - [LLMサービング](docs/en/llm/api_server.md) - [VLMサービング](docs/en/multi_modal/api_server_vl.md) - [量子化](docs/en/quantization) - 高度なガイド - [推論エンジン - TurboMind](docs/en/inference/turbomind.md) - [推論エンジン - PyTorch](docs/en/inference/pytorch.md) - [カスタムチャットテンプレート](docs/en/advance/chat_template.md) - [新しいモデルの追加](docs/en/advance/pytorch_new_model.md) - gemmチューニング - [長文推論](docs/en/advance/long_context.md) - [マルチモデル推論サービス](docs/en/llm/proxy_server.md) # サードパーティプロジェクト - LMDeployを使用してNVIDIA JetsonプラットフォームでLLMをオフラインでデプロイ:[LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson) - LMDeployとBentoMLを使用してLLMをデプロイするためのサンプルプロジェクト:[BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy) # 貢献 LMDeployへのすべての貢献に感謝します。貢献ガイドラインについては、[CONTRIBUTING.md](.github/CONTRIBUTING.md)を参照してください。 # 謝辞 - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) - [llm-awq](https://github.com/mit-han-lab/llm-awq) - [vLLM](https://github.com/vllm-project/vllm) - [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) # 引用 ```bibtex @misc{2023lmdeploy, title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM}, author={LMDeploy Contributors}, howpublished = {\url{https://github.com/InternLM/lmdeploy}}, year={2023} } ``` # ライセンス このプロジェクトは[Apache 2.0ライセンス](LICENSE)の下でリリースされています。 ================================================ FILE: README_zh-CN.md ================================================
[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy) ![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy) [![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE) [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [📘Documentation](https://lmdeploy.readthedocs.io/zh-cn/latest/) | [🛠️Quick Start](https://lmdeploy.readthedocs.io/zh-cn/latest/get_started/get_started.html) | [🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose) [English](README.md) | 简体中文 | [日本語](README_ja.md) 👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm) [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)
______________________________________________________________________ ## 最新进展 🎉
2026 - \[2026/02\] 支持 [Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) - \[2026/02\] 支持 [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) 4bit 对称和非对称量化。 具体操作指南详见[此处](./docs/zh_cn/quantization/llm_compressor.md)
2025 - 【2025年9月】TurboMind 引擎支持 MXFP4,适用于 NVIDIA V100 及以上 GPU。在 H800 上推理 openai gpt-oss 模型,性能可达 vLLM 的 1.5倍! - 【2025年6月】深度优化 FP8 MoE 模型推理 - 【2025年6月】集成[DLSlime](https://github.com/DeepLink-org/DLSlime)和[Mooncake](https://github.com/kvcache-ai/Mooncake),实现DeepSeek PD分离部署,向两个团队表示诚挚的感谢! - 【2025年4月】集成deepseek-ai组件FlashMLA、DeepGemm、DeepEP、MicroBatch、eplb等,提升DeepSeek推理性能 - 【2025年1月】新增对DeepSeek V3及R1的支持
2024 - \[2024/11\] PyTorch engine 支持 Mono-InternVL 模型 - \[2024/10\] PyTorchEngine 在 ascend 平台上支持了图模式,推理性能提高了 1 倍 - \[2024/09\] LMDeploy PyTorchEngine 增加了对 [华为 Ascend](docs/zh_cn/get_started/ascend/get_started.md) 的支持。支持的模型请见[这里](docs/zh_cn/supported_models/supported_models.md) - \[2024/09\] 通过引入 CUDA Graph,LMDeploy PyTorchEngine 在 Llama3-8B 推理上实现了 1.3 倍的加速 - \[2024/08\] LMDeploy现已集成至 [modelscope/swift](https://github.com/modelscope/swift),成为 VLMs 推理的默认加速引擎 - \[2024/07\] 支持 Llama3.1 8B 和 70B 模型,以及工具调用功能 - \[2024/07\] 支持 [InternVL2](docs/zh_cn/multi_modal/internvl.md) 全系列模型,[InternLM-XComposer2.5](docs/zh_cn/multi_modal/xcomposer2d5.md) 模型和 InternLM2.5 的 [function call 功能](docs/zh_cn/llm/api_server_tools.md) - \[2024/06\] PyTorch engine 支持了 DeepSeek-V2 和若干 VLM 模型推理, 比如 CogVLM2,Mini-InternVL,LlaVA-Next - \[2024/05\] 在多 GPU 上部署 VLM 模型时,支持把视觉部分的模型均分到多卡上 - \[2024/05\] 支持InternVL v1.5, LLaVa, InternLMXComposer2 等 VLMs 模型的 4bit 权重量化和推理 - \[2024/04\] 支持 Llama3 和 InternVL v1.1, v1.2,MiniGemini,InternLM-XComposer2 等 VLM 模型 - \[2024/04\] TurboMind 支持 kv cache int4/int8 在线量化和推理,适用已支持的所有型号显卡。详情请参考[这里](docs/zh_cn/quantization/kv_quant.md) - \[2024/04\] TurboMind 引擎升级,优化 GQA 推理。[internlm2-20b](https://huggingface.co/internlm/internlm2-20b) 推理速度达 16+ RPS,约是 vLLM 的 1.8 倍 - \[2024/04\] 支持 Qwen1.5-MOE 和 dbrx. - \[2024/03\] 支持 DeepSeek-VL 的离线推理 pipeline 和推理服务 - \[2024/03\] 支持视觉-语言模型(VLM)的离线推理 pipeline 和推理服务 - \[2024/02\] 支持 Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOE 等模型 - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) 发布,支持无缝接入[LMDeploy Serving Service](docs/zh_cn/llm/api_server.md) - \[2024/01\] 支持多模型、多机、多卡推理服务。使用方法请参考[此处](docs/zh_cn/llm/proxy_server.md) - \[2024/01\] 增加 [PyTorch 推理引擎](./docs/zh_cn/inference/pytorch.md),作为 TurboMind 引擎的补充。帮助降低开发门槛,和快速实验新特性、新技术
2023 - \[2023/12\] Turbomind 支持多模态输入 - \[2023/11\] Turbomind 支持直接读取 Huggingface 模型。点击[这里](docs/zh_cn/inference/load_hf.md)查看使用方法 - \[2023/11\] TurboMind 重磅升级。包括:Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16 - \[2023/09\] TurboMind 支持 Qwen-14B - \[2023/09\] TurboMind 支持 InternLM-20B 模型 - \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/llm/codellama.md)阅读部署方法 - \[2023/09\] TurboMind 支持 Baichuan2-7B - \[2023/08\] TurboMind 支持 flash-attention2 - \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放 - \[2023/08\] TurboMind 支持 Windows (tp=1) - \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现。部署方式请看[这里](docs/zh_cn/quantization/w4a16.md) - \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型 - \[2023/08\] LMDeploy 支持使用 [AWQ](https://arxiv.org/abs/2306.00978) 算法进行 4-bit 量化 - \[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型 - \[2023/07\] TurboMind 支持 Llama-2 7B/13B 模型 - \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
______________________________________________________________________ # 简介 LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](https://github.com/open-mmlab/mmrazor) 团队联合开发,是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。 这个强大的工具箱提供以下核心功能: - **高效的推理**:LMDeploy 开发了 Persistent Batch(即 Continuous Batch),Blocked K/V Cache,动态拆分和融合,张量并行,高效的计算 kernel等重要特性。推理性能是 vLLM 的 1.8 倍 - **可靠的量化**:LMDeploy 支持权重量化和 k/v 量化。4bit 模型推理效率是 FP16 下的 2.4 倍。量化模型的可靠性已通过 OpenCompass 评测得到充分验证。 - **便捷的服务**:通过请求分发服务,LMDeploy 支持多模型在多机、多卡上的推理服务。 - **卓越的兼容性**: LMDeploy 支持 [KV Cache 量化](docs/zh_cn/quantization/kv_quant.md), [AWQ](docs/zh_cn/quantization/w4a16.md) 和 [Automatic Prefix Caching](docs/zh_cn/inference/turbomind_config.md) 同时使用。 # 性能 LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型上,每秒处理的请求数是 vLLM 的 1.36 ~ 1.85 倍。在静态推理能力方面,TurboMind 4bit 模型推理速度(out token/s)远高于 FP16/BF16 推理。在小 batch 时,提高到 2.4 倍。 ![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba) # 支持的模型
LLMs VLMs
  • Llama (7B - 65B)
  • Llama2 (7B - 70B)
  • Llama3 (8B, 70B)
  • Llama3.1 (8B, 70B)
  • Llama3.2 (1B, 3B)
  • InternLM (7B - 20B)
  • InternLM2 (7B - 20B)
  • InternLM3 (8B)
  • InternLM2.5 (7B)
  • Qwen (1.8B - 72B)
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • Qwen2-MoE (57BA14B)
  • Qwen2.5 (0.5B - 32B)
  • Qwen3, Qwen3-MoE
  • Qwen3-Next(80B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • ChatGLM2 (6B)
  • GLM-4 (9B)
  • GLM-4-0414 (9B, 32B)
  • CodeGeeX4 (9B)
  • YI (6B-34B)
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • DeepSeek-V3 (685B)
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • Phi-3-mini (3.8B)
  • Phi-3.5-mini (3.8B)
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • SDAR (1.7B-30B)
  • gpt-oss (20B, 120B)
  • GLM-4.7-Flash (30B)
  • GLM-5 (754B)
  • LLaVA(1.5,1.6) (7B-34B)
  • InternLM-XComposer2 (7B, 4khd-7B)
  • InternLM-XComposer2.5 (7B)
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • Qwen3-VL (2B - 235B)
  • Qwen3.5 (0.8B - 397B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • InternVL2 (1B-76B)
  • InternVL2.5(MPO) (1B-78B)
  • InternVL3 (1B-78B)
  • InternVL3.5 (1B-241BA28B)
  • Intern-S1 (241B)
  • Intern-S1-mini (8.3B)
  • Intern-S1-Pro (1TB)
  • Mono-InternVL (2B)
  • ChemVLM (8B-26B)
  • CogVLM-Chat (17B)
  • CogVLM2-Chat (19B)
  • MiniCPM-Llama3-V-2_5
  • MiniCPM-V-2_6
  • Phi-3-vision (4.2B)
  • Phi-3.5-vision (4.2B)
  • GLM-4V (9B)
  • GLM-4.1V-Thinking (9B)
  • Llama3.2-vision (11B, 90B)
  • Molmo (7B-D,72B)
  • Gemma3 (1B - 27B)
  • Llama4 (Scout, Maverick)
LMDeploy 支持 2 种推理引擎: [TurboMind](./docs/zh_cn/inference/turbomind.md) 和 [PyTorch](./docs/zh_cn/inference/pytorch.md),它们侧重不同。前者追求推理性能的极致优化,后者纯用python开发,着重降低开发者的门槛。 它们在支持的模型类别、计算精度方面有所差别。用户可参考[这里](./docs/zh_cn/supported_models/supported_models.md), 查阅每个推理引擎的能力,并根据实际需求选择合适的。 # 快速开始 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) ## 安装 我们推荐在一个干净的conda环境下(python3.9 - 3.12),安装 lmdeploy: ```shell conda create -n lmdeploy python=3.10 -y conda activate lmdeploy pip install lmdeploy ``` 自 v0.3.0 版本起,默认预编译包基于 **CUDA 12** 编译。 若使用 GeForce RTX 50 系列显卡,请安装基于 **CUDA 12.8** 编译的 LMDeploy 预编译包。 ```shell export LMDEPLOY_VERSION=0.12.2 export PYTHON_VERSION=310 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu128-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu128 ``` 如果需要在 CUDA 11+ 下安装 LMDeploy,或者源码安装 LMDeploy,请参考[安装文档](docs/zh_cn/get_started/installation.md) ## 离线批处理 ```python import lmdeploy with lmdeploy.pipeline("internlm/internlm3-8b-instruct") as pipe: response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` > \[!NOTE\] > LMDeploy 默认从 HuggingFace 上面下载模型,如果要从 ModelScope 上面下载模型,请通过命令 `pip install modelscope` 安装ModelScope,并设置环境变量: > > `export LMDEPLOY_USE_MODELSCOPE=True` > > 如果要从 openMind Hub 上面下载模型,请通过命令 `pip install openmind_hub` 安装openMind Hub,并设置环境变量: > > `export LMDEPLOY_USE_OPENMIND_HUB=True` 关于 pipeline 的更多推理参数说明,请参考[这里](docs/zh_cn/llm/pipeline.md) # 用户教程 请阅读[快速上手](docs/zh_cn/get_started/get_started.md)章节,了解 LMDeploy 的基本用法。 为了帮助用户更进一步了解 LMDeploy,我们准备了用户指南和进阶指南,请阅读我们的[文档](https://lmdeploy.readthedocs.io/zh-cn/latest/): - 用户指南 - [LLM 推理 pipeline](docs/zh_cn/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ) - [VLM 推理 pipeline](docs/zh_cn/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing) - [LLM 推理服务](docs/zh_cn/llm/api_server.md) - [VLM 推理服务](docs/zh_cn/multi_modal/api_server_vl.md) - [模型量化](./docs/zh_cn/quantization) - 进阶指南 - [推理引擎 - TurboMind](./docs/zh_cn/inference/turbomind.md) - [推理引擎 - PyTorch](./docs/zh_cn/inference/pytorch.md) - [自定义对话模板](./docs/zh_cn/advance/chat_template.md) - [支持新模型](./docs/zh_cn/advance/pytorch_new_model.md) - gemm tuning - [长文本推理](./docs/zh_cn/advance/long_context.md) - [多模型推理服务](docs/zh_cn/llm/proxy_server.md) # 社区项目 - 使用LMDeploy在英伟达Jetson系列板卡部署大模型:[LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson) - 使用 LMDeploy 和 BentoML 部署大模型的示例项目:[BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy) # 贡献指南 我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。 # 致谢 - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) - [llm-awq](https://github.com/mit-han-lab/llm-awq) - [vLLM](https://github.com/vllm-project/vllm) - [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) # 引用 ```bibtex @misc{2023lmdeploy, title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM}, author={LMDeploy Contributors}, howpublished = {\url{https://github.com/InternLM/lmdeploy}}, year={2023} } ``` ```bibtex @article{zhang2025efficient, title={Efficient Mixed-Precision Large Language Model Inference with TurboMind}, author={Zhang, Li and Jiang, Youhe and He, Guoliang and Chen, Xin and Lv, Han and Yao, Qian and Fu, Fangcheng and Chen, Kai}, journal={arXiv preprint arXiv:2508.15601}, year={2025} } ``` # 开源许可证 该项目采用 [Apache 2.0 开源许可证](LICENSE)。 ================================================ FILE: autotest/benchmark/test_apiserver_performance.py ================================================ import pytest from utils.benchmark_utils import restful_test from utils.config_utils import get_func_config_list def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, func_type='benchmark') @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1})) def test_turbomind_apiserver_tp1(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2})) def test_turbomind_apiserver_tp2(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4})) def test_turbomind_apiserver_tp4(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8})) def test_turbomind_apiserver_tp8(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1})) def test_pytorch_apiserver_tp1(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2})) def test_pytorch_apiserver_tp2(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4})) def test_pytorch_apiserver_tp4(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8})) def test_pytorch_apiserver_tp8(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16})) def test_pytorch_apiserver_tp16(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.function @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 4, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }]) def test_restful_func_tp2(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_smoke=True) assert result, msg ================================================ FILE: autotest/benchmark/test_longtext_performance.py ================================================ import pytest from utils.benchmark_utils import longtext_throughput_test from utils.config_utils import get_func_config_list def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, func_type='longtext_benchmark') @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1})) def test_turbomind_longtext_throughput_tp1(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2})) def test_turbomind_longtext_throughput_tp2(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4})) def test_turbomind_longtext_throughput_tp4(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8})) def test_turbomind_longtext_throughput_tp8(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1})) def test_pytorch_longtext_throughput_tp1(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2})) def test_pytorch_longtext_throughput_tp2(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4})) def test_pytorch_longtext_throughput_tp4(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8})) def test_pytorch_longtext_throughput_tp8(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16})) def test_pytorch_longtext_throughput_tp16(config, run_config, worker_id): result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id) assert result, msg ================================================ FILE: autotest/benchmark/test_mllm_apiserver_performance.py ================================================ import pytest from utils.benchmark_utils import restful_test from utils.config_utils import get_func_config_list def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, model_type='vl_model', func_type='mllm_evaluate') @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1})) def test_turbomind_mllm_apiserver_tp1(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2})) def test_turbomind_mllm_apiserver_tp2(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4})) def test_turbomind_mllm_apiserver_tp4(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8})) def test_turbomind_mllm_apiserver_tp8(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1})) def test_pytorch_mllm_apiserver_tp1(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2})) def test_pytorch_mllm_apiserver_tp2(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4})) def test_pytorch_mllm_apiserver_tp4(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8})) def test_pytorch_mllm_apiserver_tp8(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16})) def test_pytorch_mllm_apiserver_tp16(config, run_config, worker_id): result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True) assert result, msg ================================================ FILE: autotest/benchmark/test_prefixcache_performance.py ================================================ import pytest from utils.benchmark_utils import prefixcache_throughput_test from utils.config_utils import get_func_config_list def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, func_type='benchmark') @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1})) def test_turbomind_prefix_tp1(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2})) def test_turbomind_prefix_tp2(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4})) def test_turbomind_prefix_tp4(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8})) def test_turbomind_prefix_tp8(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1})) def test_pytorch_prefix_tp1(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2})) def test_pytorch_prefix_tp2(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4})) def test_pytorch_prefix_tp4(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8})) def test_pytorch_prefix_tp8(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16})) def test_pytorch_prefix_tp16(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.function @pytest.mark.parametrize('run_config', [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }]) def test_pytorch_prefix_pr_test_tp1(config, run_config, worker_id): result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id, is_smoke=True) assert result, msg ================================================ FILE: autotest/benchmark/test_throughput_performance.py ================================================ import pytest from utils.benchmark_utils import throughput_test from utils.config_utils import get_func_config_list, get_workerid def get_models(backend, parallel_config): run_configs = get_func_config_list(backend, parallel_config, func_type='benchmark') return [item for item in run_configs if 'gpt' not in item['model']] # gpt models are excluded because of openai_harmony is not supported yet @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1})) def test_turbomind_throughput_tp1(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2})) def test_turbomind_throughput_tp2(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4})) def test_turbomind_throughput_tp4(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8})) def test_turbomind_throughput_tp8(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1})) def test_pytorch_throughput_tp1(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2})) def test_pytorch_throughput_tp2(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4})) def test_pytorch_throughput_tp4(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8})) def test_pytorch_throughput_tp8(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16})) def test_pytorch_throughput_tp16(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id) assert result, msg @pytest.mark.function @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }]) def test_throughput_func_tp2(config, run_config, worker_id): result, msg = throughput_test(config, run_config, worker_id=worker_id, is_smoke=True) assert result, msg @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', [{ 'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-VL-8B-Instruct', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }]) def test_throughput_prtest_tp1(config, run_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) result, msg = throughput_test(config, run_config, worker_id=worker_id, is_smoke=True) assert result, msg ================================================ FILE: autotest/chat_prompt_case.yml ================================================ base_testcase: - 乌鲁木齐的景点A brief introduction to Urumqi’s attractions: - contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - introduce - 水磨沟 - 天池 - len_g: 10 - end: - 介绍它的相应美食#please introduce some delicious foods: - not contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - introduce - 羊肉 - len_g: 10 chat_testcase: - 你好,你叫什么名字#hi, what's your name: - end: - 简要介绍乌鲁木齐的景点#A brief introduction to Urumqi’s attractions: - contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - 介绍它的相应美食#please introduce some delicious foods: - contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - 羊肉 - end: - 介绍相应美食#please introduce some delicious foods: - not contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 code_testcase: - 使用python编写一个int数组的冒泡排序代码: - contain: - def - bubble - 冒泡 - 快速排序呢: - contain: - def - quick ================================================ FILE: autotest/config.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json env_tag: a100 device: cuda config: tp: meta-llama/Llama-4-Scout-17B-16E-Instruct: 4 meta-llama/Meta-Llama-3-1-70B-Instruct: 4 OpenGVLab/InternVL3-38B: 2 Qwen/Qwen3-235B-A22B: 8 Qwen/Qwen3-30B-A3B: 2 Qwen/Qwen3-32B: 2 Qwen/Qwen3-VL-30B-A3B-Instruct: 2 Qwen/Qwen3-30B-A3B-Base: 2 Qwen/Qwen2.5-VL-32B-Instruct: 2 mistralai/Mixtral-8x7B-Instruct-v0.1: 2 OpenGVLab/InternVL3_5-30B-A3B: 2 zai-org/GLM-4.7-Flash: 2 turbomind_chat_model: tp: - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct-AWQ - meta-llama/Meta-Llama-3-1-70B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct - internlm/internlm3-8b-instruct - internlm/internlm3-8b-instruct-awq - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-38B - OpenGVLab/InternVL3_5-30B-A3B - Qwen/Qwen3-0.6B - Qwen/Qwen3-4B - Qwen/Qwen3-8B - Qwen/Qwen3-32B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-GPTQ-Int4 - Qwen/Qwen3-235B-A22B - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2.5-VL-32B-Instruct - Qwen/Qwen1.5-MoE-A2.7B-Chat - mistralai/Mixtral-8x7B-Instruct-v0.1 - THUDM/glm-4-9b-chat - zai-org/GLM-4.7-Flash pytorch_chat_model: tp: - meta-llama/Llama-4-Scout-17B-16E-Instruct - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-1-70B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-38B - OpenGVLab/InternVL3_5-30B-A3B - Qwen/Qwen3-0.6B - Qwen/Qwen3-4B - Qwen/Qwen3-8B - Qwen/Qwen3-32B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-30B-A3B-Instruct - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b - THUDM/glm-4-9b-chat - google/gemma-2-9b-it - google/gemma-2-27b-it - zai-org/GLM-4.7-Flash - microsoft/Phi-3.5-vision-instruct - microsoft/Phi-3-vision-128k-instruct turbomind_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-38B - OpenGVLab/InternVL3_5-30B-A3B - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2.5-VL-32B-Instruct pytorch_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3_5-30B-A3B - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-30B-A3B-Instruct - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B - THUDM/glm-4v-9b - microsoft/Phi-3-vision-128k-instruct - microsoft/Phi-3.5-vision-instruct turbomind_base_model: tp: - Qwen/Qwen3-8B-Base - Qwen/Qwen3-30B-A3B-Base pytorch_base_model: tp: - Qwen/Qwen3-8B-Base - Qwen/Qwen3-30B-A3B-Base turbomind_quantization: no_awq: - meta-llama/Meta-Llama-3-1-70B-Instruct - internlm/internlm3-8b-instruct # ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py) - OpenGVLab/InternVL3-8B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-30B-A3B-Base - Qwen/Qwen1.5-MoE-A2.7B-Chat - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2.5-VL-32B-Instruct - OpenGVLab/InternVL3_5-30B-A3B - zai-org/GLM-4.7-Flash gptq: - empty no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B - OpenGVLab/InternVL3-8B - Qwen/Qwen3-0.6B - Qwen/Qwen3-4B - Qwen/Qwen3-8B - Qwen/Qwen3-32B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-GPTQ-Int4 - Qwen/Qwen3-235B-A22B - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2.5-VL-32B-Instruct - Qwen/Qwen1.5-MoE-A2.7B-Chat - Qwen/Qwen3-8B-Base - Qwen/Qwen3-30B-A3B-Base - zai-org/GLM-4.7-Flash no_kvint8: - deepseek-ai/DeepSeek-V2-Chat - zai-org/GLM-4.7-Flash pytorch_quantization: awq: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py) - Qwen/Qwen3-0.6B - Qwen/Qwen3-4B - Qwen/Qwen3-8B - microsoft/Phi-3-mini-4k-instruct - THUDM/glm-4v-9b w8a8: - meta-llama/Llama-3.2-1B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Meta-Llama-3-1-8B-Instruct # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py) - microsoft/Phi-3-mini-4k-instruct no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B - OpenGVLab/InternVL3-8B - Qwen/Qwen3-8B-Base - Qwen/Qwen3-30B-A3B-Base - Qwen/Qwen3-0.6B - Qwen/Qwen3-4B - Qwen/Qwen3-8B - Qwen/Qwen3-32B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-30B-A3B-Instruct - microsoft/Phi-3-vision-128k-instruct - microsoft/Phi-3.5-vision-instruct - zai-org/GLM-4.7-Flash no_kvint8: - zai-org/GLM-4.7-Flash longtext_benchmark_model: - Qwen/Qwen3-8B - Qwen/Qwen3-30B-A3B evaluate_model: - google/gemma-2-9b-it - google/gemma-2-27b-it - meta-llama/Meta-Llama-3-1-8B-Instruct - Qwen/Qwen1.5-MoE-A2.7B-Chat - Qwen/Qwen3-30B-A3B benchmark_model: - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-1-70B-Instruct - internlm/internlm3-8b-instruct - THUDM/glm-4-9b-chat - Qwen/Qwen3-30B-A3B mllm_evaluate_model: - OpenGVLab/InternVL3-8B - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-30B-A3B-Instruct - OpenGVLab/InternVL3_5-30B-A3B ================================================ FILE: autotest/config_3090.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json env_tag: 3090 device: cuda turbomind_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct pytorch_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct turbomind_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B pytorch_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B turbomind_base_model: tp: - internlm/internlm3-8b-instruct - Qwen/Qwen3-8B pytorch_base_model: tp: - internlm/internlm3-8b-instruct - Qwen/Qwen3-8B turbomind_quantization: no_awq: - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct gptq: - empty no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-3B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Chat pytorch_quantization: awq: - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct w8a8: - meta-llama/Llama-3.2-3B-Instruct no_kvint4: - OpenGVLab/InternVL3-8B - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat ================================================ FILE: autotest/config_3090_legacy.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json env_tag: 3090 device: cuda turbomind_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct pytorch_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-3B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct turbomind_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B pytorch_vl_model: tp: - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen2.5-VL-3B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct turbomind_base_model: tp: - internlm/internlm3-8b-instruct - Qwen/Qwen3-8B pytorch_base_model: tp: - internlm/internlm3-8b-instruct - Qwen/Qwen3-8B turbomind_quantization: no_awq: - internlm/internlm3-8b-instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct gptq: - empty no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-3B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Chat pytorch_quantization: awq: - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct w8a8: - meta-llama/Llama-3.2-3B-Instruct no_kvint4: - OpenGVLab/InternVL3-8B - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-8B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-3B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat ================================================ FILE: autotest/config_5080.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json env_tag: 5080 device: cuda turbomind_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B pytorch_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B turbomind_vl_model: tp: - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B pytorch_vl_model: tp: - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B turbomind_base_model: tp: - Qwen/Qwen3-4B pytorch_base_model: tp: - Qwen/Qwen3-4B turbomind_quantization: no_awq: - OpenGVLab/InternVL3-2B-Instruct gptq: - empty no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-VL-3B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Chat pytorch_quantization: awq: - meta-llama/Llama-3.2-3B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B w8a8: - meta-llama/Llama-3.2-3B-Instruct no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat ================================================ FILE: autotest/config_5080_legacy.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json env_tag: 5080 device: cuda turbomind_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B pytorch_chat_model: tp: - meta-llama/Llama-3.2-3B-Instruct - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B turbomind_vl_model: tp: - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B pytorch_vl_model: tp: - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen2.5-VL-3B-Instruct turbomind_base_model: tp: - Qwen/Qwen3-4B pytorch_base_model: tp: - Qwen/Qwen3-4B turbomind_quantization: no_awq: - OpenGVLab/InternVL3-2B-Instruct gptq: - empty no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-VL-3B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Chat pytorch_quantization: awq: - meta-llama/Llama-3.2-3B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B w8a8: - meta-llama/Llama-3.2-3B-Instruct no_kvint4: - meta-llama/Llama-3.2-1B-Instruct - OpenGVLab/InternVL3-2B-Instruct - OpenGVLab/InternVL3-1B-Instruct - OpenGVLab/InternVL2_5-1B - Qwen/Qwen3-4B - Qwen/Qwen3-1.7B - Qwen/Qwen3-0.6B - Qwen/Qwen2.5-VL-3B-Instruct no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat ================================================ FILE: autotest/config_ascend.yml ================================================ model_path: /mnt/vc-intern-delivery/qa-llm-cicd/qa_test_models resource_path: /mnt/vc-intern-delivery/qa-llm-cicd/resource log_path: /mnt/vc-intern-delivery/qa-llm-cicd/log server_log_path: /mnt/vc-intern-delivery/qa-llm-cicd/server_log eval_path: /mnt/vc-intern-delivery/qa-llm-cicd/evaluation_report mllm_eval_path: /mnt/vc-intern-delivery/qa-llm-cicd/mllm_evaluation_report benchmark_path: /mnt/vc-intern-delivery/qa-llm-cicd/benchmark_report dataset_path: /mnt/vc-intern-delivery/qa-llm-cicd/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /mnt/vc-intern-delivery/qa-llm-cicd/datasets/prefix_cache_test.json env_tag: ascend device: ascend config: tp: Qwen/Qwen3-30B-A3B: 4 Qwen/Qwen3-235B-A22B: 16 Qwen/Qwen3-32B: 4 Qwen/Qwen3-8B: 2 internlm/Intern-S1: 16 internlm/Intern-S1-mini: 2 OpenGVLab/InternVL3_5-8B: 2 OpenGVLab/InternVL3_5-38B: 4 Qwen/Qwen3-VL-30B-A3B-Instruct: 4 Qwen/Qwen3-VL-8B-Instruct: 2 Qwen/Qwen3-VL-32B-Instruct: 4 pytorch_chat_model: tp: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-32B - Qwen/Qwen3-8B - Qwen/Qwen3-0.6B pytorch_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3_5-2B - OpenGVLab/InternVL3_5-8B - OpenGVLab/InternVL3_5-38B - Qwen/Qwen3-VL-30B-A3B-Instruct - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-32B-Instruct pytorch_base_model: tp: - Qwen/Qwen3-0.6B pytorch_quantization: awq: - Empty w8a8: - Empty no_kvint4: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-32B - Qwen/Qwen3-8B - Qwen/Qwen3-0.6B - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3_5-2B - OpenGVLab/InternVL3_5-8B - OpenGVLab/InternVL3_5-38B - Qwen/Qwen3-VL-30B-A3B-Instruct - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-32B-Instruct no_kvint8: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-32B - Qwen/Qwen3-8B - Qwen/Qwen3-0.6B - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3_5-2B - OpenGVLab/InternVL3_5-8B - OpenGVLab/InternVL3_5-38B - Qwen/Qwen3-VL-30B-A3B-Instruct - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-32B-Instruct longtext_model: - Qwen/Qwen3-30B-A3B benchmark_model: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-32B - Qwen/Qwen3-8B - Qwen/Qwen3-0.6B - internlm/Intern-S1 - internlm/Intern-S1-mini evaluate_model: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B mllm_evaluate_model: - Qwen/Qwen3-VL-30B-A3B-Instruct - Qwen/Qwen3-VL-8B-Instruct - Qwen/Qwen3-VL-32B-Instruct ================================================ FILE: autotest/config_h.yml ================================================ model_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/model resource_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/resource log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/log server_log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/server_log eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/evaluation_report mllm_eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/mllm_evaluation_report benchmark_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/benchmark_report dataset_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/datasets/prefix_cache_test.json env_tag: h device: cuda config: tp: Qwen/Qwen3-235B-A22B-FP8: 4 internlm/Intern-S1: 4 Qwen/Qwen3-235B-A22B-Thinking-2507-FP8: 4 Qwen/Qwen3-30B-A3B: 2 Qwen/Qwen3-32B: 2 openai/gpt-oss-120b: 2 openai/gpt-oss-120b-BF16: 4 openai/gpt-oss-20b-BF16: 2 deepseek/DeepSeek-V3.1: 8 Qwen/Qwen3-30B-A3B-Base: 2 JetLM/SDAR-30B-A3B-Sci: 2 moonshotai/Kimi-K2-Instruct-0905: 16 Qwen/Qwen3-235B-A22B-Thinking-2507: 8 OpenGVLab/InternVL3_5-38B: 2 Qwen/Qwen3-VL-30B-A3B-Instruct: 2 internlm/Intern-S1-Pro-FP8: 16 dp_ep: moonshotai/Kimi-K2-Instruct-0905: dp: 16 ep: 16 Qwen/Qwen3-235B-A22B-Thinking-2507: dp: 8 ep: 8 internlm/Intern-S1-Pro-FP8: dp: 16 ep: 16 cp_tp: Qwen/Qwen3-235B-A22B-Thinking-2507: cp: 2 tp: 8 turbomind_chat_model: tp: - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - OpenGVLab/InternVL3_5-38B - openai/gpt-oss-120b - openai/gpt-oss-20b cp_tp: - Qwen/Qwen3-235B-A22B-Thinking-2507 pytorch_chat_model: tp: - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - Qwen/Qwen3-VL-30B-A3B-Instruct - OpenGVLab/InternVL3_5-38B - unsloth/gpt-oss-120b-BF16 - unsloth/gpt-oss-20b-BF16 - deepseek/DeepSeek-V3.1 - moonshotai/Kimi-K2-Instruct-0905 - internlm/Intern-S1-Pro-FP8 - JetLM/SDAR-30B-A3B-Sci dp_ep: - moonshotai/Kimi-K2-Instruct-0905 - Qwen/Qwen3-235B-A22B-Thinking-2507 - internlm/Intern-S1-Pro-FP8 turbomind_vl_model: tp: - OpenGVLab/InternVL3_5-38B pytorch_vl_model: tp: - OpenGVLab/InternVL3_5-38B - Qwen/Qwen3-VL-30B-A3B-Instruct turbomind_base_model: tp: - Qwen/Qwen3-4B-FP8 - openai/gpt-oss-20b pytorch_base_model: tp: - Qwen/Qwen3-8B-Base - Qwen/Qwen3-30B-A3B-Base turbomind_quantization: no_awq: - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b - Qwen/Qwen3-235B-A22B-Thinking-2507 gptq: - empty no_kvint4: - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b - Qwen/Qwen3-235B-A22B-Thinking-2507 no_kvint8: - Qwen/Qwen3-235B-A22B-Thinking-2507 pytorch_quantization: awq: - empty w8a8: - empty no_kvint4: - Qwen/Qwen3-8B-Base - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - moonshotai/Kimi-K2-Instruct-0905 - Qwen/Qwen3-235B-A22B-Thinking-2507 - internlm/Intern-S1-Pro-FP8 - JetLM/SDAR-30B-A3B-Sci - deepseek/DeepSeek-V3.1 no_kvint8: - Qwen/Qwen3-235B-A22B-Thinking-2507 - internlm/Intern-S1-Pro-FP8 - deepseek/DeepSeek-V3.1 longtext_model: - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B-Thinking-2507 benchmark_model: - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-1-70B-Instruct - Qwen/Qwen3-32B - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-235B-A22B-Thinking-2507 - Qwen/Qwen2.5-72B-Instruct - openai/gpt-oss-120b - openai/gpt-oss-20b - unsloth/gpt-oss-20b-BF16 - unsloth/gpt-oss-120b-BF16 evaluate_model: - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-235B-A22B-Thinking-2507 - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 - openai/gpt-oss-120b - unsloth/gpt-oss-120b-BF16 - deepseek/DeepSeek-V3.1 - moonshotai/Kimi-K2-Instruct-0905 - internlm/Intern-S1-Pro-FP8 - JetLM/SDAR-30B-A3B-Sci mllm_evaluate_model: - OpenGVLab/InternVL3_5-38B - Qwen/Qwen3-VL-30B-A3B-Instruct ================================================ FILE: autotest/config_h800.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_model/log eval_path: /nvme/qa_test_models/evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json env_tag: h800 device: cuda tp_config: Intern-S1: 8 Qwen3-235B-A22B: 8 Qwen3-235B-A22B-FP8: 4 Qwen3-30B-A3B: 2 Qwen3-32B: 2 gpt-oss-120b: 2 gpt-oss-120b-BF16: 4 gpt-oss-20b-BF16: 2 Qwen2.5-32B-Instruct: 2 turbomind_chat_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b pytorch_chat_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - unsloth/gpt-oss-120b-BF16 - unsloth/gpt-oss-20b-BF16 turbomind_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini pytorch_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini turbomind_base_model: tp: - internlm/Intern-S1-mini - Qwen/Qwen3-4B-FP8 - openai/gpt-oss-20b pytorch_base_model: tp: - internlm/Intern-S1-mini - Qwen/Qwen3-4B-FP8 - unsloth/gpt-oss-20b-BF16 turbomind_quantization: no_awq: - internlm/Intern-S1 - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b gptq: - empty no_kvint4: - internlm/Intern-S1 - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b no_kvint8: - empty pytorch_quantization: awq: - empty w8a8: - empty no_kvint4: - internlm/Intern-S1 - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 no_kvint8: - empty evaluate_model: - internlm/Intern-S1-mini - Qwen/Qwen3-0.6B-FP8 - Qwen/Qwen3-1.7B-FP8 - Qwen/Qwen3-4B-FP8 - Qwen/Qwen3-8B-FP8 - Qwen/Qwen3-14B-FP8 - Qwen/Qwen3-32B - Qwen/Qwen3-32B-FP8 - Qwen/Qwen3-30B-A3B - Qwen/Qwen3-30B-A3B-FP8 - Qwen/Qwen3-235B-A22B - Qwen/Qwen3-235B-A22B-FP8 - openai/gpt-oss-120b - openai/gpt-oss-20b - unsloth/gpt-oss-120b-BF16 - unsloth/gpt-oss-20b-BF16 ================================================ FILE: autotest/config_h_legacy.yml ================================================ model_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/model resource_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/resource log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/log server_log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/server_log eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/evaluation_report mllm_eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/mllm_evaluation_report benchmark_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/benchmark_report dataset_path: /mnt/shared-storage-user/auto-eval-pipeline/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /mnt/shared-storage-user/auto-eval-pipeline/datasets/prefix_cache_test.json env_tag: h device: cuda config: tp: internlm/Intern-S1: 4 turbomind_chat_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini pytorch_chat_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini turbomind_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini pytorch_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini turbomind_base_model: tp: pytorch_base_model: tp: turbomind_quantization: no_awq: - internlm/Intern-S1 - internlm/Intern-S1-mini gptq: - empty no_kvint4: - internlm/Intern-S1 - internlm/Intern-S1-mini no_kvint8: - empty pytorch_quantization: awq: - empty w8a8: - empty no_kvint4: - internlm/Intern-S1 - internlm/Intern-S1-mini no_kvint8: - empty benchmark_model: - internlm/Intern-S1 - internlm/Intern-S1-mini mllm_evaluate_model: - internlm/Intern-S1 - internlm/Intern-S1-mini ================================================ FILE: autotest/config_legacy.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json env_tag: a100 device: cuda config: tp: meta-llama/Llama-4-Scout-17B-16E-Instruct: 4 meta-llama/Meta-Llama-3-1-70B-Instruct: 4 internlm/Intern-S1: 8 OpenGVLab/InternVL3-38B: 2 OpenGVLab/InternVL2_5-26B: 2 OpenGVLab/InternVL2_5-26B-MPO: 2 OpenGVLab/InternVL2_5-38B: 4 OpenGVLab/InternVL2-40B: 4 Qwen/Qwen2.5-72B-Instruct: 4 deepseek-ai/deepseek-vl-1.3b-chat: 2 baichuan-inc/Baichuan2-13B-Chat: 2 mistralai/Mixtral-8x7B-Instruct-v0.1: 2 google/gemma-2-27b-it: 2 OpenGVLab/InternVL2-Llama3-76B-AWQ: 4 unsloth/gpt-oss-20b-BF16: 2 unsloth/gpt-oss-120b-BF16: 4 OpenGVLab/InternVL3_5-30B-A3B: 2 turbomind_chat_model: tp: - meta-llama/Llama-2-7b-chat-hf - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL2_5-8B - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 - OpenGVLab/InternVL2-Llama3-76B-AWQ - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4 - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct - baichuan-inc/Baichuan2-7B-Chat - liuhaotian/llava-v1.6-vicuna-7b - codellama/CodeLlama-7b-Instruct-hf # - allenai/Molmo-7B-D-0924 This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow` pytorch_chat_model: tp: - meta-llama/Llama-2-7b-chat-hf - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL2_5-8B # - OpenGVLab/Mono-InternVL-2B 'dict' object has no attribute 'image_size' - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct - unsloth/gpt-oss-20b-BF16 - mistralai/Mixtral-8x7B-Instruct-v0.1 - google/gemma-3-12b-it - google/gemma-2-9b-it - google/gemma-2-27b-it - google/gemma-7b-it - baichuan-inc/Baichuan2-13B-Chat - deepseek-ai/deepseek-moe-16b-chat - THUDM/chatglm2-6b - microsoft/Phi-4-mini-instruct turbomind_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL2_5-8B - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 - OpenGVLab/InternVL2-Llama3-76B-AWQ - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct - liuhaotian/llava-v1.6-vicuna-7b pytorch_vl_model: tp: - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3-8B - OpenGVLab/InternVL2_5-8B # - OpenGVLab/Mono-InternVL-2B 'dict' object has no attribute 'image_size' - Qwen/Qwen2-VL-7B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct turbomind_base_model: tp: - codellama/CodeLlama-7b-hf pytorch_base_model: tp: - bigcode/starcoder2-7b turbomind_quantization: no_awq: - internlm/Intern-S1 - internlm/Intern-S1-mini - OpenGVLab/InternVL3-8B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4 - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct - OpenGVLab/InternVL3_5-30B-A3B - codellama/CodeLlama-7b-Instruct-hf # - allenai/Molmo-7B-D-0924 This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow` gptq: - empty no_kvint4: - OpenGVLab/InternVL3-8B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct # - allenai/Molmo-7B-D-0924 This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow` no_kvint8: - Qwen/Qwen2.5-7B-Instruct pytorch_quantization: awq: - meta-llama/Llama-2-7b-chat-hf # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py) - Qwen/Qwen2.5-7B-Instruct # - microsoft/Phi-4-mini-instruct The size of tensor a (5120) must match the size of tensor b (3072) at non-singleton dimension 0 w8a8: - meta-llama/Llama-2-7b-chat-hf # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py) - Qwen/Qwen2.5-7B-Instruct # - microsoft/Phi-4-mini-instruct The size of tensor a (5120) must match the size of tensor b (3072) at non-singleton dimension 0 no_kvint4: - OpenGVLab/InternVL3-8B - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen2-VL-7B-Instruct - microsoft/Phi-3-vision-128k-instruct - microsoft/Phi-3.5-vision-instruct - unsloth/gpt-oss-20b-BF16 no_kvint8: - empty longtext_benchmark_model: - internlm/Intern-S1-mini benchmark_model: - internlm/Intern-S1 - internlm/Intern-S1-mini - meta-llama/Llama-2-7b-chat-hf - unsloth/gpt-oss-20b-BF16 evaluate_model: - Qwen/Qwen2.5-7B-Instruct mllm_evaluate_model: - internlm/Intern-S1-mini - internlm/Intern-S1 ================================================ FILE: autotest/config_test.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_model/log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json env_tag: test device: cuda config: tp: test/test_tp2: 2 test/test_tp2_gpqa: 2 test/test_tp2_int4: 2 test/test_tp8: 8 test/test_vl_tp2: 2 test/test_vl_tp2_gpqa: 2 test/test_vl_tp2_int4: 2 test/test_vl_tp8: 8 test/test_allkind: 8 dp_ep: test/test_dpep16: dp: 16 ep: 16 test/test_dpep8: dp: 8 ep: 8 test/test_vl_dpep16: dp: 16 ep: 16 test/test_vl_dpep8: dp: 8 ep: 8 test/test_allkind: dp: 8 ep: 8 cp_tp: test/test_cp2tp8: cp: 2 tp: 8 test/test_vl_cp2tp8: cp: 2 tp: 8 test/test_allkind: cp: 2 tp: 8 turbomind_chat_model: tp: - test/test_tp1 - test/test_tp2 - test/test_tp2_gpqa - test/test_tp2_int4 - test/test_tp8 - test/test_vl_tp1 - test/test_vl_tp2 - test/test_vl_tp2_gpqa - test/test_vl_tp2_int4 - test/test_vl_tp8 - test/test_allkind cp_tp: - test/test_cp2tp8 - test/test_vl_cp2tp8 - test/test_allkind pytorch_chat_model: tp: - test/test_tp1 - test/test_tp1_pytorch - test/test_tp2 - test/test_tp2_gpqa - test/test_tp2_int4 - test/test_tp8 - test/test_vl_tp1 - test/test_vl_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp2_gpqa - test/test_vl_tp2_int4 - test/test_vl_tp8 - test/test_allkind dp_ep: - test/test_dpep8 - test/test_dpep16 - test/test_vl_dpep8 - test/test_vl_dpep16 - test/test_allkind cp_tp: - test/test_cp2tp8 - test/test_vl_cp2tp8 - test/test_allkind turbomind_vl_model: tp: - test/test_vl_tp1 - test/test_vl_tp2 - test/test_vl_tp2_gpqa - test/test_vl_tp2_int4 - test/test_vl_tp8 - test/test_allkind pytorch_vl_model: tp: - test/test_vl_tp1 - test/test_vl_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp2_gpqa - test/test_vl_tp2_int4 - test/test_vl_tp8 - test/test_allkind dp_ep: - test/test_vl_dpep8 - test/test_vl_dpep16 - test/test_allkind turbomind_base_model: tp: - test/test_tp1 - test/test_tp2 pytorch_base_model: tp: - test/test_tp1 - test/test_tp1_pytorch - test/test_tp2 turbomind_quantization: no_awq: - test/test_tp2 - test/test_vl_tp2 - test/test_tp2_gpqa - test/test_vl_tp2_gpqa - test/test_cp2tp8 - test/test_dpep8 gptq: - test/test_tp1 - test/test_vl_tp1 - test/test_cp2tp8 - test/test_dpep8 no_kvint4: - test/test_tp2 - test/test_vl_tp2 - test/test_cp2tp8 - test/test_vl_dpep8 no_kvint8: - test/test_tp1 - test/test_vl_tp1 - test/test_dpep8 - test/test_vl_cp2tp8 pytorch_quantization: awq: - test/test_tp1 w8a8: - test/test_tp2 no_kvint4: - test/test_tp2 - test/test_cp2tp8 - test/test_vl_cp2tp8 no_kvint8: - test/test_tp1 - test/test_vl_tp1 - test/test_vl_dpep8 longtext_model: - test/test_tp1 - test/test_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp8 - test/test_cp2tp8 - test/test_vl_dpep8 benchmark_model: - test/test_tp1 - test/test_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp8 - test/test_cp2tp8 - test/test_vl_dpep8 mllm_benchmark_model: - test/test_vl_tp1 - test/test_vl_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp8 - test/test_vl_dpep16 - test/test_vl_cp2tp8 evaluate_model: - test/test_tp1 - test/test_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp8 - test/test_cp2tp8 - test/test_dpep16 - test/test_vl_dpep8 mllm_evaluate_model: - test/test_vl_tp1 - test/test_vl_tp1_pytorch - test/test_vl_tp2 - test/test_vl_tp8 - test/test_vl_dpep16 - test/test_vl_cp2tp8 ================================================ FILE: autotest/config_testascend.yml ================================================ model_path: /nvme/qa_test_models resource_path: /nvme/qa_test_models/resource log_path: /nvme/qa_test_models/autotest_model/log server_log_path: /nvme/qa_test_models/server_log eval_path: /nvme/qa_test_models/evaluation_report mllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report benchmark_path: /nvme/qa_test_models/benchmark_report dataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json prefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json env_tag: ascend device: ascend config: tp: dp_ep: cp_tp: pytorch_chat_model: tp: - test/test_tp1 pytorch_quantization: awq: - test/test_tp1 w8a8: - test/test_tp1 no_kvint4: - test/test_tp1 no_kvint8: - test/test_tp1 ================================================ FILE: autotest/conftest.py ================================================ import os import pytest import yaml from utils.config_utils import get_config from utils.constant import DEFAULT_SERVER from utils.proxy_distributed_utils import ProxyDistributedManager from utils.ray_distributed_utils import RayLMDeployManager cli_prompt_case_file = 'autotest/chat_prompt_case.yml' common_prompt_case_file = 'autotest/prompt_case.yml' config_file = 'autotest/config.yml' PROXY_PORT = 8000 @pytest.fixture(scope='session') def config(): # Use device-specific config file if DEVICE environment variable is set return get_config() @pytest.fixture(scope='session') def cli_case_config(): case_path = os.path.join(cli_prompt_case_file) with open(case_path) as f: case_config = yaml.load(f.read(), Loader=yaml.SafeLoader) return case_config @pytest.fixture(scope='class', autouse=True) def common_case_config(): case_path = os.path.join(common_prompt_case_file) with open(case_path) as f: case_config = yaml.load(f.read(), Loader=yaml.SafeLoader) return case_config @pytest.fixture(scope='session') def shared_ray_manager(): master_addr = DEFAULT_SERVER env_tag = os.environ.get('TEST_ENV') if env_tag: device_config_path = f'autotest/config_{env_tag}.yml' if os.path.exists(device_config_path): config_path = device_config_path else: config_path = config_file else: config_path = config_file with open(config_path) as f: env_config = yaml.load(f.read(), Loader=yaml.SafeLoader) run_id = os.environ.get('RUN_ID', 'local_run') log_dir = os.path.join(env_config.get('server_log_path', '/tmp/lmdeploy_test'), str(run_id).replace('/', '_')) manager = RayLMDeployManager(master_addr=master_addr, api_port=PROXY_PORT, log_dir=log_dir, health_check=True) manager.start_ray_cluster() if manager.is_master: print('🎯 Master node: Ray cluster started, waiting for worker nodes to join...') yield manager print(f'\n[Final Cleanup] Node {manager.node_rank} performing final resource cleanup...') manager.cleanup(force=True) @pytest.fixture(scope='session') def shared_proxy_manager(): master_addr = DEFAULT_SERVER manager = ProxyDistributedManager() if manager.is_master: manager.start() print(f'🎯 Master node: LMDeploy Proxy started on {master_addr}:{manager.proxy_port}') print('⏳ Waiting for worker nodes to connect...') yield manager print(f'\n[Final Cleanup] Node {manager.node_rank} performing final resource cleanup...') manager.cleanup() ================================================ FILE: autotest/evaluate/eval_config_chat.py ================================================ # flake8: noqa from mmengine.config import read_base from opencompass.models import OpenAISDK from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner from opencompass.runners import LocalRunner from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask from opencompass.utils.text_postprocessors import extract_non_reasoning_content ####################################################################### # PART 0 Essential Configs # ####################################################################### with read_base(): # Datasets from opencompass.configs.datasets.aime2025.aime2025_llmjudge_academic import aime2025_datasets from opencompass.configs.datasets.gpqa.gpqa_cascade_eval_academic import gpqa_datasets from opencompass.configs.datasets.HLE.hle_llmverify_academic import hle_datasets from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets from opencompass.configs.datasets.livecodebench.livecodebench_v6_academic import LCBCodeGeneration_dataset from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import mmlu_pro_datasets # Summary Groups from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups ####################################################################### # Model Configuration # ####################################################################### MODEL_NAME = '' MODEL_PATH = '' API_BASE = '' JUDGE_MODEL_NAME = '' JUDGE_MODEL_PATH = '' JUDGE_API_BASE = '' api_meta_template = dict(round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ]) # Use OpenAISDK to configure LMDeploy OpenAI interface models = [ dict(type=OpenAISDK, abbr=f'{MODEL_NAME}', path=MODEL_PATH, key='EMPTY', openai_api_base=API_BASE, retry=3, run_cfg=dict(num_gpus=0), meta_template=api_meta_template, timeout=10800, pred_postprocessor=dict(type=extract_non_reasoning_content)) ] ####################################################################### # PART 1 Datasets List # ####################################################################### # datasets list for evaluation mmlu_pro_datasets = [x for x in mmlu_pro_datasets if 'math' in x['abbr'] or 'other' in x['abbr']] # Modify datasets list to exclude hle_datasets and LCBCodeGeneration_dataset datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + [LCBCodeGeneration_dataset] # LLM judge config: using LLM to evaluate predictions judge_cfg = dict( type=OpenAISDK, abbr=f'{JUDGE_MODEL_NAME}', path=JUDGE_MODEL_NAME, key='EMPTY', openai_api_base=JUDGE_API_BASE, meta_template=dict(round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ]), query_per_second=16, batch_size=1024, temperature=0.001, tokenizer_path=JUDGE_MODEL_PATH, verbose=True, max_out_len=8192, max_seq_len=32768, mode='mid', ) for item in datasets: if 'judge_cfg' in item['eval_cfg']['evaluator']: item['eval_cfg']['evaluator']['judge_cfg'] = judge_cfg if 'llm_evaluator' in item['eval_cfg']['evaluator'].keys( ) and 'judge_cfg' in item['eval_cfg']['evaluator']['llm_evaluator']: item['eval_cfg']['evaluator']['llm_evaluator']['judge_cfg'] = judge_cfg ####################################################################### # PART 2 Dataset Summarizer # ####################################################################### core_summary_groups = [ { 'name': 'core_average', 'subsets': [ ['IFEval', 'Prompt-level-strict-accuracy'], ['hle_llmjudge', 'accuracy'], ['aime2025_repeat_32', 'accuracy (32 runs average)'], ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'], ['mmlu_pro', 'naive_average'], 'mmlu_pro_math', 'mmlu_pro_other', ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'], ], }, ] summarizer = dict( dataset_abbrs=[ ['core_average', 'naive_average'], ['IFEval', 'Prompt-level-strict-accuracy'], ['hle_llmjudge', 'accuracy'], ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'], ['aime2025_repeat_32', 'accuracy (32 runs average)'], ['mmlu_pro', 'naive_average'], 'mmlu_pro_math', 'mmlu_pro_other', ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'], ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []) + core_summary_groups, ) for item in datasets: if 'max_out_len' in item['infer_cfg']['inferencer']: del item['infer_cfg']['inferencer']['max_out_len'] NUM_WORKERS = 8 infer = dict( partitioner=dict(type=NumWorkerPartitioner, num_worker=NUM_WORKERS), runner=dict( type=LocalRunner, max_num_workers=64, retry=0, task=dict(type=OpenICLInferTask), ), ) # eval with local runner eval = dict( partitioner=dict(type=NaivePartitioner, n=10), runner=dict(type=LocalRunner, max_num_workers=64, task=dict(type=OpenICLEvalTask)), ) infer['partitioner']['num_worker'] = 64 ================================================ FILE: autotest/evaluate/test_api_evaluate.py ================================================ import os import time import pytest import utils.constant as constant from utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid from utils.evaluate_utils import eval_test from utils.proxy_distributed_utils import ApiServerPerTest, proxy_worker_node_wait from utils.ray_distributed_utils import ray_worker_node_wait from utils.run_restful_chat import start_openai_service, start_proxy_server, stop_restful_api, terminate_restful_api def _run_ray_distributed_test( config, run_config, worker_id, test_type='infer', manager=None, # ← New parameter: pass in shared manager eval_config_name='default'): """Universal distributed test executor (using shared Ray cluster)""" assert manager is not None, 'Manager instance must be provided' if 'gpt' in run_config.get('model', '').lower(): eval_config_name = 'gpt' elif 'intern-s1-pro' in run_config.get('model', '').lower(): eval_config_name = 'intern-s1-pro' if str(config.get('env_tag')) == 'ascend': eval_config_name = f'{eval_config_name}-2batch' preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {}) if manager.is_master: model_path = os.path.join(config['model_path'], run_config['model']) eval_path = config.get('eval_path') # Start API Server for current model (master node starts/stops, worker nodes verify) manager.start_lmdeploy_api_server(config=config, run_config=run_config) try: print(f'🧪 Master node executing {test_type} test ({eval_config_name})...') case_name = get_case_str_by_config(run_config) result, msg = eval_test(model_path, eval_path, case_name, port=constant.PROXY_PORT, test_type=test_type, **preset_config) assert result, f'❌ {test_type} test failed: {msg}' print(f'✅ {test_type} test passed') finally: # Clean up API Server for current model (worker nodes skip) manager.cleanup(force=False) else: time.sleep(10) ray_worker_node_wait(manager, timeout_minutes=4880) def _run_proxy_distributed_test(config, run_config, worker_id, test_type='infer', manager=None, eval_config_name='default'): assert manager is not None, 'Manager instance must be provided' if 'gpt' in run_config.get('model', '').lower(): eval_config_name = 'gpt' elif 'intern-s1-pro' in run_config.get('model', '').lower(): eval_config_name = 'intern-s1-pro' if str(config.get('env_tag')) == 'ascend': eval_config_name = f'{eval_config_name}-2batch' preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {}) model_name = run_config['model'] model_path = os.path.join(config['model_path'], model_name) api_server = ApiServerPerTest(proxy_manager=manager, config=config, run_config=run_config) api_server.start() try: if manager.is_master: api_server.wait_until_ready() print(f'🧪 Master node executing {test_type} test ({eval_config_name})...') eval_path = config.get('eval_path') case_name = get_case_str_by_config(run_config) extra_config = {'max-num-workers': 16} result, msg = eval_test(model_path, eval_path, case_name, port=constant.PROXY_PORT, test_type=test_type, extra_config=extra_config, **preset_config) assert result, f'❌ {test_type} test failed: {msg}' print(f'✅ {test_type} test passed') else: print(f'⏸️ Worker node {manager.node_rank} waiting for master to complete test...') proxy_worker_node_wait(manager, timeout_minutes=4880) finally: api_server.cleanup() if manager.is_master: time.sleep(1) def run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default'): """Run test with specified evaluation configuration.""" if 'gpt' in run_config.get('model', '').lower(): eval_config_name = 'gpt' elif 'sdar' in run_config.get('model', '').lower(): eval_config_name = 'sdar' elif 'intern-s1-pro' in run_config.get('model', '').lower(): eval_config_name = 'intern-s1-pro' if str(config.get('env_tag')) == 'a100': eval_config_name = f'{eval_config_name}-32k' elif str(config.get('env_tag')) == 'ascend': eval_config_name = f'{eval_config_name}-2batch' preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {}) eval_path = config.get('eval_path') total_gpus = int(os.environ.get('TOTAL_GPU_COUNT', '8')) work_num = int(total_gpus / run_config.get('parallel_config', {}).get('tp', 1)) extra_config = {'max-num-workers': min(work_num * 16, 64)} case_name = get_case_str_by_config(run_config) if test_type == 'infer': proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), constant.PROXY_PORT, f'{case_name}_infer') run_config_new = run_config.copy() if 'extra_params' not in run_config_new: run_config_new['extra_params'] = {} run_config_new['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{constant.PROXY_PORT}' run_config_new['extra_params']['server-name'] = constant.DEFAULT_SERVER from concurrent.futures import ThreadPoolExecutor def run_openai_service_start(i): return start_openai_service(config, run_config_new, f'gw{i}') with ThreadPoolExecutor(max_workers=work_num) as executor: futures = [executor.submit(run_openai_service_start, i) for i in range(int(work_num))] results = [] for future in futures: pid, content = future.result() results.append((pid, content)) try: model_path = os.path.join(config.get('model_path'), run_config.get('model')) eval_test(model_path, eval_path, case_name, port=constant.PROXY_PORT, test_type=test_type, extra_config=extra_config, **preset_config) finally: for i in range(work_num): terminate_restful_api(f'gw{i}') stop_restful_api(proxy_pid, proxy_process) else: # eval port = constant.PROXY_PORT + get_workerid(worker_id) proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), port, f'{case_name}_eval') eval_run_config = constant.EVAL_RUN_CONFIG.copy() if 'extra_params' not in eval_run_config: eval_run_config['extra_params'] = {} eval_run_config['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{port}' pid, content = start_openai_service(config, eval_run_config, worker_id) try: if pid > 0: model_path = os.path.join(config.get('model_path'), eval_run_config.get('model')) eval_test(model_path, eval_path, case_name, port=port, test_type=test_type, extra_config=extra_config, **preset_config) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) stop_restful_api(proxy_pid, proxy_process) def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, func_type='evaluate', extra={'session_len': 65536}) @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1})) def test_turbomind_infer_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2})) def test_turbomind_infer_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4})) def test_turbomind_infer_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8})) def test_turbomind_infer_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_cp2tp8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'cp': 2, 'tp': 8})) def test_turbomind_infer_cp2tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1})) def test_pytorch_restful_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2})) def test_pytorch_restful_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4})) def test_pytorch_restful_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8})) def test_pytorch_restful_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_restful_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_tp16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_restful_distributed_tp16(shared_ray_manager, config, run_config, worker_id): _run_ray_distributed_test(config=config, run_config=run_config, worker_id=worker_id, test_type='infer', manager=shared_ray_manager) @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 8, 'ep': 8})) def test_pytorch_restful_distributed_dpep8(shared_proxy_manager, config, run_config, worker_id): _run_proxy_distributed_test(config=config, run_config=run_config, worker_id=worker_id, test_type='infer', manager=shared_proxy_manager) @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 16, 'ep': 16})) def test_pytorch_restful_distributed_dpep16(shared_proxy_manager, config, run_config, worker_id): _run_proxy_distributed_test(config=config, run_config=run_config, worker_id=worker_id, test_type='infer', manager=shared_proxy_manager) @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1})) def test_turbomind_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2})) def test_turbomind_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4})) def test_turbomind_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8})) def test_turbomind_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1})) def test_pytorch_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2})) def test_pytorch_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4})) def test_pytorch_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8})) def test_pytorch_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_eval_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_tp16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_eval_distributed_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 8, 'ep': 8})) def test_pytorch_eval_distributed_dpep8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_distributed_dpep16 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 16, 'ep': 16})) def test_pytorch_eval_distributed_dpep16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_cp2tp8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'cp': 2, 'tp': 8})) def test_turbomind_eval_cp2tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') ================================================ FILE: autotest/evaluate/test_mllm_api_evaluate.py ================================================ import os import pytest import utils.constant as constant from utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid from utils.evaluate_utils import mllm_eval_test from utils.run_restful_chat import start_openai_service, start_proxy_server, stop_restful_api, terminate_restful_api def run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default'): extra_config = constant.MLLM_EVAL_CONFIGS.get(eval_config_name, {}) eval_path = config.get('mllm_eval_path') case_name = get_case_str_by_config(run_config) if test_type == 'infer': proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), constant.PROXY_PORT, f'{case_name}_infer') total_gpus = int(os.environ.get('TOTAL_GPU_COUNT', '8')) work_num = int(total_gpus / run_config.get('parallel_config', {}).get('tp', 1)) run_config_new = run_config.copy() if 'extra_params' not in run_config_new: run_config_new['extra_params'] = {} run_config_new['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{constant.PROXY_PORT}' from concurrent.futures import ThreadPoolExecutor def run_openai_service_start(i): return start_openai_service(config, run_config_new, f'gw{i}') with ThreadPoolExecutor(max_workers=work_num) as executor: futures = [executor.submit(run_openai_service_start, i) for i in range(int(work_num))] results = [] for future in futures: pid, content = future.result() results.append((pid, content)) try: model_path = os.path.join(config.get('model_path'), run_config.get('model')) extra_config['api-nproc'] = work_num * 16 mllm_eval_test(model_path, eval_path, case_name, port=constant.PROXY_PORT, test_type=test_type, extra_config=extra_config) finally: for i in range(work_num): terminate_restful_api(f'gw{i}') stop_restful_api(proxy_pid, proxy_process) else: # eval port = constant.PROXY_PORT + get_workerid(worker_id) proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), port, f'{case_name}_eval') eval_run_config = constant.EVAL_RUN_CONFIG.copy() if 'extra_params' not in eval_run_config: eval_run_config['extra_params'] = {} eval_run_config['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{port}' pid, content = start_openai_service(config, eval_run_config, worker_id) try: if pid > 0: model_path = os.path.join(config.get('model_path'), eval_run_config.get('model')) mllm_eval_test(model_path, eval_path, case_name, port=port, test_type=test_type) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) stop_restful_api(proxy_pid, proxy_process) def get_models(backend, parallel_config): return get_func_config_list(backend, parallel_config, model_type='vl_model', func_type='mllm_evaluate', extra={ 'session-len': 65536, 'cache-max-entry-count': 0.6 }) @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1})) def test_turbomind_vl_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2})) def test_turbomind_vl_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4})) def test_turbomind_vl_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8})) def test_turbomind_vl_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1})) def test_pytorch_vl_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2})) def test_pytorch_vl_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4})) def test_pytorch_vl_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8})) def test_pytorch_vl_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.infer @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_vl_eval_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'infer') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1})) def test_turbomind_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2})) def test_turbomind_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_4 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4})) def test_turbomind_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.turbomind @pytest.mark.gpu_num_8 @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8})) def test_turbomind_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_1 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1})) def test_pytorch_eval_tp1(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2})) def test_pytorch_eval_tp2(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4})) def test_pytorch_eval_tp4(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8})) def test_pytorch_eval_tp8(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') @pytest.mark.eval @pytest.mark.pytorch @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.flaky(reruns=0) @pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16})) def test_pytorch_eval_tp16(config, run_config, worker_id): run_eval_test(config, run_config, worker_id, 'eval') ================================================ FILE: autotest/interface/pipeline/test_pipeline_func.py ================================================ import multiprocessing as mp import pydantic import pytest from utils.config_utils import set_device_env_variable, unset_device_env_variable from utils.pipeline_chat import (assert_pipeline_batch_return, assert_pipeline_batch_stream_return, assert_pipeline_common_log, assert_pipeline_single_return, assert_pipeline_single_stream_return, save_pipeline_common_log) from utils.restful_return_check import has_repeated_fragment from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline from lmdeploy.utils import is_bf16_supported def init_pipeline(model_path, backend_config): if not is_bf16_supported() and isinstance(backend_config, PytorchEngineConfig): backend_config.dtype = 'float16' return pipeline(model_path, backend_config=backend_config) def run_case_in_spawn(worker_id, target, args): needs_device_env = 'gw' in worker_id if needs_device_env: set_device_env_variable(worker_id, parallel_config=2) ctx = mp.get_context('spawn') process = ctx.Process(target=target, args=args) process.start() process.join() if needs_device_env: unset_device_env_variable() def run_pipeline_testcase_prompt(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response = pipe('Hi, pls intro yourself') result, msg = assert_pipeline_single_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_prompt_stream(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response = [] for item in pipe.stream_infer('Hi, pls intro yourself'): response.append(item) result, msg = assert_pipeline_single_stream_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_multi_prompt(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) result, msg = assert_pipeline_batch_return(response, 2) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_multi_prompt_stream(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response = [] for item in pipe.stream_infer(['Pls intro yourself', 'Shanghai is']): response.append(item) result, msg = assert_pipeline_batch_stream_return(response, 2) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_message(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}]] response = pipe(prompts) result, msg = assert_pipeline_batch_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_message_stream(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}]] response = [] for item in pipe.stream_infer(prompts): response.append(item) result, msg = assert_pipeline_single_stream_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_message_batch(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}], [{'role': 'user', 'content': 'Shanghai is'}]] response = pipe(prompts) result, msg = assert_pipeline_batch_return(response, 2) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_message_batch_stream(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}], [{'role': 'user', 'content': 'Shanghai is'}]] response = [] for item in pipe.stream_infer(prompts): response.append(item) result, msg = assert_pipeline_batch_stream_return(response, 2) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_logprobs(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40, do_sample=True) response = pipe('Hi, pls intro yourself', gen_config=gen_config) result, msg = assert_pipeline_single_return(response, logprobs_num=10) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_logprobs_stream(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40, do_sample=True) response = [] for item in pipe.stream_infer('Hi, pls intro yourself', gen_config=gen_config): response.append(item) result, msg = assert_pipeline_single_stream_return(response, logprobs_num=10) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_session_len(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(session_len=10, tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) result = True for i in range(2): result &= response[i].finish_reason == 'error' result &= response[i].generate_token_len == 0 result &= response[i].text == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR' save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_min_new_tokens(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(min_new_tokens=200, ignore_eos=True) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) result = True for i in range(2): result &= response[i].finish_reason == 'length' result &= response[i].index == i save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_stop_words(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(stop_words=[' and', '浦', ' to']) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) result = True for i in range(2): result &= '浦' not in response[i].text result &= ' and' not in response[i].text and ' to ' not in response[i].text result &= response[i].finish_reason == 'stop' and response[i].generate_token_len < 50 save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_bad_words(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(bad_words=[' and', '浦', ' to']) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) result = True for i in range(2): result &= '浦' not in response[i].text and ' and' not in response[i].text and ' to ' not in response[i].text save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_special_words_false(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompt = '<|im_start|>system\n当开启工具以及代码时,根据需求选择合适的工具进行调用\n' + \ '<|im_end|><|im_start|>system name=<|interpreter|>\n你现在已经' + \ '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \ '发送含有 Python >代码的消息时,它将在该环境中执行。这个工具适用于多种场景,' + \ '如数据分析或处理(包括数据操作、统计分析、图表绘制),复杂的计算问题(解决数学和物理' + \ '难题),编程示例(理解编程概念或特性),文本处理和分析(比如文本解析和自然语言处理),机器学习和数据科学(用于' + \ '展示模型训练和数据可视化),以及文件操作和数据导入(处理CSV、JSON等格式的文件)。<|im_end|>\n' + \ '<|im_start|>user\n设 $L$ 为圆周$x^2+y^2=2x$,计算曲线积分:$I=\\int_L' + \ '{x\\mathrm{d}s}=$<|im_end|>\n<|im_start|>assistant' gen_config = GenerationConfig(skip_special_tokens=False) response = pipe(prompt, gen_config=gen_config) result = '<|action_start|><|interpreter|>' in response.text save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_special_words_true(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) prompt = '<|im_start|>system\n当开启工具以及代码时,根据需求选择合适的工具进行调用\n' + \ '<|im_end|><|im_start|>system name=<|interpreter|>\n你现在已经' + \ '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \ '发送含有 Python >代码的消息时,它将在该环境中执行。这个工具适用于多种场景,' + \ '如数据分析或处理(包括数据操作、统计分析、图表绘制),复杂的计算问题(解决数学和物理' + \ '难题),编程示例(理解编程概念或特性),文本处理和分析(比如文本解析和自然语言处理),机器学习和数据科学(用于' + \ '展示模型训练和数据可视化),以及文件操作和数据导入(处理CSV、JSON等格式的文件)。<|im_end|>\n' + \ '<|im_start|>user\n设 $L$ 为圆周$x^2+y^2=2x$,计算曲线积分:$I=\\int_L' + \ '{x\\mathrm{d}s}=$<|im_end|>\n<|im_start|>assistant' gen_config = GenerationConfig(skip_special_tokens=True) response = pipe(prompt, gen_config=gen_config) result = '<|action_start|><|interpreter|>' not in response.text save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_repetition_penalty(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(repetition_penalty=0.01, random_seed=1, min_new_tokens=50, do_sample=True) response = pipe('Shanghai is', gen_config=gen_config) result, msg = has_repeated_fragment(response.text) save_pipeline_common_log(config, file_name, result, response, msg=msg) pipe.close() def run_pipeline_testcase_repetition_penalty_bigger(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(repetition_penalty=1.2, random_seed=1) response = pipe('Shanghai is', gen_config=gen_config) result, msg = assert_pipeline_single_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_min_top_p(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(top_p=0, random_seed=1) response = pipe('Shanghai is', gen_config=gen_config) result, msg = assert_pipeline_single_return(response) save_pipeline_common_log(config, file_name, result, response, msg) pipe.close() def run_pipeline_testcase_min_top_k(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(top_k=1, max_new_tokens=20, do_sample=True) response_list = [] for _ in range(3): response_list.append(pipe('Shanghai is', gen_config=gen_config)) result = response_list[0].text == response_list[1].text and response_list[1].text == response_list[2].text save_pipeline_common_log(config, file_name, result, response_list) pipe.close() def run_pipeline_testcase_diff_random_seed(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) response_list = [] for i in range(3): gen_config = GenerationConfig(random_seed=i, temperature=1.0, top_k=40, do_sample=True) response_list.append(pipe('Shanghai is', gen_config=gen_config)) result = response_list[0].text != response_list[1].text and response_list[1].text != response_list[2].text save_pipeline_common_log(config, file_name, result, response_list) pipe.close() def run_pipeline_testcase_same_random_seed(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(random_seed=1, top_k=40, do_sample=True) response_list = [] for _ in range(3): response_list.append(pipe('Shanghai is', gen_config=gen_config)) result = response_list[0].text == response_list[1].text and response_list[1].text == response_list[2].text save_pipeline_common_log(config, file_name, result, response_list) pipe.close() def run_pipeline_testcase_do_sample_batch(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(temperature=1.0, top_k=40, do_sample=True) response = pipe(['Shanghai is'] * 3, gen_config=gen_config) result = response[0].text != response[1].text and response[1].text != response[2].text save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_max_new_tokens(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(max_new_tokens=5) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) result = True for i in range(2): result &= response[i].finish_reason == 'length' result &= response[i].generate_token_len == 6 or response[i].generate_token_len == 5 save_pipeline_common_log(config, file_name, result, response) pipe.close() def run_pipeline_testcase_ignore_eos(config, model, backend, file_name): model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(ignore_eos=True, max_new_tokens=256) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) result = True for i in range(2): result &= response[i].finish_reason == 'length' result &= response[i].generate_token_len == 257 or response[i].generate_token_len == 256 save_pipeline_common_log(config, file_name, result, response) pipe.close() @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_prompt(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_prompt, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_prompt_stream(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_prompt_stream, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_multi_prompt(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_multi_prompt, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_multi_prompt_stream(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_multi_prompt_stream, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_message(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_message, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_message_stream(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_message_stream, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_message_batch(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_message_batch, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_return_with_message_batch_stream(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_message_batch_stream, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig]) def test_return_check_logprobs(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_logprobs, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig]) def test_return_check_logprobs_stream(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_logprobs_stream, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_backend_config_session_len(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_session_len, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_min_new_tokens(config, model, backend, worker_id): file_name = f'pipeline_log_min_new_tokens_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_min_new_tokens, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_stop_words(config, model, backend, worker_id): file_name = f'pipeline_log_stop_words_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_stop_words, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_bad_words(config, model, backend, worker_id): file_name = f'pipeline_log_bad_words_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_bad_words, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_special_words_false(config, model, backend, worker_id): file_name = f'pipeline_log_special_words_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_special_words_false, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_special_words_true(config, model, backend, worker_id): file_name = f'pipeline_log_special_words_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_special_words_true, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_minimum_repetition_penalty(config, model, backend, worker_id): file_name = f'pipeline_log_repetition_penalty_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_repetition_penalty, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_repetition_penalty_bigger_than_1(config, model, backend, worker_id): file_name = f'pipeline_log_repetition_penalty_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_repetition_penalty_bigger, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_minimun_topp(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_min_top_p, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_minimun_topk(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_min_top_k, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_diff_random_seed(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_diff_random_seed, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_same_random_seed(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_same_random_seed, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_do_sample_batch(config, model, backend, worker_id): file_name = f'pipeline_log_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_do_sample_batch, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_max_new_tokens(config, model, backend, worker_id): file_name = f'pipeline_log_max_new_tokens_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_max_new_tokens, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_gen_config_ignore_eos(config, model, backend, worker_id): file_name = f'pipeline_log_ignore_eos_{worker_id}.txt' run_case_in_spawn(worker_id, run_pipeline_testcase_ignore_eos, (config, model, backend, file_name)) assert_pipeline_common_log(config, file_name) @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) def test_backend_config_input_validation(config, model, backend, worker_id): if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=2) pipe = init_pipeline(model_path, backend_config=backend_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(top_p=-0.01) pipe('Shanghai is', gen_config=gen_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(top_p=1.01) pipe('Shanghai is', gen_config=gen_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(temperature=-1) pipe('Shanghai is', gen_config=gen_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(temperature=2.01) pipe('Shanghai is', gen_config=gen_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(top_k=-1) pipe('Shanghai is', gen_config=gen_config) with pytest.raises(AssertionError): gen_config = GenerationConfig(n=-1) pipe('Shanghai is', gen_config=gen_config) pipe.close() if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig]) def test_backend_config_validate_turbomind(config, model, backend, worker_id): if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) model_path = '/'.join([config.get('model_path'), model]) with pytest.raises(pydantic.ValidationError, match='tp must be a positive integer'): backend_config = backend(tp=0) pipeline(model_path, backend_config=backend_config) with pytest.raises(AssertionError, match='max_batch_size should be greater than 0, but got 0'): backend_config = backend(max_batch_size=0) pipeline(model_path, backend_config=backend_config) with pytest.raises(pydantic.ValidationError): backend_config = backend(cache_max_entry_count=0) pipeline(model_path, backend_config=backend_config) with pytest.raises(pydantic.ValidationError): backend_config = backend(quant_policy=1) pipeline(model_path, backend_config=backend_config) with pytest.raises(pydantic.ValidationError): backend_config = backend(rope_scaling_factor=-1) pipeline(model_path, backend_config=backend_config) with pytest.raises(pydantic.ValidationError): backend_config = backend(max_prefill_token_num=-1) pipeline(model_path, backend_config=backend_config) with pytest.raises(pydantic.ValidationError): backend_config = backend(num_tokens_per_iter=-1) pipeline(model_path, backend_config=backend_config) if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B']) @pytest.mark.parametrize('backend', [PytorchEngineConfig]) def test_backend_config_validate_pytorch(config, model, backend, worker_id): if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) model_path = '/'.join([config.get('model_path'), model]) with pytest.raises(AssertionError): backend_config = backend(tp=0) init_pipeline(model_path, backend_config=backend_config) with pytest.raises(SystemExit): backend_config = backend(max_batch_size=0) init_pipeline(model_path, backend_config=backend_config) with pytest.raises(AssertionError): backend_config = backend(cache_max_entry_count=0) init_pipeline(model_path, backend_config=backend_config) with pytest.raises(AssertionError): backend_config = backend(num_cpu_blocks=-1) init_pipeline(model_path, backend_config=backend_config) with pytest.raises(AssertionError): backend_config = backend(num_gpu_blocks=-1) init_pipeline(model_path, backend_config=backend_config) if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig]) def test_backend_config_tp(config, model, backend, worker_id): with pytest.raises(AssertionError): if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) model_path = '/'.join([config.get('model_path'), model]) backend_config = backend(tp=100) pipe = init_pipeline(model_path, backend_config=backend_config) pipe.close() if 'gw' in worker_id: unset_device_env_variable() ================================================ FILE: autotest/interface/pipeline/test_pipeline_longtext_func.py ================================================ import multiprocessing as mp import os import numpy as np import pytest from utils.config_utils import set_device_env_variable, unset_device_env_variable from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline SESSION_LEN = 198000 SESSION_LEN_128K = 128000 SESSION_LEN_32K = 32000 SESSION_LEN_CONFIG = { 'Qwen/Qwen2.5-7B-Instruct': SESSION_LEN_32K, 'Qwen/Qwen3-235B-A22B': SESSION_LEN_128K, 'Qwen/Qwen3-30B-A3B': SESSION_LEN_128K, 'Qwen/Qwen3-32B': SESSION_LEN_128K, 'meta-llama/Meta-Llama-3-1-8B-Instruct': SESSION_LEN_128K, 'meta-llama/Meta-Llama-3-1-70B-Instruct': SESSION_LEN_128K, } def run_case_in_spawn(target, args): ctx = mp.get_context('spawn') process = ctx.Process(target=target, args=args) process.start() process.join() @pytest.mark.gpu_num_1 @pytest.mark.parametrize('model', ['Qwen/Qwen3-8B']) def test_history_issue_tp1(config, model, worker_id): if 'gw' in worker_id: set_device_env_variable(worker_id) run_case_in_spawn(stream_infer_worker, (config, model, 1)) if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.gpu_num_2 @pytest.mark.parametrize('model', ['Qwen/Qwen3-32B', 'Qwen/Qwen3-32B-inner-4bits', 'Qwen/Qwen3-30B-A3B']) def test_history_issue_tp2(config, model, worker_id): if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500) run_case_in_spawn(stream_infer_worker, (config, model, 2)) if 'gw' in worker_id: unset_device_env_variable() def stream_infer_worker(config, model, tp_num): model_path = os.path.join(config.get('model_path'), model) backend_config = TurbomindEngineConfig(session_len=SESSION_LEN, tp=tp_num) pipe = pipeline(model_path, backend_config=backend_config) prompt = '今 天 心 ' * int(SESSION_LEN / 6) gen_config = GenerationConfig(top_k=40) # stream infer for outputs in pipe.stream_infer(prompt, gen_config=gen_config): continue print(outputs) prompts = ['今 天 心 ' * int(SESSION_LEN / 6)] * 2 # stream infer for outputs in pipe.stream_infer(prompts, gen_config=gen_config): continue print(outputs) pipe.close() @pytest.mark.gpu_num_1 @pytest.mark.parametrize('model', ['Qwen/Qwen2.5-7B-Instruct', 'meta-llama/Meta-Llama-3-1-8B-Instruct']) @pytest.mark.parametrize('backend', ['turbomind', 'pytorch']) def test_long_test_passkey_tp1(config, model, backend, worker_id): log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log']) if 'gw' in worker_id: set_device_env_variable(worker_id) run_case_in_spawn(passkey_retrival_worker, (config, model, backend, log_name, 1, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K))) if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.gpu_num_2 @pytest.mark.parametrize('model', ['Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-32B']) @pytest.mark.parametrize('backend', ['turbomind', 'pytorch']) def test_long_test_passkey_tp2(config, model, backend, worker_id): log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log']) if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=2) os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500) run_case_in_spawn(passkey_retrival_worker, (config, model, backend, log_name, 2, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K))) if 'gw' in worker_id: unset_device_env_variable() @pytest.mark.gpu_num_8 @pytest.mark.parametrize('model', ['Qwen/Qwen3-235B-A22B', 'meta-llama/Meta-Llama-3-1-70B-Instruct']) @pytest.mark.parametrize('backend', ['turbomind', 'pytorch']) def test_long_test_passkey_tp8(config, model, backend, worker_id): log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log']) if 'gw' in worker_id: set_device_env_variable(worker_id, parallel_config=8) os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500) run_case_in_spawn(passkey_retrival_worker, (config, model, backend, log_name, 8, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K))) if 'gw' in worker_id: unset_device_env_variable() YARN_CONFIG = {'rope_scaling': {'rope_type': 'yarn', 'factor': 4.0, 'original_max_position_embeddings': 32768}} NTK_CONFIG = { 'rope_scaling': { 'type': 'dynamic', 'factor': 2.0 }, } def passkey_retrival_worker(config, model, backend, log_name, tp_num, session_len: int = SESSION_LEN_128K): model_path = '/'.join([config.get('model_path'), model]) if backend == 'turbomind': if 'qwen' in model.lower(): backend_config = TurbomindEngineConfig(session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=tp_num, hf_overrides=YARN_CONFIG) elif 'intern-s1' in model.lower(): backend_config = TurbomindEngineConfig(session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=tp_num, hf_overrides={'text_config': NTK_CONFIG}) else: backend_config = TurbomindEngineConfig(session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=tp_num) else: if 'qwen' in model.lower(): backend_config = PytorchEngineConfig(session_len=session_len, tp=tp_num, max_batch_size=1, hf_overrides=YARN_CONFIG) elif 'intern-s1' in model.lower(): backend_config = TurbomindEngineConfig(session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=tp_num, hf_overrides={'text_config': NTK_CONFIG}) else: backend_config = PytorchEngineConfig(session_len=session_len, tp=tp_num, max_batch_size=1) pipe = pipeline(model_path, backend_config=backend_config) gen_config = GenerationConfig(top_k=40) # inference pass_key1, prompt = get_passkey_prompt(pipe, session_len) response1 = pipe(prompt, gen_config=gen_config) # inference pass_key2, prompt = get_passkey_prompt(pipe, session_len) response2 = pipe([prompt] * 2, gen_config=gen_config) pipe.close() assert str(pass_key1) in response1.text, str(response1) assert str(pass_key2) in response2[0].text and str(pass_key2) in response2[1].text, str(response2) def get_passkey_prompt(pipe, session_len): # create long context input tok = pipe.tokenizer task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.' # noqa: E501 garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.' # noqa: E501 n_times = (session_len - 1000) // len(tok.encode(garbage)) n_garbage_prefix = np.random.randint(0, n_times) n_garbage_suffix = n_times - n_garbage_prefix garbage_prefix = ' '.join([garbage] * n_garbage_prefix) garbage_suffix = ' '.join([garbage] * n_garbage_suffix) pass_key = np.random.randint(1, 50000) information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.' # noqa: E501 final_question = 'What is the pass key? The pass key is' lines = [ task_description, garbage_prefix, information_line, garbage_suffix, final_question, ] # inference prompt = ' '.join(lines) return pass_key, prompt ================================================ FILE: autotest/interface/restful/test_restful_chat_completions_v1.py ================================================ from typing import Literal import pytest from openai import OpenAI from utils.constant import BACKEND_LIST, RESTFUL_MODEL_LIST from utils.restful_return_check import (assert_chat_completions_batch_return, assert_chat_completions_stream_return, has_repeated_fragment) from lmdeploy.serve.openai.api_client import APIClient, get_model_list BASE_HTTP_URL = 'http://localhost' DEFAULT_PORT = 23333 MODEL = 'internlm/Intern-S1' BASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)]) @pytest.mark.order(8) @pytest.mark.chat @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize('backend', BACKEND_LIST) @pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST) class TestRestfulInterfaceBase: @pytest.mark.interns1 def test_get_model(self, config, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] assert model_name == '/'.join([config.get('model_path'), MODEL]), api_client.available_models model_list = get_model_list(BASE_URL + '/v1/models') assert model_name in model_list, model_list @pytest.mark.interns1 def test_encode_s1(self, backend, model_case): api_client = APIClient(BASE_URL) input_ids1, length1 = api_client.encode('Hi, pls intro yourself') input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False) input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True) input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False) input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False) assert len(input_ids1) == length1 and length1 > 0 assert len(input_ids2) == length2 and length2 > 0 assert len(input_ids3) == length3 and length3 > 0 assert len(input_ids4) == length4 and length4 > 0 assert len(input_ids5) == length5 and length5 > 0 assert length1 == length2 assert input_ids2 == input_ids1 assert input_ids1[0] == 13048 and input_ids3[0] == 151644 assert length5 == length2 * 100 assert input_ids5 == input_ids2 * 100 @pytest.mark.internlm2_5 def test_encode(self, backend, model_case): api_client = APIClient(BASE_URL) input_ids1, length1 = api_client.encode('Hi, pls intro yourself') input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False) input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True) input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False) input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False) assert len(input_ids1) == length1 and length1 > 0 assert len(input_ids2) == length2 and length2 > 0 assert len(input_ids3) == length3 and length3 > 0 assert len(input_ids4) == length4 and length4 > 0 assert len(input_ids5) == length5 and length5 > 0 assert length1 == length2 + 1 assert input_ids2 == input_ids1[1:] assert input_ids1[0] == 1 and input_ids3[0] == 1 assert length5 == length2 * 100 assert input_ids5 == input_ids2 * 100 @pytest.mark.order(8) @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize('backend', BACKEND_LIST) @pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST) class TestRestfulInterfaceChatCompletions: def test_return_info_with_prompt(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) def test_return_info_with_messegae(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) def test_return_info_with_prompt_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) def test_return_info_with_messegae_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], stream=True, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) def test_single_stopword(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=' is', temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) assert ' is' not in output.get('choices')[0].get('message').get('content') assert output.get('choices')[0].get('finish_reason') == 'stop' def test_single_stopword_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=' is', stream=True, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) assert ' to' not in outputList[index].get('choices')[0].get('delta').get('content') assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop' def test_array_stopwords(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=[' is', '上海', ' to'], temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) assert ' is' not in output.get('choices')[0].get('message').get('content') assert ' 上海' not in output.get('choices')[0].get('message').get('content') assert ' to ' not in output.get('choices')[0].get('message').get('content') assert output.get('choices')[0].get('finish_reason') == 'stop' def test_array_stopwords_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=[' is', '上海', ' to'], stream=True, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) assert ' is' not in outputList[index].get('choices')[0].get('delta').get('content') assert '上海' not in outputList[index].get('choices')[0].get('delta').get('content') assert ' to ' not in outputList[index].get('choices')[0].get('delta').get('content') assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop' @pytest.mark.internlm2_5 def test_special_words(self, backend, model_case): message = '<|im_start|>system\n当开启工具以及代码时,根据需求选择合适的工具进行调用\n' + \ '<|im_end|><|im_start|>system name=<|interpreter|>\n你现在已经' + \ '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \ '发送含有 Python >代码的消息时,它将在该环境中执行。这个工具适用于多种场景,' + \ '如数据分析或处理(包括数据操作、统计分析、图表绘制),复杂的计算问题(解决数学和物理' + \ '难题),编程示例(理解编程概念或特性),文本处理和分析(比如文本解析和自然语言处理),机器学习和数据科学(用于' + \ '展示模型训练和数据可视化),以及文件操作和数据导入(处理CSV、JSON等格式的文件)。<|im_end|>\n' + \ '<|im_start|>user\n设 $L$ 为圆周$x^2+y^2=2x$,计算曲线积分:$I=\\int_L' + \ '{x\\mathrm{d}s}=$<|im_end|>\n<|im_start|>assistant' api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=message, skip_special_tokens=False, temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) assert '<|action_start|><|interpreter|>' in output.get('choices')[0].get('message').get('content') for output in api_client.chat_completions_v1(model=model_name, messages=message, skip_special_tokens=True, temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) assert '<|action_start|><|interpreter|>' not in output.get('choices')[0].get('message').get('content') def test_minimum_repetition_penalty(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], repetition_penalty=0.0000001, temperature=0.01, max_tokens=200, min_new_tokens=100): continue assert_chat_completions_batch_return(output, model_name) result, msg = has_repeated_fragment(output.get('choices')[0].get('message').get('content')) assert result, msg def test_minimum_repetition_penalty_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, repetition_penalty=0.0000001, temperature=0.01, max_tokens=200, min_new_tokens=100): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') result, msg = has_repeated_fragment(response) assert result, msg def test_repetition_penalty_bigger_than_1(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], repetition_penalty=1.2, temperature=0.01, max_tokens=200): continue assert_chat_completions_batch_return(output, model_name) def test_repetition_penalty_bigger_than_1_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, repetition_penalty=1.2, temperature=0.01, max_tokens=200): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) continue def test_minimum_topp(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for i in range(3): for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], top_p=0.0000000001, max_tokens=10): outputList.append(output) assert_chat_completions_batch_return(output, model_name) assert outputList[0].get('choices')[0].get('message').get('content') == outputList[1].get('choices')[0].get( 'message').get('content') assert outputList[1].get('choices')[0].get('message').get('content') == outputList[2].get('choices')[0].get( 'message').get('content') def test_minimum_topp_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] responseList = [] for i in range(3): outputList = [] response = '' for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, top_p=0.0000000001, max_tokens=10): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') responseList.append(response) assert responseList[0] == responseList[1] or responseList[1] == responseList[2] def test_mistake_modelname_return(self, backend, model_case): api_client = APIClient(BASE_URL) for output in api_client.chat_completions_v1(model='error', messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.01): continue assert output.get('code') == 404 assert output.get('message') == 'The model \'error\' does not exist.' assert output.get('object') == 'error' def test_mistake_modelname_return_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) outputList = [] for output in api_client.chat_completions_v1(model='error', messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, max_tokens=5, temperature=0.01): outputList.append(output) assert output.get('code') == 404 assert output.get('message') == 'The model \'error\' does not exist.' assert output.get('object') == 'error' assert len(outputList) == 1 def test_mutilple_times_response_should_not_same(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for i in range(3): for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is', }, ], max_tokens=100): outputList.append(output) assert_chat_completions_batch_return(output, model_name) assert outputList[0].get('choices')[0].get('message').get('content') != outputList[1].get('choices')[0].get( 'message').get('content') or outputList[1].get('choices')[0].get('message').get( 'content') != outputList[2].get('choices')[0].get('message').get('content') def test_mutilple_times_response_should_not_same_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] responseList = [] for i in range(3): outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is', }, ], stream=True, max_tokens=100): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') responseList.append(response) assert responseList[0] != responseList[1] or responseList[1] == responseList[2] def test_longtext_input(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' * 100000, }, ], temperature=0.01): continue assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('choices')[0].get('message').get('content') == '' def test_longtext_input_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' * 100000, }, ], stream=True, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[0], model_name, is_last=True) assert outputList[0].get('choices')[0].get('finish_reason') == 'length' assert outputList[0].get('choices')[0].get('delta').get('content') == '' assert len(outputList) == 1 def test_ignore_eos(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, what is your name?' }, ], ignore_eos=True, max_tokens=100, temperature=0.01): continue assert_chat_completions_batch_return(output, model_name) assert output.get('usage').get('completion_tokens') == 101 or output.get('usage').get( 'completion_tokens') == 100 assert output.get('choices')[0].get('finish_reason') == 'length' def test_ignore_eos_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, what is your name?' }, ], ignore_eos=True, stream=True, max_tokens=100, temperature=0.01): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length >= 99 and length <= 101 def __test_max_tokens_or_max_completion_tokens( self, max_tokens_or_max_completion_tokens: Literal['max_tokens', 'max_completion_tokens'], ): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] if max_tokens_or_max_completion_tokens == 'max_tokens': for output in api_client.chat_completions_v1( model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, ): continue else: for output in api_client.chat_completions_v1( model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_completion_tokens=5, temperature=0.01, ): continue assert_chat_completions_batch_return(output, model_name) assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5 def test_max_tokens(self, backend, model_case): self.__test_max_tokens_or_max_completion_tokens('max_tokens') def test_max_completion_tokens(self, backend, model_case): self.__test_max_tokens_or_max_completion_tokens('max_completion_tokens') def __test_max_tokens_streaming_or_max_completion_tokens_streaming( self, max_tokens_or_max_completion_tokens: Literal['max_tokens', 'max_completion_tokens'], ): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] if max_tokens_or_max_completion_tokens == 'max_tokens': for output in api_client.chat_completions_v1( model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, max_tokens=5, temperature=0.01, ): outputList.append(output) else: for output in api_client.chat_completions_v1( model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, max_completion_tokens=5, temperature=0.01, ): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length == 5 or length == 6 def test_max_tokens_streaming(self, backend, model_case): self.__test_max_tokens_streaming_or_max_completion_tokens_streaming('max_tokens') def test_max_completion_tokens_streaming(self, backend, model_case): self.__test_max_tokens_streaming_or_max_completion_tokens_streaming('max_completion_tokens') @pytest.mark.not_pytorch def test_logprobs(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, logprobs=True, top_logprobs=10): continue assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10) assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5 @pytest.mark.not_pytorch def test_logprobs_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.chat_completions_v1(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], stream=True, max_tokens=5, temperature=0.01, logprobs=True, top_logprobs=10): outputList.append(output) assert_chat_completions_stream_return(outputList[-1], model_name, True, check_logprobs=True, logprobs_num=10) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name, check_logprobs=True, logprobs_num=10) response += outputList[index].get('choices')[0].get('delta').get('content') length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length == 5 or length == 6 @pytest.mark.order(8) @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize('backend', BACKEND_LIST) @pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST) class TestRestfulOpenAI: @pytest.mark.pr_test def test_return_info(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.01) output = outputs.model_dump() assert_chat_completions_batch_return(output, model_name) @pytest.mark.pr_test def test_return_info_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.01, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) def test_single_stopword(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], temperature=0.01, stop=' is') output = outputs.model_dump() assert_chat_completions_batch_return(output, model_name) assert ' is' not in output.get('choices')[0].get('message').get('content') assert output.get('choices')[0].get('finish_reason') == 'stop' @pytest.mark.pr_test def test_single_stopword_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=' is', temperature=0.01, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) assert ' is ' not in outputList[index].get('choices')[0].get('delta').get('content') assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop' def test_array_stopwords(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create( model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], temperature=0.01, stop=[' is', '上海', ' to'], ) output = outputs.model_dump() assert_chat_completions_batch_return(output, model_name) assert ' is' not in output.get('choices')[0].get('message').get('content') assert ' 上海' not in output.get('choices')[0].get('message').get('content') assert ' to' not in output.get('choices')[0].get('message').get('content') assert output.get('choices')[0].get('finish_reason') == 'stop' def test_array_stopwords_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], stop=[' is', '上海', ' to'], temperature=0.01, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) assert ' is' not in outputList[index].get('choices')[0].get('delta').get('content') assert '上海' not in outputList[index].get('choices')[0].get('delta').get('content') assert ' to ' not in outputList[index].get('choices')[0].get('delta').get('content') assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop' @pytest.mark.pr_test def test_minimum_topp(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputList = [] for i in range(3): outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], temperature=0.01, top_p=0.0000000001, max_tokens=10) output = outputs.model_dump() outputList.append(output) assert_chat_completions_batch_return(output, model_name) assert outputList[0].get('choices')[0].get('message').get('content') == outputList[1].get('choices')[0].get( 'message').get('content') assert outputList[1].get('choices')[0].get('message').get('content') == outputList[2].get('choices')[0].get( 'message').get('content') def test_minimum_topp_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id responseList = [] for i in range(3): outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], top_p=0.0000000001, max_tokens=10, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') responseList.append(response) assert responseList[0] == responseList[1] or responseList[1] == responseList[2] @pytest.mark.pr_test def test_mistake_modelname_return(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') with pytest.raises(Exception, match='The model \'error\' does not exist.'): client.chat.completions.create( model='error', messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], temperature=0.01, stop=[' is', '上海', ' to'], ) def test_mistake_modelname_return_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') with pytest.raises(Exception, match='The model \'error\' does not exist.'): client.chat.completions.create(model='error', messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, stream=True) @pytest.mark.pr_test def test_mutilple_times_response_should_not_same(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputList = [] for i in range(3): outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Shanghai is' }, ], max_tokens=100) output = outputs.model_dump() outputList.append(output) assert_chat_completions_batch_return(output, model_name) assert outputList[0].get('choices')[0].get('message').get('content') != outputList[1].get('choices')[0].get( 'message').get('content') or outputList[1].get('choices')[0].get('message').get( 'content') != outputList[2].get('choices')[0].get('message').get('content') def test_mutilple_times_response_should_not_same_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id responseList = [] for i in range(3): outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=100, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') responseList.append(response) assert responseList[0] != responseList[1] or responseList[1] == responseList[2] def test_longtext_input(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' * 100000 }, ], max_tokens=100) output = outputs.model_dump() print(output) assert output.get('choices')[0].get('finish_reason') == 'error' assert output.get('choices')[0].get('message').get( 'content') == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR' @pytest.mark.pr_test def test_longtext_input_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' * 100000 }, ], max_tokens=100, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[0], model_name, is_last=True) assert outputList[0].get('choices')[0].get('finish_reason') == 'error' assert outputList[0].get('choices')[0].get('delta').get( 'content') == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR' assert len(outputList) == 1 @pytest.mark.pr_test def test_max_tokens(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01) output = outputs.model_dump() assert_chat_completions_batch_return(output, model_name) assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5 def test_max_tokens_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name) response += outputList[index].get('choices')[0].get('delta').get('content') api_client = APIClient(BASE_URL) length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length == 5 or length == 6 @pytest.mark.not_pytorch @pytest.mark.pr_test def test_logprobs(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, logprobs=True, top_logprobs=10) output = outputs.model_dump() assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10) assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5 @pytest.mark.not_pytorch @pytest.mark.pr_test def test_logprobs_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id outputs = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], max_tokens=5, temperature=0.01, logprobs=True, top_logprobs=10, stream=True) outputList = [] for output in outputs: outputList.append(output.model_dump()) assert_chat_completions_stream_return(outputList[-1], model_name, True, check_logprobs=True, logprobs_num=10) response = '' for index in range(0, len(outputList) - 1): assert_chat_completions_stream_return(outputList[index], model_name, check_logprobs=True, logprobs_num=10) response += outputList[index].get('choices')[0].get('delta').get('content') api_client = APIClient(BASE_URL) length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length == 5 or length == 6 def test_input_validation(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id messages = [ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p=0) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p=1.01) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p='test') with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, n=0) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, n='test') with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature=-0.01) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature=2.01) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature='test') def test_input_validation_streaming(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id messages = [ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p=0, stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p=1.01, stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, top_p='test', stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, n=0, stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, n='test', stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature=-0.01, stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature=2.01, stream=True) with pytest.raises(Exception): client.chat.completions.create(model=model_name, messages=messages, temperature='test', stream=True) @pytest.mark.interns1 def test_disable_think(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id output = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.8, top_p=0.8) print(output) assert '' in str(output.model_dump()) output = client.chat.completions.create(model=model_name, messages=[ { 'role': 'user', 'content': 'Hi, pls intro yourself' }, ], temperature=0.8, top_p=0.8, extra_body={ 'enable_thinking': False, }) response = output.model_dump() assert '' not in response assert_chat_completions_batch_return(response, model_name) @pytest.mark.interns1 def test_disable_think_with_image(self, backend, model_case): client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1') model_name = client.models.list().data[0].id output = client.chat.completions.create( model=model_name, messages=[ { 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }, ], temperature=0.8, top_p=0.8) print(output) assert '' in str(output.model_dump()) output = client.chat.completions.create( model=model_name, messages=[ { 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }, ], temperature=0.8, top_p=0.8, extra_body={ 'enable_thinking': False, }) response = output.model_dump() assert '' not in response assert_chat_completions_batch_return(response, model_name) ================================================ FILE: autotest/interface/restful/test_restful_completions_v1.py ================================================ import pytest from utils.constant import BACKEND_LIST, RESTFUL_BASE_MODEL_LIST from utils.restful_return_check import assert_completions_batch_return, assert_completions_stream_return from lmdeploy.serve.openai.api_client import APIClient BASE_HTTP_URL = 'http://localhost' DEFAULT_PORT = 23333 MODEL = 'internlm/internlm2_5-20b' BASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)]) @pytest.mark.parametrize('backend', BACKEND_LIST) @pytest.mark.parametrize('model_case', RESTFUL_BASE_MODEL_LIST) class TestRestfulInterfaceBase: @pytest.mark.internlm2_5 def test_get_model(self, config, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] assert model_name == '/'.join([config.get('model_path'), MODEL]), api_client.available_models @pytest.mark.internlm2_5 def test_encode(self, backend, model_case): api_client = APIClient(BASE_URL) input_ids1, length1 = api_client.encode('Hi, pls intro yourself') input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False) input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True) input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False) input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False) assert len(input_ids1) == length1 and length1 > 0 assert len(input_ids2) == length2 and length2 > 0 assert len(input_ids3) == length3 and length3 > 0 assert len(input_ids4) == length4 and length4 > 0 assert len(input_ids5) == length5 and length5 > 0 assert length1 == length2 + 1 assert input_ids2 == input_ids1[1:] assert input_ids1[0] == 1 and input_ids3[0] == 1 assert length5 == length2 * 100 assert input_ids5 == input_ids2 * 100 def test_return(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for item in api_client.completions_v1( model=model_name, prompt='Hi, pls intro yourself', max_tokens=16, temperature=0.01, ): completion_tokens = item['usage']['completion_tokens'] assert completion_tokens > 0 assert completion_tokens <= 17 assert completion_tokens >= 16 assert item.get('choices')[0].get('finish_reason') in ['length'] assert_completions_batch_return(item, model_name) def test_return_streaming(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for item in api_client.completions_v1(model=model_name, prompt='Hi, pls intro yourself', max_tokens=16, stream=True, temperature=0.01): outputList.append(item) assert_completions_stream_return(outputList[-1], model_name, True) for index in range(0, len(outputList) - 1): assert_completions_stream_return(outputList[index], model_name) def test_max_tokens(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='Hi, pls intro yourself', max_tokens=16, temperature=0.01): completion_tokens = item['usage']['completion_tokens'] assert completion_tokens > 0 assert completion_tokens <= 17 assert completion_tokens >= 16 assert item.get('choices')[0].get('finish_reason') in ['length'] def test_single_stopword(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='Shanghai is', max_tokens=200, stop=' Shanghai', temperature=0.01): assert ' Shanghai' not in item.get('choices')[0].get('text') assert item.get('choices')[0].get('finish_reason') in ['stop', 'length'] def test_array_stopwords(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='Shanghai is', max_tokens=200, stop=[' Shanghai', ' city', ' China'], temperature=0.01): assert ' Shanghai' not in item.get('choices')[0].get('text') assert ' city' not in item.get('choices')[0].get('text') assert ' China' not in item.get('choices')[0].get('text') assert item.get('choices')[0].get('finish_reason') in ['stop', 'length'] def test_completions_stream(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.completions_v1(model=model_name, prompt='Shanghai is', stream='true', temperature=0.01): outputList.append(output) for index in range(1, len(outputList) - 1): output = outputList[index] assert (output.get('model') == model_name) for message in output.get('choices'): assert message.get('index') == 0 assert len(message.get('text')) > 0 output_last = outputList[len(outputList) - 1] assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length'] def test_completions_stream_stopword(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.completions_v1(model=model_name, prompt='Beijing is', stream='true', stop=' is', temperature=0.01): outputList.append(output) for index in range(1, len(outputList) - 2): output = outputList[index] assert (output.get('model') == model_name) assert (output.get('object') == 'text_completion') for message in output.get('choices'): assert ' is' not in message.get('text') assert message.get('index') == 0 assert len(message.get('text')) > 0 output_last = outputList[len(outputList) - 1] assert output_last.get('choices')[0].get('text') == '' assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length'] def test_completions_stream_stopwords(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] outputList = [] for output in api_client.completions_v1(model=model_name, prompt='Beijing is', stream='true', stop=[' Beijing', ' city', ' China'], temperature=0.01): outputList.append(output) for index in range(1, len(outputList) - 2): output = outputList[index] assert (output.get('model') == model_name) assert (output.get('object') == 'text_completion') for message in output.get('choices'): assert ' Beijing' not in message.get('text') assert ' city' not in message.get('text') assert ' China' not in message.get('text') assert message.get('index') == 0 assert len(message.get('text')) > 0 output_last = outputList[len(outputList) - 1] assert output_last.get('choices')[0].get('text') == '' assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length'] def test_batch_prompt_order(self, backend, model_case): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt=['你好', '今天天气怎么样', '你是谁', '帮我写一首以梅花为主题的五言律诗', '5+2等于多少'], max_tokens=400, min_tokens=50): print(str(item)) assert '天' in item.get('choices')[1].get('text'), item.get('choices')[1].get('text') assert '梅' in item.get('choices')[3].get('text') or '对仗' in item.get('choices')[3].get('text'), item.get( 'choices')[3].get('text') assert '7' in item.get('choices')[4].get('text'), item.get('choices')[4].get('text') ================================================ FILE: autotest/interface/restful/test_restful_generate.py ================================================ import json import os import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from typing import Any import pytest import requests from transformers import AutoTokenizer from utils.constant import BACKEND_LIST, DEFAULT_SERVER, RESTFUL_MODEL_LIST from utils.toolkit import encode_text, parse_sse_stream BASE_HTTP_URL = f'http://{DEFAULT_SERVER}' DEFAULT_PORT = 23333 BASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)]) @pytest.mark.parametrize('backend', BACKEND_LIST) @pytest.mark.parametrize('model_name', RESTFUL_MODEL_LIST) class TestGenerateComprehensive: @pytest.fixture(autouse=True) def setup_api(self, request, config, model_name, backend): self.api_url = f'{BASE_URL}/generate' self.headers = {'Content-Type': 'application/json'} self.model_name = model_name test_name = request.node.name safe_test_name = re.sub(r'[^\w\.-]', '_', test_name) safe_model_name = self.model_name.replace('/', '_') log_base = config.get('log_path', './logs') self.log_dir = os.path.join(log_base, safe_model_name) os.makedirs(self.log_dir, exist_ok=True) self.log_file = os.path.join(self.log_dir, f'{backend}_{safe_test_name}.log') def _log_request_response(self, payload, response_data, stream_raw=None): log_entry = { 'timestamp': datetime.now().isoformat(), 'model': self.model_name, 'request': payload, 'response': response_data, } if stream_raw is not None: log_entry['stream_raw'] = stream_raw try: with open(self.log_file, 'a', encoding='utf-8') as f: json.dump(log_entry, f, indent=2, ensure_ascii=False) f.write('\n') except Exception as e: print(f'[LOG WARN] Failed to write {self.log_file}: {e}') def _post(self, payload, stream=False): if 'model' not in payload: payload['model'] = self.model_name resp = requests.post(self.api_url, json=payload, headers=self.headers, stream=stream, timeout=60) resp.raise_for_status() if stream: raw_content = '' for chunk in resp.iter_content(chunk_size=None): if chunk: raw_content += chunk.decode('utf-8') events = parse_sse_stream(raw_content) accumulated_text = '' output_ids = [] stream_events_count = 0 for event in events: if event == '[DONE]': break try: data_str = event.replace('data: ', '').strip() if not data_str: continue data = json.loads(data_str) delta = data.get('text', '') if isinstance(delta, str): accumulated_text += delta ids = data.get('output_ids') if isinstance(ids, list): output_ids.extend(ids) stream_events_count += 1 except Exception as e: print(f'Error parsing stream event: {e}') continue fake_resp = { 'text': accumulated_text, 'output_ids': output_ids, 'meta_info': { 'stream_events': stream_events_count } } self._log_request_response(payload, fake_resp, raw_content) class MockResp: def json(self): return fake_resp @property def status_code(self): return 200 return MockResp() else: data = resp.json() self._log_request_response(payload, data) return resp def _validate_generation_response(self, data: dict[str, Any], expected_fields: list[str] | None = None, validate_tokens: bool = True, expect_logprobs: bool = False, validate_experts: bool = False) -> None: assert isinstance(data, dict), f'Response should be a dict, got {type(data)}' required_fields = ['text'] for field in required_fields: assert field in data, f'Missing required field: {field}' assert data[field] is not None, f'Field {field} should not be None' assert isinstance(data['text'], str), \ f"text should be string, got {type(data['text'])}" if validate_experts: assert 'routed_experts' in data[ 'meta_info'], "Response should contain 'routed_experts' when validate_experts=True" experts_data = data['meta_info']['routed_experts'] assert isinstance(experts_data, list) assert len(experts_data) > 0 total_steps = len(experts_data) for step_idx in range(total_steps): token_experts = experts_data[step_idx] assert isinstance(token_experts, list) assert len(token_experts) > 0 for layer_idx in range(len(token_experts)): layer_experts = token_experts[layer_idx] assert isinstance(layer_experts, list) assert len(layer_experts) == 8 for expert_idx, expert_id in enumerate(layer_experts): assert isinstance(expert_id, int) assert 0 <= expert_id < 256, f'Invalid expert_id: {expert_id}. Must be in [0, 256)' if validate_tokens: assert 'output_ids' in data, "Response should contain 'output_ids'" output_ids = data['output_ids'] assert isinstance(output_ids, list), \ f'output_ids should be list, got {type(output_ids)}' assert len(output_ids) >= 0, 'output_ids should not be empty' for i, token_id in enumerate(output_ids): assert isinstance(token_id, int), \ f'output_ids[{i}] should be int, got {type(token_id)}' if 'meta_info' in data: meta = data['meta_info'] assert isinstance(meta, dict), 'meta_info should be dict' if 'completion_tokens' in meta: assert meta['completion_tokens'] == len(output_ids), \ f"meta.completion_tokens ({meta['completion_tokens']}) " \ f'should equal len(output_ids) ({len(output_ids)})' if expect_logprobs: assert 'meta_info' in data, \ "Response should contain 'meta_info' when expecting logprobs" meta = data['meta_info'] assert isinstance(meta, dict) assert 'output_token_logprobs' in meta, \ "meta_info missing 'output_token_logprobs'" logprobs_data = meta['output_token_logprobs'] assert isinstance(logprobs_data, list), \ 'output_token_logprobs should be a list' assert len(logprobs_data) > 0, \ 'output_token_logprobs should not be empty' if 'output_ids' in data: assert len(logprobs_data) == len(data['output_ids']), \ f'Logprobs outer list length ({len(logprobs_data)}) != ' \ f"Output IDs length ({len(data['output_ids'])})" for idx, item in enumerate(logprobs_data): assert isinstance(item, list), \ f'Logprobs item at index {idx} should be a list, got {type(item)}' assert len(item) == 2, \ f'Logprobs item at index {idx} should have 2 elements ' \ f'[logprob, token_id], got {len(item)}' logprob_val = item[0] assert isinstance(logprob_val, (float, int)), \ f'Logprob value at [{idx}][0] should be number, ' \ f'got {type(logprob_val)}' assert logprob_val <= 0, \ f'Logprob value should be <= 0, got {logprob_val}' token_id_in_logprob = item[1] assert isinstance(token_id_in_logprob, int), \ f'Token ID in logprobs at [{idx}][1] should be int, ' \ f'got {type(token_id_in_logprob)}' if 'output_ids' in data and idx < len(data['output_ids']): assert token_id_in_logprob == data['output_ids'][idx], \ f'Token ID mismatch at index {idx}: output_ids has ' \ f"{data['output_ids'][idx]}, but logprobs has " \ f'{token_id_in_logprob}' if expected_fields: for field in expected_fields: assert field in data, f'Missing expected field: {field}' if 'error' in data: assert not data['error'], f"Response contains error: {data['error']}" if 'code' in data and data['code'] != 0: assert False, f"Response contains error code: {data['code']}" def test_basic_generation(self): print(f'\n[Model: {self.model_name}] Running basic generation test') test_cases = [{ 'name': 'simple prompt', 'payload': { 'prompt': 'The sky is', 'max_tokens': 5 }, }, { 'name': 'prompt with spaces', 'payload': { 'prompt': ' Hello world ', 'max_tokens': 3 }, }, { 'name': 'unicode prompt', 'payload': { 'prompt': 'Hello, world', 'max_tokens': 3 }, }, { 'name': 'longer generation', 'payload': { 'prompt': 'Once upon a time', 'max_tokens': 10 }, }] for test_case in test_cases: test_name = test_case['name'] print(f'\n[Test: {test_name}]') resp = self._post(test_case['payload']) data = resp.json() self._validate_generation_response(data=data, validate_tokens=True) prompt = test_case['payload']['prompt'] generated_text = data['text'] assert generated_text != prompt.strip(), \ f"Generated text should be different from prompt: '{generated_text}'" if 'output_ids' in data: output_ids = data['output_ids'] max_tokens = test_case['payload']['max_tokens'] max_allowed = max_tokens + 1 assert len(output_ids) <= max_allowed, \ f'Too many tokens generated: {len(output_ids)} > {max_allowed}' meta = data.get('meta_info', {}) finish_type = meta.get('finish_reason', {}).get('type') if len(output_ids) >= max_tokens and finish_type != 'length': print(f'[WARN] Generated {len(output_ids)} tokens but ' f"finish_reason is not 'length': {finish_type}") print(f" Generated text: '{generated_text[:50]}...'") print(f" Generated tokens: {len(data.get('output_ids', []))}") def test_input_ids_mode(self, config): print(f'\n[Model: {self.model_name}] Running input_ids mode test') model_path = os.path.join(config.get('model_path'), self.model_name) test_cases = [{ 'name': 'simple text', 'text': 'Hello world', 'max_tokens': 5, 'expected_min_text': 3 }, { 'name': 'question', 'text': 'What is the meaning of life?', 'max_tokens': 8, 'expected_min_text': 5 }, { 'name': 'short input', 'text': 'Yes', 'max_tokens': 3, 'expected_min_text': 1 }] for test_case in test_cases: test_name = test_case['name'] print(f'\n[Test: input_ids - {test_name}]') try: input_ids = encode_text(model_path, test_case['text']) except Exception as e: pytest.skip(f'Tokenizer failed for {test_name}: {e}') assert isinstance(input_ids, list), \ f'input_ids should be list, got {type(input_ids)}' assert len(input_ids) > 0, 'input_ids should not be empty' for i, token_id in enumerate(input_ids): assert isinstance(token_id, int), \ f'input_ids[{i}] should be int, got {type(token_id)}' assert token_id >= 0, \ f'input_ids[{i}] should be >= 0, got {token_id}' resp = self._post({'input_ids': input_ids, 'max_tokens': test_case['max_tokens']}) data = resp.json() self._validate_generation_response(data=data, validate_tokens=True) generated_text = data['text'] try: generated_text.encode('utf-8') except UnicodeEncodeError: pytest.fail(f'Generated text contains invalid UTF-8 characters: ' f'{generated_text[:100]}') print(f' Input tokens: {len(input_ids)}') print(f" Output tokens: {len(data.get('output_ids', []))}") print(f" Generated text: '{generated_text[:50]}...'") def test_conflict_prompt_and_input_ids(self): print(f'\n[Model: {self.model_name}] Running conflict test') test_cases = [{ 'name': 'both provided', 'payload': { 'prompt': 'Hello world', 'input_ids': [1, 2, 3, 4, 5], 'max_tokens': 5 }, 'expected_status': 400, 'expected_error_keywords': [ 'conflict', 'both', 'either', 'cannot', 'mutually exclusive', 'specify exactly one', 'prompt', 'input_ids' ] }, { 'name': 'prompt with empty input_ids', 'payload': { 'prompt': 'Test', 'input_ids': [], 'max_tokens': 3 }, 'expected_status': 400, 'expected_error_keywords': ['conflict', 'invalid', 'empty', 'specify exactly one', 'prompt', 'input_ids'] }, { 'name': 'empty prompt with input_ids', 'payload': { 'prompt': '', 'input_ids': [100, 200, 300], 'max_tokens': 3 }, 'expected_status': 400, 'expected_error_keywords': ['conflict', 'empty', 'invalid', 'specify exactly one', 'prompt', 'input_ids'] }] for test_case in test_cases: test_name = test_case['name'] print(f'\n[Test: conflict - {test_name}]') try: resp = requests.post(self.api_url, json=test_case['payload'], headers=self.headers, timeout=30) assert resp.status_code == test_case['expected_status'], \ f"Expected status {test_case['expected_status']}, " \ f'got {resp.status_code}' error_data = resp.json() assert 'error' in error_data or 'message' in error_data, \ "Error response should contain 'error' or 'message' field" error_msg = '' if 'error' in error_data: error_msg = str(error_data['error']).lower() elif 'message' in error_data: error_msg = str(error_data['message']).lower() keywords_found = any(keyword in error_msg for keyword in test_case['expected_error_keywords']) if not keywords_found: has_both_fields = ('prompt' in error_msg and 'input_ids' in error_msg) has_exclusivity = any(phrase in error_msg for phrase in [ 'only one', 'specify exactly', 'cannot both', 'mutually exclusive', 'exactly one', 'must specify' ]) if has_both_fields and has_exclusivity: keywords_found = True assert keywords_found, \ f'Error message should indicate conflict between prompt and ' \ f'input_ids, got: {error_msg}' assert 'text' not in error_data, \ "Error response should not contain 'text' field" assert 'output_ids' not in error_data, \ "Error response should not contain 'output_ids' field" print(f' Got expected error: {error_msg[:100]}...') except Exception as e: print(f' Unexpected error: {e}') raise @pytest.mark.logprob def test_input_ids_with_logprob(self, config): print(f'\n[Model: {self.model_name}] Running input_ids with logprob test') model_path = os.path.join(config.get('model_path'), self.model_name) test_cases = [{ 'name': 'basic logprob', 'text': 'The weather is', 'max_tokens': 3, 'expected_min_text': 3 }, { 'name': 'single token generation', 'text': 'Hello', 'max_tokens': 1, 'expected_min_text': 1 }, { 'name': 'multiple tokens with logprob', 'text': 'Artificial intelligence is', 'max_tokens': 5, 'expected_min_text': 5 }] for test_case in test_cases: test_name = test_case['name'] print(f'\n[Test: logprob - {test_name}]') try: input_ids = encode_text(model_path, test_case['text']) except Exception as e: pytest.skip(f'Tokenizer failed for {test_name}: {e}') request_payload = {'input_ids': input_ids, 'max_tokens': test_case['max_tokens'], 'return_logprob': True} resp = self._post(request_payload) data = resp.json() self._validate_generation_response(data=data, validate_tokens=True, expect_logprobs=True) assert 'meta_info' in data, \ "Response should contain 'meta_info' when return_logprob=True" meta = data['meta_info'] assert 'output_token_logprobs' in meta, \ "meta_info should contain 'output_token_logprobs'" logprobs = meta['output_token_logprobs'] logprob_values = [] for i, item in enumerate(logprobs): logprob_values.append(item[0]) avg_logprob = sum(logprob_values) / len(logprob_values) if avg_logprob < -15.0: pytest.fail(f'Generation confidence critically low ' f'(Avg: {avg_logprob:.2f})') generated_text = data.get('text', '') print(f' Generated tokens: {len(logprob_values)}') print(f' Avg Logprob: {avg_logprob:.3f}') print(f" Generated text: '{generated_text[:50]}...'") def test_stop_str_with_include_flag(self): print(f'\n[Model: {self.model_name}] Running stop_str with include flag test') test_cases = [{ 'name': 'simple stop word', 'prompt': 'Count to 10: 1, 2, 3, ', 'stop_word': '6', 'max_tokens': 20, }] for test_case in test_cases: test_name = test_case['name'] print(f'\n[Test: stop_str - {test_name}]') prompt = test_case['prompt'] stop_word = test_case['stop_word'] max_tokens = test_case['max_tokens'] print(' Testing EXCLUDE mode (include_stop=False)...') resp1 = self._post({ 'prompt': prompt, 'max_tokens': max_tokens, 'stop': [stop_word], 'include_stop_str_in_output': False, 'return_logprob': True }) self._validate_generation_response(resp1.json()) text_exclude = resp1.json()['text'] assert stop_word not in text_exclude, \ f"Stop word '{stop_word}' should NOT be in output when include_stop=False" print(' Testing INCLUDE mode (include_stop=True)...') resp2 = self._post({ 'prompt': prompt, 'max_tokens': max_tokens, 'stop': [stop_word], 'include_stop_str_in_output': True, 'return_logprob': True }) self._validate_generation_response(resp2.json()) text_include = resp2.json()['text'] assert stop_word in text_include, \ f"Stop word '{stop_word}' should be in output when include_stop=True" def test_streaming_mode(self): print(f'\n[Model: {self.model_name}] Running streaming mode test') prompt = 'Count to 10: 1, 2,' resp = self._post({'prompt': prompt, 'max_tokens': 8, 'stream': True}, stream=True) assert resp.status_code == 200 data = resp.json() text = data['text'] output_ids = data['output_ids'] meta = data['meta_info'] assert isinstance(text, str) and len(text.strip()) > 0, \ 'Generated text cannot be empty' assert len(output_ids) >= 3, 'Output token count should be reasonable' import re count_matches = len(re.findall(r'\b[3-9]\b', text)) assert count_matches >= 2, \ f'Expected continuation of counting, but not enough numbers found ' \ f'(found {count_matches})' stream_events = meta.get('stream_events', []) assert stream_events <= len(output_ids) + 2, \ 'Streaming event count should be less than output token count' print(f" Generated text: '{text}'") print(f' Output tokens: {len(output_ids)}, ' f'Stream events: {stream_events}') def test_streaming_incremental_correctness(self): print(f'\n[Model: {self.model_name}] Running streaming incremental correctness test') prompt = 'The sky is ' raw_resp = requests.post(self.api_url, json={ 'prompt': prompt, 'max_tokens': 10, 'stream': True }, headers=self.headers, stream=True, timeout=30) raw_resp.raise_for_status() full_text_from_delta = '' tokens_from_delta = [] event_count = 0 print(' Streaming chunks:') for line in raw_resp.iter_lines(): if line: line_str = line.decode('utf-8').strip() if line_str.startswith('data: ') and '[DONE]' not in line_str: try: json_str = line_str[6:] payload = json.loads(json_str) delta_text = payload.get('text', '') token_id = payload.get('token_id') full_text_from_delta += delta_text if token_id is not None: tokens_from_delta.append(token_id) event_count += 1 if delta_text.strip(): print(f"+'{delta_text}'") except Exception as e: print(f'[Parse warning]: {e}') continue assert len(full_text_from_delta.strip()) > 0, \ 'Assembled text from streaming deltas is empty' assert event_count >= 3, \ f'Too few streaming events received ({event_count}), ' \ f'connection might be interrupted' print(f" Final assembled text: '{full_text_from_delta}'") print(f' Total events received: {event_count}') @pytest.mark.logprob def test_return_logprob(self): print(f'\n[Model: {self.model_name}] Running return_logprob test') resp = self._post({'prompt': 'Paris is the capital of', 'max_tokens': 2, 'return_logprob': True}) data = resp.json() self._validate_generation_response(data, validate_tokens=True, expect_logprobs=True) print(f" Generated text: '{data['text']}'") def test_same_session_id_allowed(self): print(f'\n[Model: {self.model_name}] Running same session_id test') sid = int(time.time_ns()) % 100000 resp1 = self._post({'prompt': 'First message:', 'session_id': sid, 'max_tokens': 2}) resp2 = self._post({'prompt': 'Second message:', 'session_id': sid, 'max_tokens': 2}) assert resp1.status_code == 200 assert resp2.status_code == 200 data1 = resp1.json() data2 = resp2.json() self._validate_generation_response(data1) self._validate_generation_response(data2) text1 = data1['text'].strip() text2 = data2['text'].strip() assert text1 != text2 print(f" First response: '{data1['text']}'") print(f" Second response: '{data2['text']}'") def test_empty_prompt_rejected(self): print(f'\n[Model: {self.model_name}] Running empty prompt test') with pytest.raises(requests.HTTPError) as exc: self._post({'prompt': '', 'max_tokens': 5}) assert exc.value.response.status_code == 400 try: error_response = exc.value.response.json() print(f' Error response: {error_response}') assert 'error' in error_response or 'message' in error_response except json.JSONDecodeError: print(f' Non-JSON error: {exc.value.response.text[:100]}') def test_input_ids_rejected(self): print(f'\n[Model: {self.model_name}] Running input_ids invalid cases test') invalid_cases = [{ 'case': { 'input_ids': [], 'max_tokens': 5 }, 'desc': 'Empty input_ids list' }, { 'case': { 'input_ids': 'not_a_list', 'max_tokens': 5 }, 'desc': 'input_ids is a string, not list' }, { 'case': { 'max_tokens': 5 }, 'desc': 'Missing input_ids field' }] for invalid_case in invalid_cases: test_desc = invalid_case['desc'] payload = invalid_case['case'] with pytest.raises(requests.HTTPError) as exc_info: self._post(payload) response = exc_info.value.response assert response.status_code in [400, 422], (f"Bad Request for case '{test_desc}', " f'but got {response.status_code}') def test_stress_concurrent_requests(self): print(f'\n[Model: {self.model_name}] Running stress concurrent requests test') def single_request(idx): start_time = time.time() try: resp = requests.post(self.api_url, json={ 'prompt': f'Hello, task {idx}', 'max_tokens': 5, 'stream': False }, headers=self.headers, timeout=10) resp.raise_for_status() data = resp.json() if 'text' in data and len(data['text'].strip()) > 0: latency = time.time() - start_time return {'success': True, 'latency': latency} else: return {'success': False, 'error': 'Empty response'} except Exception as e: return {'success': False, 'error': str(e)} success_count = 0 total_latency = 0 failures = [] with ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(single_request, i) for i in range(20)] for i, future in enumerate(as_completed(futures)): result = future.result() if result['success']: success_count += 1 total_latency += result['latency'] print(f" Req {i}: ✓ (Latency: {result['latency']:.2f}s)") else: failures.append(result['error']) print(f' Req {i}: ✗') success_rate = success_count / 20 assert success_rate == 1.0, \ f'Stress test failed: success rate {success_rate*100}% < 80%' if success_count > 0: avg_latency = total_latency / success_count assert avg_latency < 5.0, \ f'Average latency too high: {avg_latency:.2f}s' print(f' Performance: Avg Latency={avg_latency:.2f}s') print(f' Summary: {success_count}/20 succeeded') def test_stress_long_prompt_and_generation(self): print(f'\n[Model: {self.model_name}] Running stress long prompt test') long_prompt = 'Summarize: The quick brown fox jumps over the lazy dog. ' * 100 resp = self._post({'prompt': long_prompt, 'max_tokens': 512, 'temperature': 0.7}) data = resp.json() self._validate_generation_response(data=data, validate_tokens=True) def test_stress_streaming_under_load(self): print(f'\n[Model: {self.model_name}] Running stress streaming under load test') def stream_request(idx): try: resp = requests.post(self.api_url, json={ 'prompt': f'Stream load test {idx}', 'max_tokens': 10, 'stream': True }, headers=self.headers, stream=True, timeout=30) assert resp.status_code == 200 content_type = resp.headers.get('Content-Type', '') assert 'text/event-stream' in content_type or \ 'application/x-ndjson' in content_type full_text = '' event_count = 0 for line in resp.iter_lines(): if line and line.startswith(b'data:'): event_count += 1 if b'[DONE]' in line: break try: payload = json.loads(line.decode().replace('data: ', '', 1)) full_text += payload.get('text', '') except Exception: pass assert len(full_text) > 0 assert event_count >= 3 return True except Exception as e: print(f' Stream {idx} error: {e}') return False with ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(stream_request, i) for i in range(10)] results = [f.result() for f in futures] success_count = sum(results) assert success_count == 10, \ f'Concurrent streaming test failure rate too high: {success_count}/10' print(f' Streaming under load: {success_count}/10 succeeded') def test_temperature_parameter(self): print(f'\n[Model: {self.model_name}] Running temperature parameter test') prompt = 'The capital of France is' resp_low = self._post({'prompt': prompt, 'max_tokens': 10, 'temperature': 0.1, 'stream': False}) resp_high = self._post({'prompt': prompt, 'max_tokens': 10, 'temperature': 0.9, 'stream': False}) data_low = resp_low.json() data_high = resp_high.json() self._validate_generation_response(data=data_low, validate_tokens=True) self._validate_generation_response(data=data_high, validate_tokens=True) assert 'Paris' in data_low['text'] or \ 'paris' in data_low['text'].lower(), \ "Low temperature didn't answer correct capital" assert data_low['text'] != data_high['text'], \ 'High and low temperature outputs identical, ' \ 'temperature may not be effective' def test_top_p_parameter(self): print(f'\n[Model: {self.model_name}] Running top_p parameter test') prompt = 'The weather today is' resp_strict = self._post({'prompt': prompt, 'max_tokens': 20, 'top_p': 0.01, 'stream': False}) resp_loose = self._post({'prompt': prompt, 'max_tokens': 20, 'top_p': 0.99, 'stream': False}) text_strict = resp_strict.json() text_loose = resp_loose.json() self._validate_generation_response(data=text_strict, validate_tokens=True) self._validate_generation_response(data=text_loose, validate_tokens=True) def test_top_k_parameter(self): print(f'\n[Model: {self.model_name}] Running top_k parameter test') prompt = 'Artificial intelligence' resp_k10 = self._post({'prompt': prompt, 'max_tokens': 10, 'top_k': 10, 'stream': False}) resp_k50 = self._post({'prompt': prompt, 'max_tokens': 10, 'top_k': 50, 'stream': False}) text_k10 = resp_k10.json() text_k50 = resp_k50.json() self._validate_generation_response(data=text_k10, validate_tokens=True) self._validate_generation_response(data=text_k50, validate_tokens=True) def test_min_p_parameter(self): print(f'\n[Model: {self.model_name}] Running min_p parameter test') prompt = 'Machine learning is' resp = self._post({'prompt': prompt, 'max_tokens': 10, 'min_p': 0.05, 'stream': False}) data = resp.json() self._validate_generation_response(data) def test_repetition_penalty(self): print(f'\n[Model: {self.model_name}] Running repetition penalty test') prompt = 'Repeat repeat repeat repeat' resp_no_penalty = self._post({'prompt': prompt, 'max_tokens': 10, 'repetition_penalty': 1.0, 'stream': False}) resp_penalty = self._post({'prompt': prompt, 'max_tokens': 10, 'repetition_penalty': 1.5, 'stream': False}) text_no_penalty = resp_no_penalty.json()['text'] text_penalty = resp_penalty.json()['text'] def count_repeats(text): words = text.lower().split() return sum(1 for i in range(1, len(words)) if words[i] == words[i - 1]) repeats_no_penalty = count_repeats(text_no_penalty) repeats_penalty = count_repeats(text_penalty) assert repeats_penalty <= repeats_no_penalty, ( f'High penalty coefficient ({1.5}) repetition count ({repeats_penalty}) ' f'not less than low penalty ({1.0}) count ({repeats_no_penalty}), ' f'repetition_penalty ineffective') def test_ignore_eos_parameter(self): print(f'\n[Model: {self.model_name}] Running ignore_eos parameter test') prompt = 'The sky is blue.' resp_normal = self._post({'prompt': prompt, 'ignore_eos': False, 'stream': False}) data_normal = resp_normal.json() self._validate_generation_response(data_normal) resp_ignore = self._post({'prompt': prompt, 'ignore_eos': True, 'stream': False}) data_ignore = resp_ignore.json() self._validate_generation_response(data_ignore) reason_ignore = data_ignore.get('meta_info', {}).get('finish_reason', {}).get('type', 'unknown') assert reason_ignore == 'length', \ f'ignore_eos=True must end due to length, actual: {reason_ignore}' def test_skip_special_tokens(self, config): print(f'[Model: {self.model_name}] Running skip_special_tokens test') model_path = os.path.join(config.get('model_path'), self.model_name) user_content = 'Hello [world]! This is a [test].' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) special_tokens_map = tokenizer.special_tokens_map special_patterns = list(special_tokens_map.values()) special_patterns = [ item for sublist in special_patterns for item in (sublist if isinstance(sublist, list) else [sublist]) ] print('Special patterns:', special_patterns) print(' Executing skip_special_tokens=True') payload_true = {'prompt': user_content, 'max_tokens': 100, 'skip_special_tokens': True, 'stream': False} resp_true = self._post(payload_true) data_true = resp_true.json() self._validate_generation_response(data=data_true, validate_tokens=True) generated_text = data_true['text'] assert not any(pattern in generated_text for pattern in special_patterns), \ 'Expected no special pattern in the generated text but found one.' def test_stop_token_ids(self): print(f'\n[Model: {self.model_name}] Running stop_token_ids test') payload = {'prompt': 'Once upon a time', 'max_tokens': 500, 'stop_token_ids': [11, 281], 'stream': False} resp = self._post(payload) assert resp.status_code == 200, \ f'HTTP request failed, status code: {resp.status_code}' try: data = resp.json() except Exception as e: pytest.fail(f'Response JSON parsing failed: {e}') self._validate_generation_response(data) generated_text = data.get('text', '') finish_reason = data.get('meta_info', {}).get('finish_reason', {}).get('type', 'unknown') actual_length = len(generated_text) print(f'\n stop_token_ids=[11, 281] generation result: length={actual_length}, ' f"end reason='{finish_reason}', text='{generated_text[:20]}...'") assert finish_reason in ['stop'], \ f'Expected generation to end due to stop token, ' \ f'actual reason: {finish_reason}. This may mean stop_token_ids [11, 281] ' \ f"didn't take effect, or generation was truncated." def test_combined_parameters(self): print(f'\n[Model: {self.model_name}] Running combined parameters test') resp = self._post({ 'prompt': 'The future of AI', 'max_tokens': 15, 'temperature': 0.7, 'top_p': 0.9, 'top_k': 40, 'repetition_penalty': 1.1, 'stream': False }) assert resp.status_code == 200 data = resp.json() self._validate_generation_response(data) def test_streaming_with_all_parameters(self): print(f'\n[Model: {self.model_name}] Running streaming with all parameters test') resp = self._post( { 'prompt': 'Streaming test with parameters', 'max_tokens': 10, 'temperature': 0.8, 'top_p': 0.85, 'top_k': 30, 'repetition_penalty': 1.2, 'stop': ['test'], 'stream': True }, stream=True) assert resp.status_code == 200 data = resp.json() self._validate_generation_response(data) stream_events = data['meta_info'].get('stream_events', []) assert stream_events <= len(data['output_ids']) + 2, \ 'Streaming event count should be less than generated token count' def test_invalid_temperature_values(self): print(f'\n[Model: {self.model_name}] Running invalid temperature values test') resp1 = self._post({'prompt': 'Test', 'max_tokens': 3, 'temperature': 0.0, 'stream': False}) assert resp1.status_code == 200, 'temperature=0.0 should be valid' with pytest.raises(requests.HTTPError) as exc_info: self._post({'prompt': 'Test', 'max_tokens': 3, 'temperature': -0.5, 'stream': False}) assert exc_info.value.response.status_code in [400, 422] print(' Invalid temperature values test passed') def test_invalid_top_p_values(self): print(f'\n[Model: {self.model_name}] Running invalid top_p values test') with pytest.raises(requests.HTTPError) as exc_info: self._post({'prompt': 'Test', 'max_tokens': 3, 'top_p': 1.5, 'stream': False}) assert exc_info.value.response.status_code in [400, 422] print(' Invalid top_p values test passed') def test_invalid_top_k_values(self): print(f'\n[Model: {self.model_name}] Running invalid top_k values test') with pytest.raises(requests.HTTPError) as exc_info: self._post({'prompt': 'Test', 'max_tokens': 3, 'top_k': -5, 'stream': False}) assert exc_info.value.response.status_code in [400, 422] print(' Invalid top_k values test passed') def test_boundary_max_tokens(self): print(f'\n[Model: {self.model_name}] Running boundary max_tokens test') resp1 = self._post({'prompt': 'Min tokens', 'max_tokens': 1, 'stream': False}) assert resp1.status_code == 200 data1 = resp1.json() assert data1['meta_info']['completion_tokens'] >= 1 resp2 = self._post({'prompt': 'Max tokens test', 'max_tokens': 2048, 'stream': False}) assert resp2.status_code == 200 with pytest.raises(requests.HTTPError) as exc: self._post({'prompt': 'Test', 'max_tokens': -2, 'stream': False}) assert exc.value.response.status_code == 400 with pytest.raises(requests.HTTPError) as exc: self._post({'prompt': 'Test', 'max_tokens': 0, 'stream': False}) assert exc.value.response.status_code == 400 print(' Max tokens boundary test passed') def test_parameter_interactions(self): print(f'\n[Model: {self.model_name}] Running parameter interactions test') resp1 = self._post({ 'prompt': 'Deterministic generation', 'max_tokens': 10, 'temperature': 0.0, 'top_p': 0.5, 'top_k': 10, 'stream': False }) assert resp1.status_code == 200 data1 = resp1.json() self._validate_generation_response(data1) print(' Parameter interaction (temp=0 with top_p/k) passed') def test_session_id_with_all_parameters(self): print(f'\n[Model: {self.model_name}] Running session_id with all parameters test') session_id = int(time.time_ns()) % 100000 resp1 = self._post({ 'session_id': session_id, 'prompt': 'Hello, introduce yourself briefly.', 'max_tokens': 20, 'temperature': 0.7, 'stream': False }) assert resp1.status_code == 200 data1 = resp1.json() self._validate_generation_response(data1) resp2 = self._post({ 'session_id': session_id, 'prompt': 'What was I just talking about?', 'max_tokens': 20, 'temperature': 0.7, 'stream': False }) assert resp2.status_code == 200 data2 = resp2.json() self._validate_generation_response(data2) assert 'What' in data2['text'] or 'hello' in data2['text'].lower() or \ len(data2['text']) > 0 print(f' Session {session_id} test passed') def test_edge_cases_stop_conditions(self): print(f'\n[Model: {self.model_name}] Running edge cases stop conditions test') resp1 = self._post({'prompt': 'Test with empty stop list', 'max_tokens': 10, 'stop': [], 'stream': False}) assert resp1.status_code == 200 data1 = resp1.json() assert len(data1['text']) > 0 resp2 = self._post({ 'prompt': 'Write a sentence ending with a period. Stop here test.', 'max_tokens': 200, 'stop': ['.'], 'stream': False }) assert resp2.status_code == 200 data2 = resp2.json() text2 = data2['text'] finish_reason = data2['meta_info']['finish_reason']['type'] assert '. ' not in text2 and not text2.strip().endswith( '.'), "Stop token '.' should cause generation to end at period" assert not text2.strip().endswith('.'), "Stop token '.' should cause generation to end at period" assert finish_reason in ['stop', 'eos'], \ f'Expected to end due to stop token, actual: {finish_reason}, content is {text2}' print(f" Stop at '.': generated '{text2}' (Reason: {finish_reason})") def test_spaces_between_special_tokens(self, config): print(f'[Model: {self.model_name}] Running spaces_between_special_tokens test') model_path = os.path.join(config.get('model_path'), self.model_name) user_content = 'Hello [world]! This is a [test].' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) special_tokens_map = tokenizer.special_tokens_map special_patterns = list(special_tokens_map.values()) special_patterns = [ item for sublist in special_patterns for item in (sublist if isinstance(sublist, list) else [sublist]) ] print(' Executing skip_special_tokens=False and checking spaces between special tokens') payload_false = {'prompt': user_content, 'max_tokens': 100, 'skip_special_tokens': False, 'stream': False} resp_false = self._post(payload_false) data_false = resp_false.json() self._validate_generation_response(data=data_false, validate_tokens=True) generated_text = data_false['text'] for i in range(len(generated_text) - 1): if generated_text[i] in special_patterns and generated_text[i + 1] not in [' ', '\n']: assert False, f'Expected space after special token {generated_text[i]} but found none.' @pytest.mark.experts @pytest.mark.not_turbomind def test_request_returns_experts(self): print(f'\n[Model: {self.model_name}] Running request with experts test') resp1 = self._post({ 'prompt': 'Deterministic generation', 'max_tokens': 50, 'temperature': 0.8, 'return_routed_experts': True }) assert resp1.status_code == 200 data1 = resp1.json() self._validate_generation_response(data1, validate_experts=True) ================================================ FILE: autotest/prompt_case.yml ================================================ identity: - 你好,你叫什么名字#hi, what's your name: memory_test: - 简要介绍乌鲁木齐的景点#A brief introduction to Urumqi’s attractions: - contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - uwumqi - Ürümqi - 介绍它的相应美食#please introduce some delicious foods: - contain: - urumqi - 乌鲁木齐 - 乌市 - xinjiang - 新疆 - uwumqi - Ürümqi chinese_poem_case: - 给我一首中文诗,需要添加标点符号,请用中文回答Give me a Chinese poem in Chinese: - contain: - "," - "。" - poem - poetry - \n - len_g: 5 english_poem_case: - write a romantic English poem in English: - contain: - " " - contain: - "." - "," - len_g: 30 emoji_case: - 请输出👍赞的emoji#print output the emoji of good👍: - contain: - 👍 - 😊 - 😀 - 🎉 - 👏 - 👌 - good - like - 赞 - 好 - '!' - u1f44d - 🌟 traditional_chinese_case: - 介紹澳門景點,使用繁體: - contain: - 澳門 - 景點 - 澳门 - macau code_testcase: - 使用python编写一个int数组的冒泡排序代码: - contain: - def - bubble - 冒泡 - code - python - llama2: - contain: - def - bubble - 冒泡 - code - python - assist - however ================================================ FILE: autotest/pytest.ini ================================================ [pytest] python_files = test*_*.py # test file python_classes = Test* # test class python_functions = test_* # test function pytest_runtest_call.tryfirst = True filterwarnings = ignore::UserWarning reruns = 2 reruns_delay = 1 ================================================ FILE: autotest/template.json ================================================ { "model_name": "base", "capability": "completion" } ================================================ FILE: autotest/toolchain/test_lagent.py ================================================ import pytest @pytest.mark.order(10) @pytest.mark.lagent @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize('model', ['internlm/internlm2_5-7b-chat']) def test_repeat(config, model): from lagent.llms import INTERNLM2_META, LMDeployPipeline model = LMDeployPipeline( path='/'.join([config.get('model_path'), model]), meta_template=INTERNLM2_META, tp=1, top_k=40, top_p=0.8, temperature=1.2, stop_words=['<|im_end|>'], max_new_tokens=4096, ) response_list = [] for i in range(3): print(f'run_{i}:') response = model.chat([{ 'role': 'user', 'content': '已知$$z_{1}=1$$,$$z_{2}=\\text{i}$$,$$z_{3}=-1$$,$$z_{4}=-\\text{i}$$,顺次连结它们所表示的点,则所得图形围成的面积为( )\nA. $$\\dfrac{1}{4}$$\n B. $$\\dfrac{1}{2}$$\n C. $$1$$\n D. $$2$$\n\n' # noqa: F401, E501 }]) print(response) response_list.append(response) assert len(response) > 10 assert response_list[0] != response_list[1] and response_list[1] != response_list[2] ================================================ FILE: autotest/tools/chat/test_command_chat_hf_pytorch.py ================================================ import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2, PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2) from utils.config_utils import get_func_config_list, get_workerid from utils.run_client_chat import run_tests BACKEND = 'pytorch' @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list('pytorch', {'tp': 1})) def test_hf_pytorch_chat_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_hf_pytorch_chat_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_hf_pytorch_chat_tp4(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_hf_pytorch_chat_tp8(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16})) def test_hf_pytorch_chat_tp16(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, 'base_model')) def test_hf_pytorch_base_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, 'base_model')) def test_hf_pytorch_base_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2) def test_hf_pytorch_chat_pr_tp2(config, run_config, cli_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1) def test_hf_pytorch_chat_pr_tp1(config, run_config, cli_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_pytorch_chat_tp1(config, run_config, cli_case_config, worker_id): run_config['env'] = {'LMDEPLOY_USE_MODELSCOPE': 'True'} run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_pytorch_chat @pytest.mark.gpu_num_1 @pytest.mark.other @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1) def test_pytorch_chat_with_lora_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_pytorch_chat @pytest.mark.gpu_num_1 @pytest.mark.other @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2) def test_pytorch_chat_with_lora_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) ================================================ FILE: autotest/tools/chat/test_command_chat_hf_turbomind.py ================================================ import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, TURBOMIND_FALLBACK_TEST_LLM_GPU1, TURBOMIND_FALLBACK_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1, TURBOMIND_PR_TEST_LLM_GPU2) from utils.config_utils import get_func_config_list, get_workerid from utils.run_client_chat import run_tests BACKEND = 'turbomind' @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1})) def test_hf_turbomind_chat_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_hf_turbomind_chat_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_hf_turbomind_chat_tp4(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_hf_turbomind_chat_tp8(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1) def test_hf_turbomind_chat_fallback_backend_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2) def test_hf_turbomind_chat_fallback_backend_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, 'base_model')) def test_hf_turbomind_base_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, 'base_model')) def test_hf_turbomind_base_tp2(config, run_config, cli_case_config, worker_id): run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_2 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2) def test_hf_turbomind_chat_pr_tp2(config, run_config, cli_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1) def test_hf_turbomind_chat_pr_tp1(config, run_config, cli_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_turbomind_chat_tp1(config, run_config, cli_case_config, worker_id): run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id) ================================================ FILE: autotest/tools/common_case_config.py ================================================ TURBOMIND_PR_TEST_LLM_GPU2 = [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }] TURBOMIND_PR_TEST_LLM_GPU1 = [{ 'model': 'Qwen/Qwen3-0.6B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-0.6B-inner-4bits', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-8B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] TURBOMIND_PR_TEST_MLLM_GPU1 = [{ 'model': 'OpenGVLab/InternVL3-8B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'OpenGVLab/InternVL3-8B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] TURBOMIND_PR_TEST_MLLM_GPU2 = [{ 'model': 'OpenGVLab/InternVL3_5-30B-A3B', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'OpenGVLab/InternVL3_5-30B-A3B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }] TURBOMIND_FALLBACK_TEST_LLM_GPU1 = [{ 'model': 'THUDM/cogvlm-chat-hf', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'microsoft/Phi-3.5-vision-instruct', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] TURBOMIND_FALLBACK_TEST_LLM_GPU2 = [{ 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] TURBOMIND_FALLBACK_TEST_MLLM_GPU1 = [{ 'model': 'THUDM/glm-4v-9b', 'backend': 'turbomind', 'communicator': 'cuda-ipc', 'quant_policy': 4, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'THUDM/glm-4v-9b', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] TURBOMIND_LOGPROBS_TEST_LLM_GPU2 = [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'OpenGVLab/InternVL3-38B', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }] BASE_MODELSCOPE_CONFIG = [{ 'model': 'Qwen/Qwen2.5-7B-Instruct', 'communicator': 'cuda-ipc', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {}, 'env': { 'LMDEPLOY_USE_MODELSCOPE': 'True' } }, { 'model': 'Qwen/Qwen2.5-7B-Instruct', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {}, 'env': { 'LMDEPLOY_USE_MODELSCOPE': 'True' } }] MODELSCOPE_CONFIG = [{ **item, 'backend': 'turbomind' } for item in BASE_MODELSCOPE_CONFIG] + [{ **item, 'backend': 'pytorch' } for item in BASE_MODELSCOPE_CONFIG] PYTORCH_LORA_TEST_LLM_GPU1 = [{ 'model': 'meta-llama/Llama-2-7b-chat-hf', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': { 'adapters': { 'default': 'lora/Llama2-Chinese-7b-Chat-LoRA' } } }] PYTORCH_LORA_TEST_LLM_GPU2 = [{ 'model': 'baichuan-inc/Baichuan2-13B-Chat', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': { 'adapters': { 'a': 'lora/2024-01-25_self_dup', 'b': 'lora/2024-01-25_self' } } }] PYTORCH_PR_TEST_LLM_GPU2 = [{ 'model': 'Qwen/Qwen3-30B-A3B', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }, { 'model': 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': {} }] PYTORCH_PR_TEST_LLM_GPU1 = [{ 'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }, { 'model': 'Qwen/Qwen3-0.6B', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'tp': 1 }, 'extra_params': {} }] BASE_TOOLCALL_TEST_LLM = [{ 'model': 'Qwen/Qwen3-8B', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': { 'tool-call-parser': 'qwen' } }, { 'model': 'meta-llama/Meta-Llama-3-1-70B-Instruct', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 4 }, 'extra_params': { 'tool-call-parser': 'llama3' } }, { 'model': 'Qwen/Qwen3-30B-A3B', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': { 'tool-call-parser': 'qwen' } }] BASE_REASONING_TEST_LLM = [{ 'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': { 'reasoning-parser': 'qwen-qwq' } }, { 'model': 'Qwen/Qwen3-30B-A3B', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': { 'reasoning-parser': 'qwen-qwq' } }] TOOLCALL_TEST_LLM = [{ **item, 'backend': 'turbomind' } for item in BASE_TOOLCALL_TEST_LLM] + [{ **item, 'backend': 'pytorch' } for item in BASE_TOOLCALL_TEST_LLM] REASONING_TEST_LLM = [{ **item, 'backend': 'turbomind' } for item in BASE_REASONING_TEST_LLM] + [{ **item, 'backend': 'pytorch' } for item in BASE_REASONING_TEST_LLM] BASE_SPECULATIVE_DECODING_PIPELINE_TEST_LLM = [{ 'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': { 'max_batch_size': 128, 'speculative_config': { 'method': 'eagle3', 'num_speculative_tokens': 3, 'model': 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B' } } }] SPECULATIVE_DECODING_PIPELINE_TEST_LLM = [{ **item, 'backend': 'pytorch' } for item in BASE_SPECULATIVE_DECODING_PIPELINE_TEST_LLM] BASE_SPECULATIVE_DECODING_RESTFUL_TEST_LLM = [{ 'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 1 }, 'extra_params': { 'speculative-draft-model': 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', 'speculative-algorithm': 'eagle3', 'speculative-num-draft-tokens': 3, 'max-batch-size': 128 } }, { 'model': 'deepseek/DeepSeek-V3', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 16 }, 'extra_params': { 'speculative-algorithm': 'deepseek_mtp', 'speculative-num-draft-tokens': 3, 'max-batch-size': 128 } }] SPECULATIVE_DECODING_RESTFUL_TEST_LLM = [{ **item, 'backend': 'pytorch' } for item in BASE_SPECULATIVE_DECODING_RESTFUL_TEST_LLM] ================================================ FILE: autotest/tools/pipeline/llm_case.py ================================================ import json import os import fire import yaml from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline from lmdeploy.messages import SpeculativeConfig gen_config = GenerationConfig(max_new_tokens=500, min_new_tokens=10) def run_pipeline_chat_test(model_path, run_config, cases_path, is_pr_test: bool = False): backend = run_config.get('backend') communicator = run_config.get('communicator') quant_policy = run_config.get('quant_policy') extra_params = run_config.get('extra_params', {}) parallel_config = run_config.get('parallel_config', {}) if backend == 'pytorch': backend_config = PytorchEngineConfig(quant_policy=quant_policy) else: backend_config = TurbomindEngineConfig(communicator=communicator, quant_policy=quant_policy) # quant format model_lower = model_path.lower() if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower: backend_config.model_format = 'awq' elif 'gptq' in model_lower: backend_config.model_format = 'gptq' # Parallel config for para_key in ('dp', 'ep', 'cp'): if para_key in parallel_config: setattr(backend_config, para_key, parallel_config[para_key]) if 'tp' in parallel_config and parallel_config['tp'] > 1: backend_config.tp = parallel_config['tp'] # Extract speculative_config from extra_params if present speculative_config = None spec_cfg = extra_params.pop('speculative_config', None) if isinstance(spec_cfg, dict): speculative_config = SpeculativeConfig(**spec_cfg) # Extra params # Map CLI param names to PytorchEngineConfig attribute names param_name_map = {'device': 'device_type'} for key, value in extra_params.items(): attr_name = param_name_map.get(key, key) try: setattr(backend_config, attr_name, value) except AttributeError: print(f"Warning: Cannot set attribute '{attr_name}' on backend_config. Skipping.") print('backend_config config: ' + str(backend_config)) print('speculative_config config: ' + str(speculative_config)) pipe = pipeline(model_path, backend_config=backend_config, speculative_config=speculative_config) cases_path = os.path.join(cases_path) with open(cases_path) as f: cases_info = yaml.load(f.read(), Loader=yaml.SafeLoader) for case in cases_info.keys(): if is_pr_test and case != 'memory_test': continue if case != 'code_testcase' and 'code' in model_path.lower(): continue case_info = cases_info.get(case) prompts = [] response_list = [] for prompt_detail in case_info: prompt = list(prompt_detail.keys())[0] prompts.append({'role': 'user', 'content': prompt}) response = pipe([prompts], gen_config=gen_config, log_level='INFO', max_log_len=10)[0].text response_list.append({'prompt': prompt, 'response': response}) prompts.append({'role': 'assistant', 'content': response}) print(f'[caseresult {case} start]' + json.dumps(response_list, ensure_ascii=False) + f'[caseresult {case} end]\n') pipe.close() if __name__ == '__main__': fire.Fire() ================================================ FILE: autotest/tools/pipeline/mllm_case.py ================================================ import json import fire import numpy as np from PIL import Image from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline from lmdeploy.vl import encode_image_base64, load_image from lmdeploy.vl.constants import IMAGE_TOKEN gen_config = GenerationConfig(max_new_tokens=500, min_new_tokens=10) PIC1 = 'tiger.jpeg' PIC2 = 'human-pose.jpg' PIC_BEIJING = 'Beijing_Small.jpeg' PIC_CHONGQING = 'Chongqing_Small.jpeg' PIC_REDPANDA = 'redpanda.jpg' PIC_PANDA = 'panda.jpg' DESC = 'What are the similarities and differences between these two images.' DESC_ZH = '两张图有什么相同和不同的地方.' def run_pipeline_mllm_test(model_path, run_config, resource_path, is_pr_test: bool = False): backend = run_config.get('backend') communicator = run_config.get('communicator') quant_policy = run_config.get('quant_policy') extra_params = run_config.get('extra_params', {}) parallel_config = run_config.get('parallel_config', {}) if 'pytorch' == backend: backend_config = PytorchEngineConfig(session_len=65152, quant_policy=quant_policy, cache_max_entry_count=0.6) else: backend_config = TurbomindEngineConfig(session_len=65152, communicator=communicator, quant_policy=quant_policy, cache_max_entry_count=0.6) # quant format model_lower = model_path.lower() if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower: backend_config.model_format = 'awq' elif 'gptq' in model_lower: backend_config.model_format = 'gptq' # Parallel config for para_key in ('dp', 'ep', 'cp'): if para_key in parallel_config: setattr(backend_config, para_key, parallel_config[para_key]) if 'tp' in parallel_config and parallel_config['tp'] > 1: backend_config.tp = parallel_config['tp'] # Extra params # Map CLI param names to PytorchEngineConfig attribute names param_name_map = {'device': 'device_type'} for key, value in extra_params.items(): attr_name = param_name_map.get(key, key) try: setattr(backend_config, attr_name, value) except AttributeError: print(f"Warning: Cannot set attribute '{attr_name}' on backend_config. Skipping.") print('backend_config config: ' + str(backend_config)) pipe = pipeline(model_path, backend_config=backend_config) image = load_image(f'{resource_path}/{PIC1}') if 'deepseek' in model_lower: prompt = f'describe this image{IMAGE_TOKEN}' else: prompt = 'describe this image' response = pipe((prompt, image)).text print('[caseresult single1 start]' + json.dumps(response, ensure_ascii=False) + '[caseresult single1 end]\n') prompts = [{ 'role': 'user', 'content': [{ 'type': 'text', 'text': prompt }, { 'type': 'image_url', 'image_url': { 'url': f'{resource_path}/{PIC1}' } }] }] response = pipe(prompts, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult single2 start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult single2 end]\n') image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}'] images = [load_image(img_url) for img_url in image_urls] response = pipe((prompt, images)) print('[caseresult multi-imagese start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult multi-imagese end]\n') image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}'] prompts = [(prompt, load_image(img_url)) for img_url in image_urls] response = pipe(prompts, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult batch-example1 start]' + json.dumps(response[0].text, ensure_ascii=False) + '[caseresult batch-example1 end]\n') print('[caseresult batch-example2 start]' + json.dumps(response[1].text, ensure_ascii=False) + '[caseresult batch-example2 end]\n') image = load_image(f'{resource_path}/{PIC2}') sess = pipe.chat((prompt, image)) print('[caseresult multi-turn1 start]' + json.dumps(sess.response.text, ensure_ascii=False) + '[caseresult multi-turn1 end]\n') sess = pipe.chat('What is the woman doing?', session=sess) print('[caseresult multi-turn2 start]' + json.dumps(sess.response.text, ensure_ascii=False) + '[caseresult multi-turn2 end]\n') if not is_pr_test: if 'internvl' in model_path.lower() and 'internvl2-4b' not in model_path.lower(): internvl_vl_testcase(pipe, resource_path) internvl_vl_testcase(pipe, resource_path, lang='cn') if 'minicpm' in model_path.lower(): MiniCPM_vl_testcase(pipe, resource_path) if 'qwen' in model_path.lower(): Qwen_vl_testcase(pipe, resource_path) pipe.close() def internvl_vl_testcase(pipe, resource_path, lang='en'): if lang == 'cn': description = DESC_ZH else: description = DESC # multi-image multi-round conversation, combined images messages = [ dict(role='user', content=[ dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\n{description}'), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-combined-images-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-combined-images-{lang} end]\n') messages.append(dict(role='assistant', content=response.text)) messages.append(dict(role='user', content=description)) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-combined-images2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-combined-images2-{lang} end]\n') # multi-image multi-round conversation, separate images messages = [ dict( role='user', content=[ dict( type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\n' + # noqa E251,E501 description), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-separate-images-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-separate-images-{lang} end]\n') messages.append(dict(role='assistant', content=response.text)) messages.append(dict(role='user', content=description)) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-separate-images2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-separate-images2-{lang} end]\n') # video multi-round conversation def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array( [int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]) return frame_indices def load_video(video_path, bound=None, num_segments=32): import cv2 cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f'Cannot open video file: {video_path}') max_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 fps = cap.get(cv2.CAP_PROP_FPS) frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) imgs = [] for frame_index in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = Image.fromarray(rgb_frame).convert('RGB') imgs.append(img) cap.release() return imgs video_path = resource_path + '/red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' for i in range(len(imgs)): question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' if lang == 'cn': question += '视频里有什么动物,它在做什么?' else: question += 'What animals are in the video, and what are they doing?' content = [{'type': 'text', 'text': question}] for img in imgs: content.append({ 'type': 'image_url', 'image_url': { 'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}' # noqa E231 } }) messages = [dict(role='user', content=content)] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-video-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-video-{lang} end]\n') messages.append(dict(role='assistant', content=response.text)) if lang == 'cn': messages.append(dict(role='user', content='描述视频详情,不要重复')) else: messages.append(dict(role='user', content='Describe this video in detail. Don\'t repeat.')) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print(f'[caseresult internvl-video2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) + f'[caseresult internvl-video2-{lang} end]\n') def MiniCPM_vl_testcase(pipe, resource_path): # Chat with multiple images messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(max_slice_nums=9, url=f'{resource_path}/{PIC_REDPANDA}')), dict(type='image_url', image_url=dict(max_slice_nums=9, url=f'{resource_path}/{PIC_PANDA}')) ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult minicpm-combined-images start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult minicpm-combined-images end]\n') messages.append(dict(role='assistant', content=response.text)) messages.append(dict(role='user', content=DESC)) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult minicpm-combined-images2 start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult minicpm-combined-images2 end]\n') # In-context few-shot learning question = 'production date' messages = [ dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url=f'{resource_path}/data1.jpeg')), ]), dict(role='assistant', content='2021.08.29'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url=f'{resource_path}/data2.jpeg')), ]), dict(role='assistant', content='1999.05.15'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url=f'{resource_path}/data3.jpeg')), ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult minicpm-fewshot start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult minicpm-fewshot end]\n') # Chat with video MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number def encode_video(video_path): def uniform_sample(length, n): gap = len(length) / n idxs = [int(i * gap + gap / 2) for i in range(n)] return [length[i] for i in idxs] import cv2 cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f'Cannot open video file: {video_path}') fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) sample_fps = round(fps / 1) # FPS frame_idx = [i for i in range(0, total_frames, sample_fps)] if len(frame_idx) > MAX_NUM_FRAMES: frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) frames = [] for idx in frame_idx: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(rgb_frame.astype('uint8')).convert('RGB')) cap.release() print('num frames:', len(frames)) return frames video_path = resource_path + '/red-panda.mp4' frames = encode_video(video_path) question = 'What animals are in the video, and what are they doing?' content = [dict(type='text', text=question)] for frame in frames: content.append( dict(type='image_url', image_url=dict(use_image_id=False, max_slice_nums=2, url=f'data:image/jpeg;base64,{encode_image_base64(frame)}'))) # noqa E231 messages = [dict(role='user', content=content)] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult minicpm-video start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult minicpm-video end]\n') def Qwen_vl_testcase(pipe, resource_path): # multi-image multi-round conversation, combined images messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')), dict(type='image_url', image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}')) ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult qwen-combined-images start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult qwen-combined-images end]\n') messages.append(dict(role='assistant', content=response.text)) messages.append(dict(role='user', content=DESC)) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult qwen-combined-images2 start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult qwen-combined-images2 end]\n') # image resolution for performance boost min_pixels = 64 * 28 * 28 max_pixels = 64 * 28 * 28 messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url=f'{resource_path}/{PIC_BEIJING}')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url=f'{resource_path}/{PIC_CHONGQING}')) ]) ] response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult qwen-performance-images start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult qwen-performance-images end]\n') messages.append(dict(role='assistant', content=response.text)) messages.append(dict(role='user', content=DESC)) response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10) print('[caseresult qwen-performance-images2 start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult qwen-performance-images2 end]\n') if __name__ == '__main__': fire.Fire() ================================================ FILE: autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py ================================================ import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2, PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2, SPECULATIVE_DECODING_PIPELINE_TEST_LLM) from utils.config_utils import get_func_config_list, get_workerid from utils.pipeline_chat import run_pipeline_llm_test BACKEND = 'pytorch' @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1})) def test_pipeline_chat_tp1(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_pipeline_chat_tp2(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_pipeline_chat_tp4(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_pipeline_chat_tp8(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16})) def test_pipeline_chat_tp16(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None})) def test_pipeline_chat_pytorch_prefix_cache_tp2(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2) def test_hf_pytorch_chat_pr_tp2(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1) def test_hf_pytorch_chat_pr_tp1(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_pipeline_chat_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1) def test_pytorch_chat_with_lora_tp1(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2) def test_pytorch_chat_with_lora_tp2(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in SPECULATIVE_DECODING_PIPELINE_TEST_LLM if item['parallel_config'].get('tp') == 1]) def test_pipeline_chat_speculative_decoding_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) ================================================ FILE: autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py ================================================ import pytest from utils.config_utils import get_func_config_list from utils.pipeline_chat import run_pipeline_mllm_test BACKEND = 'pytorch' def get_models(parallel_config): return get_func_config_list(BACKEND, parallel_config, model_type='vl_model', extra={'session_len': 8192}) @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_models({'tp': 1})) def test_restful_chat_tp1(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_models({'tp': 2})) def test_restful_chat_tp2(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_models({'tp': 4})) def test_restful_chat_tp4(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_models({'tp': 8})) def test_restful_chat_tp8(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_16 @pytest.mark.parametrize('run_config', get_models({'tp': 16})) def test_restful_chat_tp16(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) ================================================ FILE: autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py ================================================ import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, TURBOMIND_FALLBACK_TEST_LLM_GPU1, TURBOMIND_FALLBACK_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1, TURBOMIND_PR_TEST_LLM_GPU2) from utils.config_utils import get_func_config_list, get_workerid from utils.pipeline_chat import run_pipeline_llm_test BACKEND = 'turbomind' @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1})) def test_pipeline_chat_tp1(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_pipeline_chat_tp2(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_pipeline_chat_tp4(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_pipeline_chat_tp8(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None})) def test_pipeline_chat_prefix_cache_tp2(config, run_config, common_case_config, worker_id): run_pipeline_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1) def test_pipeline_chat_fallback_backend_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2) def test_pipeline_chat_fallback_backend_tp2(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2) def test_pipeline_chat_pr_tp2(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1) def test_pipeline_chat_pr_tp1(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_pipeline_llm_test(config, run_config, case_config, worker_id) ================================================ FILE: autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py ================================================ import pytest from tools.common_case_config import (TURBOMIND_FALLBACK_TEST_MLLM_GPU1, TURBOMIND_PR_TEST_MLLM_GPU1, TURBOMIND_PR_TEST_MLLM_GPU2) from utils.config_utils import get_func_config_list, get_workerid from utils.pipeline_chat import run_pipeline_mllm_test BACKEND = 'turbomind' def get_models(parallel_config): return get_func_config_list(BACKEND, parallel_config, model_type='vl_model', extra={'session_len': 8192}) @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_models({'tp': 1})) def test_restful_chat_tp1(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_models({'tp': 2})) def test_restful_chat_tp2(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_models({'tp': 4})) def test_restful_chat_tp4(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_models({'tp': 8})) def test_restful_chat_tp8(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_16 @pytest.mark.parametrize('run_config', get_models({'tp': 16})) def test_restful_chat_tp16(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_1 @pytest.mark.other @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_MLLM_GPU1) def test_restful_chat_fallback_backend_tp1(config, run_config, worker_id): run_pipeline_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_1 @pytest.mark.other @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_MLLM_GPU1) def test_pipeline_pr_test(config, run_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) run_pipeline_mllm_test(config, run_config, worker_id, is_smoke=True) @pytest.mark.gpu_num_2 @pytest.mark.other @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_MLLM_GPU2) def test_pipeline_pr_tp2_test(config, run_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_pipeline_mllm_test(config, run_config, worker_id, is_smoke=True) ================================================ FILE: autotest/tools/quantization/test_quantization_awq.py ================================================ import os import allure import pytest from utils.config_utils import get_cuda_prefix_by_workerid, get_quantization_model_list from utils.quantization_utils import quantization @pytest.mark.order(3) @pytest.mark.test_3090 @pytest.mark.timeout(900) @pytest.mark.parametrize('model', get_quantization_model_list('awq')) def test_quantization_awq(config, model, worker_id): quantization_type = 'awq' quantization_all(config, model + '-inner-4bits', model, quantization_type, get_cuda_prefix_by_workerid(worker_id, {'tp': 1})) @pytest.mark.order(3) @pytest.mark.timeout(900) @pytest.mark.parametrize('model', get_quantization_model_list('gptq')) def test_quantization_gptq(config, model, worker_id): quantization_type = 'gptq' quantization_all(config, model + '-inner-gptq', model, quantization_type, get_cuda_prefix_by_workerid(worker_id, {'tp': 1})) @pytest.mark.order(3) @pytest.mark.pr_test @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.timeout(900) @pytest.mark.parametrize('model', ['Qwen/Qwen3-0.6B']) def test_quantization_awq_pr(config, model): quantization_type = 'awq' quantization_all(config, model + '-inner-4bits', model, quantization_type, cuda_prefix='CUDA_VISIBLE_DEVICES=6') def quantization_all(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix: str = ''): result, msg = quantization(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( log_path, '_'.join(['quantization', quantization_type, quantization_model_name.split('/')[1]]) + '.log') allure.attach.file(quantization_log, name=quantization_log, attachment_type=allure.attachment_type.TEXT) assert result, msg ================================================ FILE: autotest/tools/quantization/test_quantization_w8a8.py ================================================ import os import allure import pytest from utils.config_utils import get_cuda_prefix_by_workerid, get_quantization_model_list from utils.quantization_utils import quantization @pytest.mark.order(2) @pytest.mark.quantization_w8a8 @pytest.mark.timeout(900) @pytest.mark.parametrize('model', get_quantization_model_list('w8a8')) def test_quantization_w8a8(config, model, worker_id): quantization_w8a8(config, model + '-inner-w8a8', model, get_cuda_prefix_by_workerid(worker_id, {'tp': 1})) def quantization_w8a8(config, quantization_model_name, origin_model_name, cuda_prefix): quantization_type = 'w8a8' result, msg = quantization(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( log_path, '_'.join(['quantization', quantization_type, quantization_model_name.split('/')[1]]) + '.log') allure.attach.file(quantization_log, name=quantization_log, attachment_type=allure.attachment_type.TEXT) assert result, msg ================================================ FILE: autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py ================================================ import time import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2, PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2, REASONING_TEST_LLM, SPECULATIVE_DECODING_RESTFUL_TEST_LLM, TOOLCALL_TEST_LLM) from utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid from utils.constant import PROXY_PORT from utils.proxy_distributed_utils import ApiServerPerTest, proxy_worker_node_wait from utils.ray_distributed_utils import ray_worker_node_wait from utils.run_restful_chat import run_all_step, run_llm_test, run_reasoning_case, run_tools_case BACKEND = 'pytorch' def _run_ray_distributed_test( config, run_config, common_case_config, manager=None, # ← New parameter: pass in shared manager ): """Universal distributed test executor (using shared Ray cluster)""" assert manager is not None, 'Manager instance must be provided' if manager.is_master: # Start API Server for current model (master node starts/stops, worker nodes verify) manager.start_lmdeploy_api_server(config=config, run_config=run_config) try: case_name = get_case_str_by_config(run_config) run_all_step(config.get('log_path'), case_name, common_case_config, port=PROXY_PORT) finally: # Clean up API Server for current model (worker nodes skip) manager.cleanup(force=False) else: time.sleep(10) ray_worker_node_wait(manager, timeout_minutes=4880) def _run_proxy_distributed_test( config, run_config, common_case_config, manager=None, # ← New parameter: pass in shared manager ): """Universal distributed test executor (using shared Ray cluster)""" assert manager is not None, 'Manager instance must be provided' api_server = ApiServerPerTest(proxy_manager=manager, config=config, run_config=run_config) api_server.start() try: if manager.is_master: api_server.wait_until_ready() case_name = get_case_str_by_config(run_config) run_all_step(config.get('log_path'), case_name, common_case_config, port=PROXY_PORT) else: print(f'⏸️ Worker node {manager.node_rank} waiting for master to complete test...') proxy_worker_node_wait(manager, timeout_minutes=4880) finally: api_server.cleanup() if manager.is_master: time.sleep(1) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1})) def test_restful_chat_tp1(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_restful_chat_tp2(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_4 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_restful_chat_tp4(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_8 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_restful_chat_tp8(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_16 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16})) def test_restful_chat_tp16(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api_pytorch @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_distributed_tp16 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16})) def test_restful_chat_distributed_tp16(shared_ray_manager, config, run_config, common_case_config, worker_id): _run_ray_distributed_test(config=config, run_config=run_config, common_case_config=common_case_config, manager=shared_ray_manager) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api_pytorch @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_distributed_dpep16 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'dp': 16, 'ep': 16})) def test_restful_chat_distributed_dpep16(shared_proxy_manager, config, run_config, common_case_config, worker_id): _run_proxy_distributed_test(config=config, run_config=run_config, common_case_config=common_case_config, manager=shared_proxy_manager) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.test_ascend @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None})) def test_restful_chat_pytorch_prefix_cache_tp2(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2) def test_hf_pytorch_chat_pr_tp2(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1) def test_hf_pytorch_chat_pr_tp1(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1) def test_pytorch_chat_with_lora_tp1(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2) def test_pytorch_chat_with_lora_tp2(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1]) def test_restful_chat_reasoning_tp1(config, run_config, worker_id): run_reasoning_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.parametrize( 'run_config', [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2]) def test_restful_chat_reasoning_tp2(config, run_config, worker_id): run_reasoning_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1]) def test_restful_chat_tools_tp1(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2]) def test_restful_chat_tools_tp2(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_4 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 4]) def test_restful_chat_tools_tp4(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in SPECULATIVE_DECODING_RESTFUL_TEST_LLM if item['parallel_config'].get('tp') == 1]) def test_restful_chat_speculative_decoding_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_distributed_tp16 @pytest.mark.parametrize( 'run_config', [item for item in SPECULATIVE_DECODING_RESTFUL_TEST_LLM if item['parallel_config'].get('tp') == 16]) def test_restful_chat_speculative_decoding_tp16(shared_ray_manager, config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} _run_ray_distributed_test(config=config, run_config=run_config, common_case_config=case_config, manager=shared_ray_manager) ================================================ FILE: autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py ================================================ import pytest from utils.config_utils import get_func_config_list from utils.run_restful_chat import run_mllm_test BACKEND = 'pytorch' @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, model_type='vl_model')) def test_restful_chat_tp1(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, model_type='vl_model')) def test_restful_chat_tp2(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}, model_type='vl_model')) def test_restful_chat_tp4(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}, model_type='vl_model')) def test_restful_chat_tp8(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_16 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}, model_type='vl_model')) def test_restful_chat_tp16(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) ================================================ FILE: autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py ================================================ import pytest from tools.common_case_config import (MODELSCOPE_CONFIG, REASONING_TEST_LLM, TOOLCALL_TEST_LLM, TURBOMIND_FALLBACK_TEST_LLM_GPU1, TURBOMIND_FALLBACK_TEST_LLM_GPU2, TURBOMIND_LOGPROBS_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1, TURBOMIND_PR_TEST_LLM_GPU2) from utils.config_utils import get_func_config_list, get_workerid from utils.run_restful_chat import run_llm_test, run_logprob_test, run_reasoning_case, run_tools_case BACKEND = 'turbomind' @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1})) def test_restful_chat_tp1(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2})) def test_restful_chat_tp2(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4})) def test_restful_chat_tp4(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8})) def test_restful_chat_tp8(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None})) def test_restful_chat_prefix_cache_tp2(config, run_config, common_case_config, worker_id): run_llm_test(config, run_config, common_case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1) def test_restful_chat_fallback_backend_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2) def test_restful_chat_fallback_backend_tp2(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2) def test_restful_chat_pr_tp2(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1) def test_restful_chat_pr_tp1(config, run_config, common_case_config, worker_id): worker_id = 'gw' + str(6 + get_workerid(worker_id)) case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.pr_test @pytest.mark.parametrize('run_config', TURBOMIND_LOGPROBS_TEST_LLM_GPU2) def test_restful_logprobs(config, run_config, worker_id): worker_id = 'gw' + str(3 + get_workerid(worker_id)) run_logprob_test(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.gpu_num_1 @pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND]) def test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id): case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'} run_llm_test(config, run_config, case_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1]) def test_restful_chat_reasoning_tp1(config, run_config, worker_id): run_reasoning_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.parametrize( 'run_config', [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2]) def test_restful_chat_reasoning_tp2(config, run_config, worker_id): run_reasoning_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_1 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1]) def test_restful_chat_tools_tp1(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_2 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2]) def test_restful_chat_tools_tp2(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) @pytest.mark.usefixtures('common_case_config') @pytest.mark.flaky(reruns=0) @pytest.mark.gpu_num_4 @pytest.mark.parametrize( 'run_config', [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 4]) def test_restful_chat_tools_tp4(config, run_config, worker_id): run_tools_case(config, run_config, worker_id) ================================================ FILE: autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py ================================================ import pytest from tools.common_case_config import TURBOMIND_FALLBACK_TEST_MLLM_GPU1 from utils.config_utils import get_func_config_list from utils.run_restful_chat import run_mllm_test BACKEND = 'turbomind' @pytest.mark.gpu_num_1 @pytest.mark.test_3090 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, model_type='vl_model')) def test_restful_chat_tp1(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_2 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, model_type='vl_model')) def test_restful_chat_tp2(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_4 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}, model_type='vl_model')) def test_restful_chat_tp4(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_8 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}, model_type='vl_model')) def test_restful_chat_tp8(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_16 @pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}, model_type='vl_model')) def test_restful_chat_tp16(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) @pytest.mark.gpu_num_1 @pytest.mark.other @pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_MLLM_GPU1) def test_restful_chat_fallback_backend_tp1(config, run_config, worker_id): run_mllm_test(config, run_config, worker_id) ================================================ FILE: autotest/utils/benchmark_utils.py ================================================ import os import time import allure import utils.constant as constant from utils.common_utils import execute_command_with_logging from utils.config_utils import get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid, get_workerid from utils.run_restful_chat import health_check, start_openai_service, terminate_restful_api def throughput_test(config, run_config, worker_id: str = '', is_smoke: bool = False): model = run_config.get('model') model_path = os.path.join(config.get('model_path'), model) dataset_path = config.get('dataset_path') case_name = get_case_str_by_config(run_config) benchmark_path = os.path.join(config.get('benchmark_path'), 'throughput') work_dir = os.path.join(benchmark_path, f'wk_{case_name}') os.makedirs(work_dir, exist_ok=True) max_cache_entry = get_max_cache_entry(model, run_config.get('backend')) if max_cache_entry is not None: if 'extra_params' not in run_config: run_config['extra_params'] = {} run_config['extra_params']['cache-max-entry-count'] = max_cache_entry cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) command = f'{cuda_prefix} python3 benchmark/profile_throughput.py {dataset_path} {model_path} {get_cli_common_param(run_config)}' # noqa if is_smoke: num_prompts = '--num-prompts 100' else: num_prompts = '--num-prompts 5000' env = os.environ.copy() env.update(run_config.get('env', {})) for batch in [128, 256]: csv_path = os.path.join(work_dir, f'{batch}.csv') timestamp = time.strftime('%Y%m%d_%H%M%S') benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{batch}_{timestamp}.log') cmd = ' '.join([command, '--concurrency', str(batch), num_prompts, '--csv ', csv_path]).strip() result, stderr = execute_command_with_logging(cmd, benchmark_log, env=env) allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT) if result and not os.path.isfile(csv_path): return False, 'result is empty' if not result: return False, stderr return True, 'success' def longtext_throughput_test(config, run_config, worker_id: str = ''): model = run_config.get('model') model_path = os.path.join(config.get('model_path'), model) dataset_path = config.get('dataset_path') case_name = get_case_str_by_config(run_config) benchmark_path = os.path.join(config.get('benchmark_path'), 'longtext-throughput') work_dir = os.path.join(benchmark_path, f'wk_{case_name}') os.makedirs(work_dir, exist_ok=True) max_cache_entry = get_max_cache_entry(model, run_config.get('backend')) if max_cache_entry is not None: if 'extra_params' not in run_config: run_config['extra_params'] = {} run_config['extra_params']['cache-max-entry-count'] = max_cache_entry run_config['extra_params'].pop('session-len', None) cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) command = f'{cuda_prefix} python3 benchmark/profile_pipeline_api.py {dataset_path} {model_path} {get_cli_common_param(run_config)}' # noqa env = os.environ.copy() env.update(run_config.get('env', {})) for input_len, out_len, num_prompts, session_info, concurrency in [(1, 32768, 3, '32k', 3), (1, 65536, 1, '64k', 1), (65536, 1024, 5, '64k-1k', 5), (198000, 1024, 1, '198k-1k', 1)]: session_len = input_len + out_len + 1 csv_path = os.path.join(work_dir, f'{case_name}_{session_info}.csv') timestamp = time.strftime('%Y%m%d_%H%M%S') benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{session_info}_{timestamp}.log') cmd = ' '.join([ command, '--dataset-name random', f'--random-input-len {input_len}', f'--random-output-len {out_len}', f'--num-prompts {num_prompts}', f'--concurrency {concurrency}', '--stream-output', f'--session-len {session_len}', '--random-range-ratio 1', f'--csv {csv_path}' ]).strip() result, stderr = execute_command_with_logging(cmd, benchmark_log, timeout=7200, env=env) allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT) if result and not os.path.isfile(csv_path): return False, 'result is empty' if not result: return False, stderr return True, 'success' def restful_test(config, run_config, worker_id: str = '', is_smoke: bool = False, is_mllm: bool = False): max_cache_entry = get_max_cache_entry(run_config.get('model'), run_config.get('backend')) if max_cache_entry is not None: if 'extra_params' not in run_config: run_config['extra_params'] = {} run_config['extra_params']['cache-max-entry-count'] = max_cache_entry pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: if is_mllm: return mllm_restful_profile(config, run_config, port=constant.DEFAULT_PORT + get_workerid(worker_id), is_smoke=is_smoke) else: return restful_profile(config, run_config, port=constant.DEFAULT_PORT + get_workerid(worker_id), is_smoke=is_smoke) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) BASE_HTTP_URL = f'http://{constant.DEFAULT_SERVER}' def restful_profile(config, run_config, port, is_smoke: bool = False): model_path = os.path.join(config.get('model_path'), run_config.get('model')) case_name = get_case_str_by_config(run_config) dataset_path = config.get('dataset_path') benchmark_path = os.path.join(config.get('benchmark_path'), 'restful') work_dir = os.path.join(benchmark_path, f'wk_{case_name}') timestamp = time.strftime('%Y%m%d_%H%M%S') benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{timestamp}.log') os.makedirs(work_dir, exist_ok=True) http_url = f'{BASE_HTTP_URL}:{port}' # noqa: E231 if not health_check(http_url, case_name): return False, 'server not start' csv_path = f'{work_dir}/restful.csv' command = f'python benchmark/profile_restful_api.py --backend lmdeploy --dataset-name sharegpt --dataset-path {dataset_path} --tokenizer {model_path} --base-url {http_url} --output-file {csv_path}' # noqa if is_smoke: command += ' --num-prompts 100' else: command += ' --num-prompts 5000' result, stderr = execute_command_with_logging(command, benchmark_log) allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT) if result and not os.path.isfile(csv_path): return False, 'result is empty' if not result: return False, stderr return True, 'success' def mllm_restful_profile(config, run_config, port, is_smoke: bool = False): model_path = os.path.join(config.get('model_path'), run_config.get('model')) case_name = get_case_str_by_config(run_config) benchmark_path = os.path.join(config.get('benchmark_path'), 'mllm_restful') work_dir = os.path.join(benchmark_path, f'wk_{case_name}') timestamp = time.strftime('%Y%m%d_%H%M%S') benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{timestamp}.log') os.makedirs(work_dir, exist_ok=True) http_url = f'{BASE_HTTP_URL}:{port}' # noqa: E231 if not health_check(http_url, case_name): return False, 'server not start' csv_path = f'{work_dir}/mllm_restful.csv' command = f'python benchmark/profile_restful_api.py --backend lmdeploy-chat --dataset-name image --tokenizer {model_path} --model {case_name} --model-path {model_path} --random-input-len 100 --random-output-len 100 --random-range-ratio 1 --image-format jpeg --image-count 1 --image-content random --image-resolution 1024x1024 --base-url {http_url} --output-file {csv_path}' # noqa if is_smoke: command += ' --num-prompts 100' else: command += ' --num-prompts 1000' result, stderr = execute_command_with_logging(command, benchmark_log) allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT) if result and not os.path.isfile(csv_path): return False, 'result is empty' if not result: return False, stderr return True, 'success' def prefixcache_throughput_test(config, run_config, worker_id: str = '', is_smoke: bool = False): model = run_config.get('model') model_path = os.path.join(config.get('model_path'), model) dataset_path = config.get('prefix_dataset_path') case_name = get_case_str_by_config(run_config) benchmark_path = os.path.join(config.get('benchmark_path'), 'prefix-throughtput') work_dir = os.path.join(benchmark_path, f'wk_{case_name}') os.makedirs(work_dir, exist_ok=True) max_cache_entry = get_max_cache_entry(model, run_config.get('backend')) if max_cache_entry is not None: if 'extra_params' not in run_config: run_config['extra_params'] = {} run_config['extra_params']['cache-max-entry-count'] = max_cache_entry cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) run_config_new = run_config.copy() if 'extra_params' not in run_config_new: run_config_new['extra_params'] = {} run_config_new['extra_params'].pop('enable-prefix-caching', None) run_config_new['extra_params']['session-len'] = 32768 command = f'{cuda_prefix} python3 benchmark/profile_pipeline_api.py {dataset_path} {model_path} {get_cli_common_param(run_config_new)}' # noqa env = os.environ.copy() env.update(run_config.get('env', {})) if is_smoke: test_configs = [(4096, 256, 10, '4k', None)] else: test_configs = [(4096, 256, 100, '4k', None)] for enable_prefix_caching in [False, True]: suffix = 'cache' if enable_prefix_caching else 'no_cache' for input_len, out_len, num_prompts, session_info, concurrency in test_configs: timestamp = time.strftime('%Y%m%d_%H%M%S') benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{session_info}_{suffix}_{timestamp}.log') csv_path = os.path.join(work_dir, f'{session_info}_{suffix}.csv') command = ' '.join([ command, '--dataset-name random', f'--random-input-len {input_len}', f'--random-output-len {out_len}', '--random-range-ratio 1.0', f'--num-prompts {num_prompts}', '--stream-output', f'--csv {csv_path}' ]).strip() if enable_prefix_caching: command += ' --enable-prefix-caching' if concurrency: command += f' --concurrency {concurrency}' result, stderr = execute_command_with_logging(command, benchmark_log, env=env) allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT) if result and not os.path.isfile(csv_path): return False, 'result is empty' if not result: return False, stderr return True, 'success' def get_max_cache_entry(model, backend): if backend == 'pytorch': return 0.8 if 'Llama-2' in model: return 0.95 elif 'internlm2' in model: return 0.9 elif 'Qwen/Qwen3-235B-A22B' == model or 'internlm/Intern-S1' == model: return 0.7 else: return None ================================================ FILE: autotest/utils/common_utils.py ================================================ import os import subprocess import sys def execute_command_with_logging(cmd, log_file_path: str, timeout: int = 3600, env=None, should_print=True) -> tuple[bool, str]: if env is None: env = os.environ.copy() if os.path.isfile(log_file_path): write_type = 'a' else: write_type = 'w' try: result = True with open(log_file_path, write_type, encoding='utf-8') as log_file: start_msg = f'execute command: {cmd}\n' print(start_msg, end='') log_file.write(start_msg) log_file.flush() process = subprocess.run(cmd, shell=True, text=True, encoding='utf-8', errors='replace', stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, bufsize=1, timeout=timeout, start_new_session=True) if process.stdout: if should_print: print(process.stdout, end='') log_file.write(process.stdout) if process.returncode == 0: result_msg = 'execute command success!\n' else: result = False result_msg = f'execute command fail: {process.returncode}\n' log_file.write(result_msg) return result, result_msg.strip() except Exception as e: error_msg = f'execute command fail exception: {str(e)}\n' print(error_msg, file=sys.stderr, end='') with open(log_file_path, 'a', encoding='utf-8') as log_file: log_file.write(error_msg) return False, error_msg.strip() ================================================ FILE: autotest/utils/config_utils.py ================================================ import copy import os from collections import OrderedDict from typing import Any import yaml from lmdeploy.utils import is_bf16_supported SUFFIX_INNER_AWQ = '-inner-4bits' SUFFIX_INNER_GPTQ = '-inner-gptq' SUFFIX_INNER_W8A8 = '-inner-w8a8' def resolve_extra_params(extra_params: dict[str, Any], model_base_path: str) -> None: """Resolve relative model paths in extra_params to absolute paths. Centralised helper so that every call-site does not need its own ``if key in extra_params …`` guard – adding a new key here is enough. """ # Keys in extra_params whose string values are relative model paths model_path_keys = ['speculative-draft-model'] # Flat string-valued keys for key in model_path_keys: if key in extra_params: value = extra_params[key] if value and isinstance(value, str) and not os.path.isabs(value): extra_params[key] = os.path.join(model_base_path, value) # Nested speculative_config (pipeline usage) spec_cfg = extra_params.get('speculative_config') if isinstance(spec_cfg, dict) and 'model' in spec_cfg: model = spec_cfg['model'] if model and isinstance(model, str) and not os.path.isabs(model): spec_cfg['model'] = os.path.join(model_base_path, model) def get_func_config_list(backend: str, parallel_config: dict[str, int], model_type: str = 'chat_model', func_type: str = 'func', extra: dict[str, Any] | None = None) -> list[dict[str, Any]]: """Generate all valid running config combinations (communicator + quant policy + model). Args: backend: Backend type (turbomind/pytorch) parallel_config: Parallel config for tensor parallel model_type: Model type, default: chat_model func_type: Test func type filter, default: func extra: extra config to update in each run config dict Returns: list[dict]: All valid run config dicts """ config = get_config() device = config.get('device', 'cuda') base_case_list = get_model_list(config, backend, parallel_config, model_type, func_type) if extra is None: extra = {} run_configs = [] dtype = 'float16' if not is_bf16_supported(device) else None for communicator in _get_communicator_list(config, backend, parallel_config): for model in base_case_list: for quant_policy in [0, 4, 8]: # temp remove testcase because of issue 3434 if 'turbomind' == backend and communicator == 'cuda-ipc' and parallel_config.get( 'tp', 1) > 1 and ('InternVL3' in model or 'InternVL2_5' in model or 'MiniCPM-V-2_6' in model or 'InternVL2-Llama3' in model): # noqa continue if 'turbomind' == backend and parallel_config.get( 'tp', 1 ) > 1 and model_type == 'vl_model' and func_type == 'mllm_evaluate': # mllm eval with bug when tp > 2 continue # [TM][FATAL] models/llama/LlamaBatch.cc(362): Check failed: r->session.start_flag Mrope doesn't support interactive chat # noqa if ('Qwen2.5-VL' in model or 'Qwen2-VL' in model) and 'turbomind' == backend: continue # AssertionError: prompts should be a list if 'phi' in model.lower() and model_type == 'vl_model': continue if not _is_kvint_model(config, backend, model, quant_policy): continue run_config = { 'model': model, 'backend': backend, 'communicator': communicator, 'quant_policy': quant_policy, 'parallel_config': parallel_config, 'extra_params': copy.copy(extra) } if dtype and backend == 'pytorch': run_config['extra_params']['dtype'] = dtype if device != 'cuda': run_config['extra_params']['device'] = device run_configs.append(run_config) for run_config in run_configs: if 'Qwen3-235B-A22B-Thinking-2507' in run_config['model']: run_config['extra_params']['cache-max-entry-count'] = 0.9 run_config['extra_params']['max-batch-size'] = 1024 if config.get('env_tag', '') in ['3090', '5080']: run_config['extra_params']['cache-max-entry-count'] = 0.5 if config.get('env_tag', '') in ['a100'] and ('Qwen3-235B-A22B' in run_config['model'] or run_config['model'] == 'internlm/Intern-S1'): run_config['extra_params']['cache-max-entry-count'] = 0.6 if 'sdar' in run_config['model'].lower(): run_config['extra_params']['dllm-block-length'] = 4 run_config['extra_params']['dllm-denoising-steps'] = 4 run_config['extra_params']['dllm-confidence-threshold'] = 0.9 if 'kimi' in run_config['model'].lower(): para_conf = run_config.get('parallel_config', {}) if para_conf.get('dp', 0) == 16 and para_conf.get('ep', 0) == 16: run_config['extra_params']['max-batch-size'] = 256 if 'Intern-S1-Pro-FP8' in run_config['model'] or 'Intern-S1-Pro-BF16' in run_config['model']: if 'Intern-S1-Pro-FP8' in run_config['model']: run_config['extra_params']['model-format'] = 'fp8' para_conf = run_config.get('parallel_config', {}) # For dpep16 configuration, add max-prefill-token-num if para_conf.get('dp', 0) == 16 and para_conf.get('ep', 0) == 16: run_config['extra_params']['max-prefill-token-num'] = 1024 run_config['extra_params']['max-batch-size'] = 128 return run_configs def get_cli_common_param(run_config: dict[str, Any]) -> str: """Generate cli common params string by run config dict.""" backend = run_config.get('backend') model = run_config.get('model') communicator = run_config.get('communicator') quant_policy = run_config.get('quant_policy') extra_params = run_config.get('extra_params', {}) parallel_config = run_config.get('parallel_config', {}) cli_params = [f'--backend {backend}', f'--communicator {communicator}'] # Optional params if quant_policy != 0: cli_params.append(f'--quant-policy {quant_policy}') # quant format model_lower = model.lower() if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower: cli_params.append('--model-format awq') if 'gptq' in model_lower: cli_params.append('--model-format gptq') # Parallel config for para_key in ('dp', 'ep', 'cp'): if para_key in parallel_config and parallel_config[para_key] > 1: cli_params.append(f'--{para_key} {parallel_config[para_key]}') if 'tp' in parallel_config and parallel_config['tp'] > 1: tp_num = parallel_config['tp'] cli_params.append(f'--tp {tp_num}') # noqa # Extra params cli_params.append(get_cli_str(extra_params)) return ' '.join(cli_params).strip() def get_cli_str(config: dict[str, Any]) -> str: cli_str = [] # Extra params for key, value in config.items(): key = key.replace('_', '-') if value is None: cli_str.append(f'--{key}') elif isinstance(value, list): tmp_cli = ' '.join(map(str, value)) cli_str.append(f'--{key} {tmp_cli}') elif isinstance(value, dict): tmp_cli = ' '.join([f'{k}={v}' for k, v in value.items()]) cli_str.append(f'--{key} {tmp_cli}') else: cli_str.append(f'--{key} {value}' if value else f'--{key}') return ' '.join(cli_str) def get_parallel_config(config: dict[str, Any], model_name: str) -> list[dict[str, int]]: """Get matched parallel config dict by model name, default tp:1 if no match.""" result = [] base_model = _base_model_name(model_name) parallel_configs = config.get('config', {}) for conf_key, model_map in parallel_configs.items(): if model_map is None: continue if base_model in model_map: conf_value = model_map[base_model] if isinstance(conf_value, dict): result.append(conf_value.copy()) elif isinstance(conf_value, int): result.append({conf_key: conf_value}) return result if result else [{'tp': 1}] def _extract_models_from_config(config_value: Any) -> list[str]: """Extract flat model name list from config value (dict/list supported)""" models = [] if isinstance(config_value, dict): for model_list in config_value.values(): if isinstance(model_list, list): models.extend([m for m in model_list if isinstance(m, str)]) elif isinstance(config_value, list): models.extend([m for m in config_value if isinstance(m, str)]) return models def get_model_list(config: dict[str, Any], backend: str, parallel_config: dict[str, int] | None = None, model_type: str = 'chat_model', func_type: str = 'func') -> list[str]: """Get filtered model list with quantization extended models by backend/parallel config/model type/func type. Args: config: Global system config dict backend: Backend type (turbomind/pytorch) parallel_config: Parallel filter config model_type: Model type, default: chat_model func_type: Test func type filter, default: func Returns: list[str]: Base models + quantization extended models """ model_config_key = f'{backend}_{model_type}' all_models = [] if model_config_key in config: all_models = _extract_models_from_config(config[model_config_key]) all_models = _filter_by_test_func_type(config, all_models, func_type) all_models = list(OrderedDict.fromkeys(all_models)) # Deduplicate, keep order all_models = [model for model in all_models if is_model_in_list(config, parallel_config, model)] extended_models = list(all_models) quantization_config = config.get(f'{backend}_quantization', {}) # Append quantization models by backend if backend == 'turbomind': _extend_turbomind_quant_models(quantization_config, all_models, extended_models) elif backend == 'pytorch': _extend_pytorch_quant_models(quantization_config, all_models, extended_models) return extended_models def _filter_by_test_func_type(config: dict[str, Any], model_list: list[str], func_type: str) -> list[str]: """Filter model list by test function type, return intersection of two model sets.""" if func_type == 'func': return model_list filtered_models = [] model_config_key = f'{func_type}_model' if model_config_key in config: filtered_models = _extract_models_from_config(config[model_config_key]) return list(set(filtered_models) & set(model_list)) def _extend_turbomind_quant_models(quant_config: dict[str, Any], base_models: list[str], target_list: list[str]) -> None: """Append turbomind quantization models to target list (AWQ 4bits + GPTQ)""" no_awq_models = quant_config.get('no_awq', []) # Append AWQ 4bits quantization models for model_name in base_models: if model_name in target_list and model_name not in no_awq_models and not is_quantization_model(model_name): target_list.append(model_name + SUFFIX_INNER_AWQ) # Append GPTQ quantization models for model_name in quant_config.get('gptq', []): if model_name in target_list: target_list.append(model_name + SUFFIX_INNER_GPTQ) def _extend_pytorch_quant_models(quant_config: dict[str, Any], base_models: list[str], target_list: list[str]) -> None: """Append pytorch quantization models to target list (AWQ 4bits + W8A8)""" # Append AWQ quantization models for model_name in quant_config.get('awq', []): if model_name in target_list: target_list.append(model_name + SUFFIX_INNER_AWQ) # Append W8A8 quantization models for model_name in quant_config.get('w8a8', []): if model_name in target_list: target_list.append(model_name + SUFFIX_INNER_W8A8) def _is_kvint_model(config: dict[str, Any], backend: str, model: str, quant_policy: int) -> bool: """Check if model supports the kv quantization policy, quant_policy=0 always return True.""" if quant_policy == 0: return True no_kvint_black_list = config.get(f'{backend}_quantization', {}).get(f'no_kvint{quant_policy}', []) return _base_model_name(model) not in no_kvint_black_list def _base_model_name(model: str) -> str: """Simplify model name by removing quantization suffix for config matching.""" return model.replace('-inner-4bits', '').replace('-inner-w8a8', '').replace('-inner-gptq', '') def get_quantization_model_list(type: str) -> list[str]: """Get quantization model list by specified quant type(awq/gptq/w8a8)""" config = get_config() quant_model_list = [] if type == 'awq': # Get all turbomind chat/base models & deduplicate turbo_chat = _extract_models_from_config( config['turbomind_chat_model']) if 'turbomind_chat_model' in config else [] turbo_base = _extract_models_from_config( config['turbomind_base_model']) if 'turbomind_base_model' in config else [] all_turbo_models = list(OrderedDict.fromkeys(turbo_chat + turbo_base)) # Filter turbomind valid awq models no_awq = config.get('turbomind_quantization', {}).get('no_awq', []) quant_model_list = [m for m in all_turbo_models if m not in no_awq and not is_quantization_model(m)] # Append pytorch awq models torch_awq = config.get('pytorch_quantization', {}).get('awq', []) for model in torch_awq: if model not in quant_model_list: quant_model_list.append(model) elif type == 'gptq': quant_model_list = config.get('turbomind_quantization', {}).get(type, []) elif type == 'w8a8': quant_model_list = config.get('pytorch_quantization', {}).get(type, []) return quant_model_list def get_config() -> dict[str, Any]: """Load & get yaml config file, auto adapt device env & update log path.""" # Get device env & match config file path env_tag = os.environ.get('TEST_ENV') config_path = f'autotest/config_{env_tag}.yml' if env_tag else 'autotest/config.yml' # Fallback to default config if device-specific config not exist if env_tag and not os.path.exists(config_path): config_path = 'autotest/config.yml' # Load yaml config file safely with open(config_path, 'r', encoding='utf-8') as f: config = yaml.load(f.read(), Loader=yaml.SafeLoader) # Deep copy config to avoid modify raw data, update log path with github run id config_copy = copy.deepcopy(config) run_id = os.environ.get('RUN_ID', 'local_run') config_copy['log_path'] = os.path.join(config_copy['log_path'], str(run_id).replace('/', '_')) config_copy['eval_path'] = os.path.join(config_copy['eval_path'], str(run_id).replace('/', '_')) config_copy['mllm_eval_path'] = os.path.join(config_copy['mllm_eval_path'], str(run_id).replace('/', '_')) config_copy['benchmark_path'] = os.path.join(config_copy['benchmark_path'], str(run_id).replace('/', '_')) config_copy['server_log_path'] = os.path.join(config_copy['server_log_path'], str(run_id).replace('/', '_')) os.makedirs(config_copy['log_path'], exist_ok=True) os.makedirs(config_copy['eval_path'], exist_ok=True) os.makedirs(config_copy['mllm_eval_path'], exist_ok=True) os.makedirs(config_copy['benchmark_path'], exist_ok=True) os.makedirs(config_copy['server_log_path'], exist_ok=True) return config_copy def get_cuda_prefix_by_workerid(worker_id: str | None, parallel_config: dict[str, int] | None = None) -> str | None: """Get cuda/ascend visible devices env prefix by worker id & parallel config.""" para_conf = parallel_config or {} device_type = os.environ.get('DEVICE', 'cuda') tp_num = para_conf.get('tp') if not tp_num: return '' cuda_id = get_cuda_id_by_workerid(worker_id, tp_num) if not cuda_id: return '' return f'ASCEND_RT_VISIBLE_DEVICES={cuda_id}' if device_type == 'ascend' else f'CUDA_VISIBLE_DEVICES={cuda_id}' def get_cuda_id_by_workerid(worker_id: str | None, tp_num: int = 1) -> str | None: """Get cuda id str by worker id and tp num, return None if invalid worker id.""" if worker_id is None or 'gw' not in worker_id: return None base_id = int(worker_id.replace('gw', '')) cuda_num = base_id * tp_num return ','.join([str(cuda_num + i) for i in range(tp_num)]) def get_workerid(worker_id: str | None) -> int: """Parse numeric worker id from worker id str, return 0 if invalid worker id.""" if worker_id is None or 'gw' not in worker_id: return 0 return int(worker_id.replace('gw', '')) def is_quantization_model(model: str) -> bool: """Check if model name contains quantization related keywords.""" lower_name = model.lower() return any(key in lower_name for key in ('awq', '4bits', 'w4', 'int4')) def _get_communicator_list(config: dict[str, Any], backend: str, parallel_config: dict[str, int] | None = None) -> list[str]: """Get available communicator list by device and parallel config.""" device = config.get('device', None) if device == 'ascend': return ['nccl'] if backend == 'pytorch': return ['nccl'] if ('cp' in parallel_config or 'dp' in parallel_config or 'ep' in parallel_config): return ['nccl'] if 'tp' in parallel_config and parallel_config['tp'] == 1: return ['nccl'] return ['nccl', 'cuda-ipc'] def set_device_env_variable(worker_id: str | None, parallel_config: dict[str, int] | None = None) -> None: """Set device environment variable based on the device type.""" device = os.environ.get('DEVICE', 'cuda') tp_num = 1 if parallel_config is not None: if isinstance(parallel_config, int): tp_num = parallel_config elif isinstance(parallel_config, dict): tp_num = parallel_config.get('tp', 1) if device == 'ascend': device_id = get_cuda_id_by_workerid(worker_id, tp_num) if device_id is not None: os.environ['ASCEND_RT_VISIBLE_DEVICES'] = device_id else: cuda_id = get_cuda_id_by_workerid(worker_id, tp_num) if cuda_id is not None: os.environ['CUDA_VISIBLE_DEVICES'] = cuda_id def unset_device_env_variable(): device_type = os.environ.get('DEVICE', 'cuda') if device_type == 'ascend': if 'ASCEND_RT_VISIBLE_DEVICES' in os.environ: del os.environ['ASCEND_RT_VISIBLE_DEVICES'] else: if 'CUDA_VISIBLE_DEVICES' in os.environ: del os.environ['CUDA_VISIBLE_DEVICES'] def is_model_in_list(config: dict[str, Any], parallel_config: dict[str, int], model: str) -> bool: """Check if model matches the target parallel config.""" model_config = get_parallel_config(config, model) return parallel_config in model_config def get_case_str_by_config(run_config: dict[str, Any], is_simple: bool = True) -> str: """Generate case name string by run config dict.""" model_name = run_config['model'] backend_type = run_config['backend'] communicator = run_config.get('communicator', 'nccl') quant_policy = run_config.get('quant_policy', 0) parallel_config = run_config.get('parallel_config', {'tp': 1}) extra_params = run_config.get('extra_params', {}) # Sorted parallel config to fixed string format sorted_items = sorted(parallel_config.items()) parallel_str = '_'.join(f'{k}{v}' for k, v in sorted_items) # Get last section of model name, compatible with model name contains '/' pure_model_name = model_name.split('/')[-1].replace('_', '-') extra_params_case = '' if not is_simple: for k, v in extra_params.items(): if len(v) > 10: extra_params_case += f'_{k}'.replace('_', '-').replace('/', '-').replace('.', '-') else: extra_params_case += f'_{k}{v}'.replace('_', '-').replace('/', '-').replace('.', '-') return f'{backend_type}_{pure_model_name}_{communicator}_{parallel_str}_{quant_policy}{extra_params_case}' def parse_config_by_case(case_str: str) -> dict[str, Any]: """Parse run config dict from case name string (fix split & type convert bug)""" case_parts = case_str.split('_') # Parse fixed field & reassemble dynamic parallel config backend = case_parts[0] model = case_parts[1] communicator = case_parts[2] quant_policy = int(case_parts[-1]) parallel_parts = case_parts[3:-1] # Convert parallel str to dict, e.g: ['tp1','pp2'] -> {'tp':1, 'pp':2} parallel_config = {} for part in parallel_parts: for idx, char in enumerate(part): if char.isdigit(): k = part[:idx] v = int(part[idx:]) parallel_config[k] = v break return { 'backend': backend, 'model': model, 'communicator': communicator, 'parallel_config': parallel_config, 'quant_policy': quant_policy } def test_config(): os.environ['DEVICE'] = 'test' config = get_config() assert 'model_path' in config.keys() assert 'resource_path' in config.keys() assert 'log_path' in config.keys() assert 'server_log_path' in config.keys() assert 'eval_path' in config.keys() assert 'mllm_eval_path' in config.keys() assert 'benchmark_path' in config.keys() assert 'dataset_path' in config.keys() assert 'prefix_dataset_path' in config.keys() assert 'env_tag' in config.keys() assert 'config' in config.keys() assert 'tp' in config.get('config') assert is_model_in_list(config, parallel_config={'tp': 1}, model='test/test_tp1') assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp1') is False assert is_model_in_list(config, parallel_config={'ep': 1}, model='test/test_tp1') is False, is_model_in_list(config, parallel_config={'ep': 1}, model='test/test_tp1') assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp2-inner-4bits') assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp2-inner-w8a8') assert is_model_in_list(config, parallel_config={'tp': 8}, model='test/test_tp8-inner-gptq') assert is_model_in_list(config, parallel_config={'tp': 8}, model='test/test_cp2tp8') is False assert is_model_in_list(config, parallel_config={'tp': 8, 'cp': 2}, model='test/test_cp2tp8') assert is_model_in_list(config, parallel_config={'cp': 2, 'tp': 8}, model='test/test_cp2tp8') assert is_model_in_list(config, parallel_config={'cp': 4, 'tp': 8}, model='test/test_cp2tp8') is False assert is_model_in_list(config, parallel_config={'dp': 8, 'ep': 8}, model='test/test_dpep8') assert is_model_in_list(config, parallel_config={'dp': 4, 'ep': 8}, model='test/test_dpep8') is False assert is_model_in_list(config, parallel_config={'ep': 4, 'dp': 8}, model='test/test_dpep8') is False assert _is_kvint_model(config, 'turbomind', 'test/test_tp1-inner-4bits', 8) is False assert _is_kvint_model(config, 'turbomind', 'test/test_tp1-inner-4bits', 4) assert _is_kvint_model(config, 'turbomind', 'any', 0) assert _is_kvint_model(config, 'pytorch', 'test/test_tp1-inner-gptq', 8) is False assert _is_kvint_model(config, 'pytorch', 'test/test_tp1-inner-gptq', 4) assert _is_kvint_model(config, 'pytorch', 'test/test_vl_tp1-inner-gptq', 8) is False assert _is_kvint_model(config, 'pytorch', 'test/test_cp2tp8-inner-w8a8', 4) is False os.unsetenv('DEVICE') def test_get_case_str_by_config(): run_config = { 'model': 'test/test_dpep16', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'dp': 16, 'ep': 16 } } case_str = get_case_str_by_config(run_config) assert case_str == 'turbomind_test-dpep16_nccl_dp16_ep16_8', case_str run_config_parsed = parse_config_by_case(case_str) assert run_config_parsed['model'] == 'test-dpep16' assert run_config_parsed['backend'] == 'turbomind' assert run_config_parsed['communicator'] == 'nccl' assert run_config_parsed['quant_policy'] == 8 assert run_config_parsed['parallel_config']['dp'] == 16 assert run_config_parsed['parallel_config']['ep'] == 16 def test_cli_common_param(): run_config = { 'model': 'test/test_dpep16-inner-4bits', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 8, 'parallel_config': { 'dp': 16, 'ep': 16 }, 'extra_params': { 'dtype': 'bfloat16', 'device': 'ascend', 'enable_prefix_caching': None, 'max_batch_size': 2048, 'session_len': 8192, 'cache_max_entry_count': 0.75, 'adapters': { 'a': 'lora/2024-01-25_self_dup', 'b': 'lora/2024-01-25_self' } } } cli_params = get_cli_common_param(run_config) assert cli_params == '--backend turbomind --communicator nccl --quant-policy 8 --model-format awq --dp 16 --ep 16 --dtype bfloat16 --device ascend --enable-prefix-caching --max-batch-size 2048 --session-len 8192 --cache-max-entry-count 0.75 --adapters a=lora/2024-01-25_self_dup b=lora/2024-01-25_self', cli_params # noqa run_config = { 'model': 'test/test_dpep16-inner-4bits', 'backend': 'pytorch', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 8 } } cli_params = get_cli_common_param(run_config) assert cli_params == '--backend pytorch --communicator nccl --model-format awq --tp 8', cli_params os.unsetenv('TEST_ENV') def test_return_info_turbomind(): os.environ['TEST_ENV'] = 'test' backend = 'turbomind' func_chat_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func') assert len(func_chat_tp1) == 12, len(func_chat_tp1) func_chat_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='func') assert len(func_chat_tp2) == 32, len(func_chat_tp2) func_chat_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='func') assert len(func_chat_tp8) == 36, len(func_chat_tp8) func_chat_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='func') assert len(func_chat_cptp) == 14, len(func_chat_cptp) func_chat_dpep8 = get_func_config_list(backend, parallel_config={ 'dp': 8, 'ep': 8 }, model_type='chat_model', func_type='func') assert len(func_chat_dpep8) == 6, len(func_chat_dpep8) func_chat_dpep16 = get_func_config_list(backend, parallel_config={ 'dp': 16, 'ep': 16 }, model_type='chat_model', func_type='func') assert len(func_chat_dpep16) == 0, len(func_chat_dpep16) func_base_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='base_model', func_type='func') assert len(func_base_tp1) == 6, len(func_base_tp1) func_base_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='base_model', func_type='func') assert len(func_base_tp2) == 4, len(func_base_tp2) evaluate_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='evaluate') assert len(evaluate_tp1) == 6, len(evaluate_tp1) benchmark_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='benchmark') assert len(benchmark_tp2) == 4, len(benchmark_tp2) longtext_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='longtext') assert len(longtext_tp8) == 12, len(longtext_tp8) evaluate_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='evaluate') assert len(evaluate_cptp) == 4, len(evaluate_cptp) benchmark_dpep8 = get_func_config_list(backend, parallel_config={ 'dp': 8, 'ep': 8 }, model_type='chat_model', func_type='benchmark') assert len(benchmark_dpep8) == 0, len(benchmark_dpep8) mllm_benchmark_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='mllm_benchmark') assert len(mllm_benchmark_tp1) == 6, len(mllm_benchmark_tp1) mllm_longtext_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='mllm_longtext') assert len(mllm_longtext_tp2) == 0, len(mllm_longtext_tp2) mllm_evaluate_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='mllm_evaluate') assert len(mllm_evaluate_tp8) == 12, len(mllm_evaluate_tp8) mllm_evaluate_dpep16 = get_func_config_list(backend, parallel_config={ 'dp': 16, 'ep': 16 }, model_type='chat_model', func_type='evaluate') assert len(mllm_evaluate_dpep16) == 0, len(mllm_evaluate_dpep16) mllm_benchmark_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='benchmark') assert len(mllm_benchmark_cptp) == 4, len(mllm_benchmark_cptp) os.unsetenv('TEST_ENV') def test_return_info_pytorch(): os.environ['TEST_ENV'] = 'test' backend = 'pytorch' func_chat_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func') assert len(func_chat_tp1) == 12, len(func_chat_tp1) func_chat_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='func') assert len(func_chat_tp2) == 19, len(func_chat_tp2) func_chat_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='func') assert len(func_chat_tp8) == 9, len(func_chat_tp8) func_chat_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='func') assert len(func_chat_cptp) == 7, len(func_chat_cptp) func_chat_dpep8 = get_func_config_list(backend, parallel_config={ 'dp': 8, 'ep': 8 }, model_type='chat_model', func_type='func') assert len(func_chat_dpep8) == 8, len(func_chat_dpep8) func_chat_dpep16 = get_func_config_list(backend, parallel_config={ 'dp': 16, 'ep': 16 }, model_type='chat_model', func_type='func') assert len(func_chat_dpep16) == 6, len(func_chat_dpep16) func_base_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='base_model', func_type='func') assert len(func_base_tp1) == 7, len(func_base_tp1) func_base_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='base_model', func_type='func') assert len(func_base_tp2) == 4, len(func_base_tp2) evaluate_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='evaluate') assert len(evaluate_tp1) == 7, len(evaluate_tp1) benchmark_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='benchmark') assert len(benchmark_tp2) == 3, len(benchmark_tp2) longtext_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='longtext') assert len(longtext_tp8) == 3, len(longtext_tp8) evaluate_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='evaluate') assert len(evaluate_cptp) == 2, len(evaluate_cptp) benchmark_dpep8 = get_func_config_list(backend, parallel_config={ 'dp': 8, 'ep': 8 }, model_type='chat_model', func_type='benchmark') assert len(benchmark_dpep8) == 2, len(benchmark_dpep8) mllm_benchmark_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='mllm_benchmark') assert len(mllm_benchmark_tp1) == 5, len(mllm_benchmark_tp1) mllm_longtext_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='mllm_longtext') assert len(mllm_longtext_tp2) == 0, len(mllm_longtext_tp2) mllm_evaluate_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='mllm_evaluate') assert len(mllm_evaluate_tp8) == 3, len(mllm_evaluate_tp8) mllm_evaluate_dpep16 = get_func_config_list(backend, parallel_config={ 'dp': 16, 'ep': 16 }, model_type='chat_model', func_type='evaluate') assert len(mllm_evaluate_dpep16) == 3, len(mllm_evaluate_dpep16) mllm_benchmark_cptp = get_func_config_list(backend, parallel_config={ 'cp': 2, 'tp': 8 }, model_type='chat_model', func_type='benchmark') assert len(mllm_benchmark_cptp) == 2, len(mllm_benchmark_cptp) os.unsetenv('TEST_ENV') def test_run_config(): os.environ['TEST_ENV'] = 'test' backend = 'turbomind' run_config1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')[0] assert run_config1['model'] == 'test/test_tp1' assert run_config1['backend'] == 'turbomind' assert run_config1['communicator'] == 'nccl' assert run_config1['quant_policy'] == 0 assert run_config1['parallel_config'] == {'tp': 1} os.environ['TEST_ENV'] = 'testascend' backend = 'pytorch' run_config2 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')[0] assert run_config2['model'] == 'test/test_tp1' assert run_config2['backend'] == 'pytorch' assert run_config2['communicator'] == 'nccl' assert run_config2['quant_policy'] == 0 assert run_config2['parallel_config'] == {'tp': 1} run_config3 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func', extra={ 'speculative_algorithm': 'eagle', 'session_len': 1024 })[0] assert run_config3['model'] == 'test/test_tp1' assert run_config3['backend'] == 'pytorch' assert run_config3['communicator'] == 'nccl' assert run_config3['quant_policy'] == 0 assert run_config3['parallel_config'] == {'tp': 1} assert run_config3['extra_params']['speculative_algorithm'] == 'eagle' assert run_config3['extra_params']['session_len'] == 1024 os.unsetenv('TEST_ENV') def test_get_parallel_config(): test = get_parallel_config({}, 'empty') assert test == [{'tp': 1}] test = get_parallel_config( { 'config': { 'tp': { 'empty': 1 }, 'dp_ep': { 'empty': { 'dp': 1, 'ep': 8 } }, 'cp_tp': { 'empty': { 'cp': 8, 'tp': 8 } } } }, 'empty') assert test == [{'tp': 1}, {'dp': 1, 'ep': 8}, {'cp': 8, 'tp': 8}] if __name__ == '__main__': test_get_parallel_config() test_cli_common_param() test_run_config() test_get_case_str_by_config() test_return_info_pytorch() test_config() test_return_info_turbomind() ================================================ FILE: autotest/utils/constant.py ================================================ import os DEFAULT_PORT = 23333 DEFAULT_SERVER = os.getenv('MASTER_ADDR', '127.0.0.1') PROXY_PORT = 8000 EVAL_CONFIGS = { 'default': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.6, }, 'default-32k': { 'query_per_second': 4, 'max_out_len': 32768, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.6, }, 'default-2batch': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 2, 'temperature': 0.6, }, 'gpt': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.6, 'openai_extra_kwargs': { 'reasoning_effort': 'high', } }, 'gpt-32k': { 'query_per_second': 4, 'max_out_len': 32768, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.6, 'openai_extra_kwargs': { 'reasoning_effort': 'high', } }, 'gpt-2batch': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 2, 'temperature': 0.6, 'openai_extra_kwargs': { 'reasoning_effort': 'high', } }, 'sdar': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 1.0, 'openai_extra_kwargs': { 'top_p': 1.0, }, 'extra_body': { 'top_k': 0, } }, 'sdar-32k': { 'query_per_second': 4, 'max_out_len': 32768, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 1.0, 'openai_extra_kwargs': { 'top_p': 1.0, }, 'extra_body': { 'top_k': 0, } }, 'sdar-2batch': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 2, 'temperature': 1.0, 'openai_extra_kwargs': { 'top_p': 1.0, }, 'extra_body': { 'top_k': 0, } }, 'intern-s1-pro': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.8, 'openai_extra_kwargs': { 'top_p': 0.95, }, 'extra_body': { 'top_k': 50, 'min_p': 0.0, } }, 'intern-s1-pro-32k': { 'query_per_second': 4, 'max_out_len': 32768, 'max_seq_len': 65536, 'batch_size': 500, 'temperature': 0.8, 'openai_extra_kwargs': { 'top_p': 0.95, }, 'extra_body': { 'top_k': 50, 'min_p': 0.0, } }, 'intern-s1-pro-2batch': { 'query_per_second': 4, 'max_out_len': 64000, 'max_seq_len': 65536, 'batch_size': 2, 'temperature': 0.8, 'openai_extra_kwargs': { 'top_p': 0.95, }, 'extra_body': { 'top_k': 50, 'min_p': 0.0, } } } MLLM_EVAL_CONFIGS = { 'default': {}, 'internvl': { 'repetition-penalty': 1.0, 'top-p': 0.8, 'top-k': 20, 'temperature': 0.7, } } BACKEND_LIST = ['turbomind', 'pytorch'] RESTFUL_MODEL_LIST = [ 'Qwen/Qwen3-0.6B', 'Qwen/Qwen3-VL-2B-Instruct', 'Qwen/Qwen3-30B-A3B', 'internlm/Intern-S1', 'internlm/internlm2_5-20b', 'Qwen/Qwen3-32B', 'OpenGVLab/InternVL3_5-30B-A3B', 'OpenGVLab/InternVL3-38B', 'Qwen/Qwen3-VL-8B-Instruct', 'internlm/internlm3-8b-instruct', 'meta-llama/Llama-3.2-3B-Instruct', 'Qwen/Qwen3-VL-30B-A3B-Instruct' ] RESTFUL_BASE_MODEL_LIST = [ 'Qwen/Qwen3-8B-Base', 'internlm/internlm2_5-20b', 'Qwen/Qwen3-4B', 'internlm/internlm3-8b-instruct' ] SUFFIX_INNER_AWQ = '-inner-4bits' SUFFIX_INNER_GPTQ = '-inner-gptq' SUFFIX_INNER_W8A8 = '-inner-w8a8' EVAL_RUN_CONFIG = { 'model': 'Qwen/Qwen2.5-32B-Instruct', 'backend': 'turbomind', 'communicator': 'nccl', 'quant_policy': 0, 'parallel_config': { 'tp': 2 }, 'extra_params': { 'server-name': DEFAULT_SERVER, 'session-len': 76000, 'cache-max-entry-count': 0.7 } } ================================================ FILE: autotest/utils/evaluate_utils.py ================================================ import csv import glob import json import os import subprocess import time import allure import pandas as pd from mmengine.config import Config from utils.common_utils import execute_command_with_logging from utils.config_utils import get_case_str_by_config, get_cli_str, parse_config_by_case from utils.constant import DEFAULT_PORT, DEFAULT_SERVER, EVAL_RUN_CONFIG def write_to_summary(case_name, result, msg, metrics, result_dir): status = '✅ PASS' if result else f'❌ FAIL {msg}' config = parse_config_by_case(case_name) backend = config['backend'] model = config['model'] communicator = config['communicator'] parallel_config_str = config['parallel_config'] quant_policy = config['quant_policy'] dataset_name = [] dataset_metrics = [] for key in sorted(metrics.keys()): dataset_name.append(key) dataset_metrics.append(metrics.get(key, '')) summary_dataset_name = ' | '.join(dataset_name) summary_dataset_metrics = ' | '.join(dataset_metrics) summary_file = os.environ.get('GITHUB_STEP_SUMMARY', '') md_summary_file = f'{result_dir}/summary_{case_name}.md' summary_line = f'| {model} | {quant_policy} | {backend} | {communicator} | {parallel_config_str} | {status} | {summary_dataset_metrics} |\n' # noqa: E501 write_header = not os.path.exists(md_summary_file) or os.path.getsize(md_summary_file) == 0 with open(md_summary_file, 'a') as f: if write_header: dash_line = '-----|' * (len(metrics.keys())) f.write('## Model Evaluation Results\n') f.write( f'| Model | QuantPolicy | Backend | Communicator | Parallel config | Status | {summary_dataset_name} |\n' # noqa ) f.write(f'|-------|-------------|---------|--------------|----|--------|{dash_line}\n') f.write(summary_line) if summary_file: write_header = not os.path.exists(summary_file) or os.path.getsize(summary_file) == 0 with open(summary_file, 'a') as f: if write_header: dash_line = '-----|' * (len(metrics.keys())) f.write('## Model Evaluation Results\n') f.write( f'| Model | QuantPolicy | Backend | Communicator | Parallel config | Status | {summary_dataset_name} |\n' # noqa ) f.write(f'|-------|-------------|---------|--------------|----|--------|{dash_line}\n') f.write(summary_line) else: print( f'Summary: {model} | {backend} | {communicator} | {parallel_config_str} | {status} | {summary_dataset_metrics}' # noqa: E501 ) def llm_summary(case_name, result, msg, work_dir, result_dir=None): metrics = {} if work_dir and os.path.exists(work_dir): try: summary_dirs = glob.glob(os.path.join(work_dir, '*', 'summary')) if not summary_dirs: raise FileNotFoundError('No summary directory found') summary_dir = summary_dirs[0] csv_files = glob.glob(os.path.join(summary_dir, 'summary_*.csv')) if not csv_files: raise FileNotFoundError('No CSV files found') csv_file = sorted(csv_files)[-1] if not os.path.exists(csv_file): raise FileNotFoundError('CSV file does not exist') with open(csv_file, 'r') as f: reader = csv.reader(f) next(reader) for row in reader: if len(row) < 5 or not row[4]: continue dataset = row[0] metric_value = row[4] try: metrics[dataset] = f'{float(metric_value):.2f}' # noqa: E231 except ValueError: metrics[dataset] = metric_value except Exception as e: print(f'Error reading metrics: {str(e)}') if not result_dir: result_dir = work_dir write_to_summary(case_name, result, msg, metrics, result_dir) def mllm_summary(case_name, result, msg, work_dir, result_dir=None, dataset_list=['MMBench_V11_MINI', 'MMStar_MINI', 'AI2D_MINI', 'OCRBench_MINI']): metrics = {} pattern = os.path.join(work_dir, case_name, 'T*') t_dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)] if not t_dirs: return # 按修改时间排序 t_dirs.sort(key=os.path.getmtime, reverse=True) latest_dir = t_dirs[0] for dataset in dataset_list: if dataset == 'OCRBench_MINI': score_file = f'{latest_dir}/{case_name}_{dataset}_score.json' cur_score = 0 with open(score_file, 'r') as f: total_score = json.load(f) cur_score = total_score['Final Score Norm'] metrics[dataset] = f'{cur_score:.2f}' # noqa: E231 else: score_file = f'{latest_dir}/{case_name}_{dataset}_acc.csv' df = pd.read_csv(score_file) cur_score = df['Overall'].iloc[0] if dataset == 'MMBench_V11_MINI': cur_score = df.loc[df['split'] == 'dev', 'Overall'].values cur_score = cur_score * 100 metrics[dataset] = f'{cur_score.item():.2f}' # noqa: E231 if result_dir is None: result_dir = work_dir write_to_summary(case_name, result, msg, metrics, result_dir) def eval_test(model_path, eval_path, case_name, port=DEFAULT_PORT, test_type='infer', extra_config={}, **kwargs): work_dir = None try: work_dir = os.path.join(eval_path, f'wk_{case_name}') timestamp = time.strftime('%Y%m%d_%H%M%S') eval_log = os.path.join(eval_path, f'log_{case_name}_{test_type}_{timestamp}.log') temp_config_path = os.path.join(eval_path, f'temp_{case_name}.py') current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) config_file = os.path.join(parent_dir, 'evaluate/eval_config_chat.py') print(f'Starting OpenCompass evaluation for model: {model_path}') print(f'Model path: {model_path}') print(f'Case: {case_name}') print(f'Config file: {config_file}') original_cwd = os.getcwd() os.makedirs(work_dir, exist_ok=True) test_url = f'http://{DEFAULT_SERVER}:{port}/v1' try: if test_type == 'infer': if not os.path.exists(config_file): return False, f'Config file {config_file} not found' cfg = Config.fromfile(config_file) cfg.MODEL_NAME = case_name cfg.MODEL_PATH = model_path cfg.API_BASE = test_url # noqa: E231 if cfg.models and len(cfg.models) > 0: model_cfg = cfg.models[0] model_cfg['abbr'] = case_name model_cfg['path'] = case_name model_cfg['openai_api_base'] = test_url model_cfg['tokenizer_path'] = model_path for key, value in kwargs.items(): model_cfg[key] = value cfg.NUM_WORKERS = extra_config.get('max-num-workers', 8) cfg.infer['partitioner']['num_worker'] = extra_config.get('max-num-workers', 8) cfg.dump(temp_config_path) print(f'Modified config saved to: {temp_config_path}') elif test_type == 'eval': if not os.path.exists(temp_config_path): error_msg = f'Temp config file {temp_config_path} not found for eval stage' llm_summary(case_name, False, error_msg, work_dir, eval_path) return False, error_msg cfg = Config.fromfile(temp_config_path) print(f'Using existing temp config file: {temp_config_path}') eval_run_config = EVAL_RUN_CONFIG eval_case_name = get_case_str_by_config(eval_run_config) cfg.JUDGE_API_BASE = test_url cfg.JUDGE_MODEL_PATH = model_path cfg.JUDGE_MODEL_NAME = eval_case_name if hasattr(cfg, 'judge_cfg'): cfg.judge_cfg['path'] = eval_case_name cfg.judge_cfg['abbr'] = eval_case_name cfg.judge_cfg['openai_api_base'] = test_url cfg.judge_cfg['tokenizer_path'] = model_path if hasattr(cfg, 'datasets') and cfg.datasets: for dataset in cfg.datasets: if 'eval_cfg' in dataset and 'evaluator' in dataset['eval_cfg']: evaluator = dataset['eval_cfg']['evaluator'] if 'judge_cfg' in evaluator: evaluator['judge_cfg']['abbr'] = cfg.JUDGE_MODEL_NAME evaluator['judge_cfg']['path'] = cfg.JUDGE_MODEL_NAME evaluator['judge_cfg']['openai_api_base'] = cfg.JUDGE_API_BASE evaluator['judge_cfg']['tokenizer_path'] = cfg.JUDGE_MODEL_PATH if 'llm_evaluator' in evaluator and 'judge_cfg' in evaluator['llm_evaluator']: evaluator['llm_evaluator']['judge_cfg']['abbr'] = cfg.JUDGE_MODEL_NAME evaluator['llm_evaluator']['judge_cfg']['path'] = cfg.JUDGE_MODEL_NAME evaluator['llm_evaluator']['judge_cfg']['openai_api_base'] = cfg.JUDGE_API_BASE evaluator['llm_evaluator']['judge_cfg']['tokenizer_path'] = cfg.JUDGE_MODEL_PATH cfg.dump(temp_config_path) print(f'Modified config for eval stage saved to: {temp_config_path}') extra_config_str = get_cli_str(extra_config) cmd = f'opencompass {temp_config_path} --reuse -w {work_dir} -m {test_type} --dump-res-length {extra_config_str}' # noqa print(f'Running command: {cmd}') print(f'Work directory: {work_dir}') result, stderr = execute_command_with_logging(cmd, eval_log, timeout=259200) allure.attach.file(eval_log, name=eval_log, attachment_type=allure.attachment_type.TEXT) if test_type == 'eval': llm_summary(case_name, result, stderr, work_dir, eval_path) return result, stderr except Exception as e: print(f'Error occurred: {e}') return False, f'Error occurred: {e}' finally: os.chdir(original_cwd) print(f'Returned to directory: {original_cwd}') except subprocess.TimeoutExpired: timeout_msg = (f'Evaluation timed out for {model_path} ' f'after 259200 seconds') if work_dir and test_type == 'eval': llm_summary(case_name, False, timeout_msg, work_dir, eval_path) return False, timeout_msg except Exception as e: error_msg = f'Error during evaluation for {model_path}: {str(e)}' if work_dir and test_type == 'eval': llm_summary(case_name, False, error_msg, work_dir, eval_path) return False, error_msg def mllm_eval_test(model_path, eval_path, case_name, port=DEFAULT_PORT, test_type='infer', extra_config={}): work_dir = os.path.join(eval_path, f'wk_{case_name}') timestamp = time.strftime('%Y%m%d_%H%M%S') eval_log = os.path.join(eval_path, f'log_{case_name}_{timestamp}.log') print(f'Starting VLMEvalKit evaluation for model: {model_path}') print(f'Model path: {model_path}') print(f'Case: {case_name}') print(f'Work directory: {work_dir}') os.makedirs(work_dir, exist_ok=True) extra_config_str = get_cli_str(extra_config) if test_type == 'infer': cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:{port}/v1 --reuse --work-dir {work_dir} --mode infer {extra_config_str}' # noqa elif test_type == 'eval': cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:empty/v1 --reuse --work-dir {work_dir} --api-nproc 32 --mode eval --judge turbomind_Qwen2.5-32B-Instruct_nccl_tp2_0 --judge-base-url http://{DEFAULT_SERVER}:{port}/v1' # noqa result, msg = execute_command_with_logging(cmd, eval_log) allure.attach.file(eval_log, name=eval_log, attachment_type=allure.attachment_type.TEXT) if test_type == 'eval': mllm_summary(case_name, result, msg, work_dir, eval_path, dataset_list=['MMBench_V11_MINI', 'MMStar_MINI', 'AI2D_MINI', 'OCRBench_MINI']) return result, msg ================================================ FILE: autotest/utils/get_run_config.py ================================================ from lmdeploy.model import MODELS # Deprecated function def get_model_name(model): model_names = ['llama', 'llama2', 'llama3', 'internlm', 'internlm2', 'baichuan2', 'chatglm2', 'yi', 'qwen'] model_names += list(MODELS.module_dict.keys()) model_names.sort() model_name = _simple_model_name(model) model_name = model_name.lower() if model_name in model_names: return model_name if model_name in model_names: return model_name if ('llama-2' in model_name): return 'llama2' if ('llama-3-1' in model_name): return 'llama3_1' if ('llama-3' in model_name): return 'llama3' if 'vicuna' in model_name and 'llava' not in model_name: return 'vicuna' if 'llava' in model_name and 'v1' in model_name and 'v1.6-34b' not in model_name and 'mistral' not in model_name: return 'llava-v1' if 'llava' in model_name and 'v1.6-34b' in model_name: return 'llava-chatml' if 'internvl-chat' in model_name and 'v1-2' in model_name: return 'internvl-zh-hermes2' elif 'llava-1.5' in model_name: return 'llava-v1' if ('yi-vl' in model_name): return 'yi-vl' if ('qwen' in model_name): return 'qwen' if ('internvl') in model_name: return 'internvl-internlm2' if ('internlm2') in model_name: return 'internlm2' if ('internlm-xcomposer2d5') in model_name: return 'internlm-xcomposer2d5' if ('internlm-xcomposer2') in model_name: return 'internlm-xcomposer2' if ('glm-4') in model_name: return 'glm4' if len(model_name.split('-')) > 2 and '-'.join(model_name.split('-')[0:2]) in model_names: return '-'.join(model_name.split('-')[0:2]) return model_name.split('-')[0] def _simple_model_name(model): if '/' in model: model_name = model.split('/')[1] else: model_name = model model_name = model_name.replace('-inner-4bits', '') model_name = model_name.replace('-inner-w8a8', '') model_name = model_name.replace('-4bits', '') return model_name ================================================ FILE: autotest/utils/mp_log_utils.py ================================================ import os import allure from pytest_assume.plugin import assume def write_log(config, result, msg, is_new: bool = True, case_path_tag: str = 'default'): try: log_path = os.path.join(config.get('log_path'), case_path_tag) if is_new: file = open(log_path, 'w') else: file = open(log_path, 'a') file.writelines('result:' + result + ', reason:' + msg + '\n') file.close() except Exception as e: return False, None, f'Unknown error: {e}' def assert_log(config, case_path_tag: str = 'default'): log_path = os.path.join(config.get('log_path'), case_path_tag) with open(log_path, 'r') as f: lines = f.readlines() for line in lines: if 'result:False, reason:' in line: result = False msg = line break if 'result:True, reason:' in line and not result: result = True allure.attach.file(log_path, name=log_path, attachment_type=allure.attachment_type.TEXT) with assume: assert result, msg ================================================ FILE: autotest/utils/pipeline_chat.py ================================================ import json import os import shutil import time import allure from pytest_assume.plugin import assume from utils.common_utils import execute_command_with_logging from utils.config_utils import get_case_str_by_config, get_cuda_prefix_by_workerid, get_workerid, resolve_extra_params from utils.rule_condition_assert import assert_result def run_pipeline_llm_test(config, run_config, common_case_config, worker_id: str = '', is_smoke: bool = False): model = run_config.get('model') if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True': model_path = model else: model_path = os.path.join(config.get('model_path'), model) log_path = config.get('log_path') case_name = get_case_str_by_config(run_config) timestamp = time.strftime('%Y%m%d_%H%M%S') pipeline_log = os.path.join(log_path, f'pipeline_llm_{case_name}_{timestamp}.log') env = os.environ.copy() env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500) env.update(run_config.get('env', {})) run_config_bk = run_config.copy() run_config_bk.pop('env', None) run_config_bk.pop('model', None) resolve_extra_params(run_config_bk.get('extra_params', {}), config.get('model_path')) run_config_string = json.dumps(run_config_bk, ensure_ascii=False, indent=None) run_config_string = run_config_string.replace(' ', '').replace('"', '\\"').replace(',', '\\,') cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) cmd = f'{cuda_prefix} python3 autotest/tools/pipeline/llm_case.py run_pipeline_chat_test {model_path} {run_config_string} autotest/prompt_case.yml {is_smoke}' # noqa E501 result, stderr = execute_command_with_logging(cmd, pipeline_log, timeout=1800, env=env) with assume: assert result, stderr with open(pipeline_log, 'r', encoding='utf-8') as file: output_text = file.read() with open(pipeline_log, 'a') as file: for case in common_case_config.keys(): if is_smoke and case != 'memory_test': continue if case != 'code_testcase' and 'code' in model_path.lower(): continue with allure.step(case): case_info = common_case_config.get(case) case_result = True reason = '' for prompt_detail in case_info: prompt = list(prompt_detail.keys())[0] case_result, reason = assert_result(get_response_from_output_by_prompt(output_text, case, prompt), prompt_detail.values(), model_path) if not case_result: print(f'{case} result: {case_result}, reason: {reason} \n') file.writelines(f'{case} result: {case_result}, reason: {reason} \n') with assume: assert case_result, reason allure.attach.file(pipeline_log, name=pipeline_log, attachment_type=allure.attachment_type.TEXT) def run_pipeline_mllm_test(config, run_config, worker_id: str = '', is_smoke: bool = False): model = run_config.get('model') if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True': model_path = model else: model_path = os.path.join(config.get('model_path'), model) log_path = config.get('log_path') case_name = get_case_str_by_config(run_config) timestamp = time.strftime('%Y%m%d_%H%M%S') pipeline_log = os.path.join(log_path, f'pipeline_mllm_{case_name}_{timestamp}.log') env = os.environ.copy() env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500) env.update(run_config.get('env', {})) run_config_bk = run_config.copy() run_config_bk.pop('env', None) run_config_bk.pop('model', None) run_config_string = json.dumps(run_config_bk, ensure_ascii=False, indent=None) run_config_string = run_config_string.replace(' ', '').replace('"', '\\"').replace(',', '\\,') cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) resource_path = config.get('resource_path') cmd = f'{cuda_prefix} python3 autotest/tools/pipeline/mllm_case.py run_pipeline_mllm_test {model_path} {run_config_string} {resource_path} {is_smoke}' # noqa E501 result, stderr = execute_command_with_logging(cmd, pipeline_log, timeout=1800, env=env, should_print=False) with assume: assert result, stderr with open(pipeline_log, 'r', encoding='utf-8') as file: output_text = file.read() with open(pipeline_log, 'a') as file: with allure.step('single1 pic'): response = get_response_from_output(output_text, 'single1') case_result = any(word in response.lower() for word in ['tiger', '虎']) file.writelines(f'single1 pic result: {case_result} reason: simple example tiger should in {response} \n') with assume: assert case_result, f'reason: simple example tiger should in {response}' with allure.step('single2 pic'): response = get_response_from_output(output_text, 'single2') case_result = any(word in response.lower() for word in ['tiger', '虎']) file.writelines(f'single2 pic result: {case_result} reason: simple example tiger should in {response} \n') with assume: assert case_result, f'reason: simple example tiger should in {response}' with allure.step('multi-imagese'): response = get_response_from_output(output_text, 'multi-imagese') case_result = any(word in response.lower() for word in ['tiger', '虎', '滑雪', 'ski']) file.writelines(f'multi-imagese pic result: {case_result} reason: tiger or ski should in {response} \n') with assume: assert case_result, f'reason: Multi-images example: tiger or ski should in {response}' with allure.step('batch-example1'): response = get_response_from_output(output_text, 'batch-example1') case_result = any(word in response.lower() for word in ['滑雪', 'ski']) file.writelines(f'batch-example1 pic result: {case_result} reason: ski should in {response} \n') with assume: assert case_result, f'reason: batch-example1: ski should in {response}' with allure.step('batch-example2'): response = get_response_from_output(output_text, 'batch-example2') case_result = any(word in response.lower() for word in ['tiger', '虎']) file.writelines(f'batch-example2 pic result: {case_result} reason: tiger should in {response} \n') with assume: assert case_result, f'reason: batch-example1: tiger should in {response}' with allure.step('multi-turn1'): response = get_response_from_output(output_text, 'multi-turn1') case_result = any(word in response.lower() for word in ['滑雪', 'ski']) file.writelines(f'multi-turn1 pic result: {case_result} reason: ski should in {response} \n') with assume: assert case_result, f'reason: batch-example1: ski should in {response}' with allure.step('multi-turn2'): response = get_response_from_output(output_text, 'multi-turn2') case_result = any(word in response.lower() for word in ['滑雪', 'ski']) file.writelines(f'multi-turn2 pic result: {case_result} reason: ski should in {response} \n') with assume: assert case_result, f'reason: batch-example1: ski should in {response}' if not is_smoke: if 'internvl' in model.lower() and 'internvl2-4b' not in model.lower(): internvl_vl_testcase(output_text, file) internvl_vl_testcase(output_text, file, 'cn') if 'minicpm' in model.lower(): MiniCPM_vl_testcase(output_text, file) if 'qwen' in model.lower(): Qwen_vl_testcase(output_text, file) with open(pipeline_log, 'r', encoding='utf-8') as file: output_text = file.read() print(output_text) allure.attach.file(pipeline_log, name=pipeline_log, attachment_type=allure.attachment_type.TEXT) def get_response_from_output(output_text, case): return output_text.split(f'[caseresult {case} start]')[1].split(f'[caseresult {case} end]')[0] def get_response_from_output_by_prompt(output_text, case, prompt): output_list = output_text.split(f'[caseresult {case} start]')[1].split(f'[caseresult {case} end]')[0] output_dict = json.loads(output_list.rstrip()) for output in output_dict: if output.get('prompt') == prompt: return output.get('response') return None def assert_pipeline_single_return(output, logprobs_num: int = 0): result = assert_pipeline_single_element(output, is_last=True, logprobs_num=logprobs_num) if not result: return result, 'single_stream_element is wrong' return result & (len(output.token_ids) == output.generate_token_len or len(output.token_ids) == output.generate_token_len - 1), 'token_is len is not correct' def assert_pipeline_batch_return(output, size: int = 1): if len(output) != size: return False, 'length is not correct' for single_output in output: result, msg = assert_pipeline_single_return(single_output) if not result: return result, msg return True, '' def assert_pipeline_single_stream_return(output, logprobs_num: int = 0): for i in range(0, len(output) - 2): if not assert_pipeline_single_element(output[i], is_stream=True, logprobs_num=logprobs_num): return False, f'single_stream_element is false, index is {i}' if assert_pipeline_single_element(output[-1], is_stream=True, is_last=True, logprobs_num=logprobs_num) is False: return False, 'last single_stream_element is false' return True, '' def assert_pipeline_batch_stream_return(output, size: int = 1): for i in range(size): output_list = [item for item in output if item.index == i] result, msg = assert_pipeline_single_stream_return(output_list) if not result: return result, msg return True, '' def assert_pipeline_single_element(output, is_stream: bool = False, is_last: bool = False, logprobs_num: int = 0): result = True result &= output.generate_token_len > 0 result &= output.input_token_len > 0 result &= output.index >= 0 if is_last: result &= output.text is not None result &= output.finish_reason in ['stop', 'length'] if is_stream: result &= output.token_ids is None or output.token_ids == [] else: result &= len(output.token_ids) > 0 else: result &= len(output.text) > 0 result &= output.finish_reason is None result &= len(output.token_ids) > 0 if logprobs_num == 0 or (is_last and is_stream): result &= output.logprobs is None else: if is_stream: result &= len(output.logprobs) >= 1 else: result &= len(output.logprobs) == output.generate_token_len or len( output.logprobs) == output.generate_token_len + 1 if result: for content in output.logprobs: result &= len(content.keys()) <= logprobs_num for key in content.keys(): result &= isinstance(content.get(key), float) return result def internvl_vl_testcase(output_text, file, lang: str = 'en'): with allure.step(f'internvl-combined-images-{lang}'): response = get_response_from_output(output_text, f'internvl-combined-images-{lang}') case_result = any(word in response.lower() for word in ['panda', '熊猫']) file.writelines(f'internvl-combined-images-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: combined images: panda should in {response}' with allure.step(f'internvl-combined-images2-{lang}'): response = get_response_from_output(output_text, f'internvl-combined-images2-{lang}') case_result = any(word in response.lower() for word in ['panda', '熊猫']) file.writelines( f'internvl-combined-images2-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: combined images2: panda should in {response}' with allure.step(f'internvl-separate-images-{lang}'): response = get_response_from_output(output_text, f'internvl-separate-images-{lang}') case_result = any(word in response.lower() for word in ['panda', '熊猫', 'same', 'different', 'eat', 'cute']) file.writelines(f'internvl-separate-images-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: separate images: panda should in {response}' with allure.step(f'internvl-separate-images2-{lang}'): response = get_response_from_output(output_text, f'internvl-separate-images2-{lang}') case_result = any(word in response.lower() for word in ['panda', '熊猫', 'same', 'different', 'difference', 'identical']) file.writelines( f'internvl-separate-images2-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: separate images2: panda should in {response}' with allure.step(f'internvl-video-{lang}'): response = get_response_from_output(output_text, f'internvl-video-{lang}') case_result = any(word in response.lower() for word in ['red panda', 'eat', '熊猫', '竹子', 'food', 'hold']) file.writelines(f'internvl-video-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: video: panda should in {response}' with allure.step(f'internvl-video2-{lang}'): response = get_response_from_output(output_text, f'internvl-video2-{lang}') case_result = any(word in response.lower() for word in ['red panda', 'eat', '熊猫', '竹子']) file.writelines(f'internvl-video2-{lang} result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: video2: panda should in {response}' def MiniCPM_vl_testcase(output_text, file): with allure.step('minicpm-combined-images'): response = get_response_from_output(output_text, 'minicpm-combined-images') case_result = any(word in response.lower() for word in ['panda', '熊猫']) file.writelines(f'minicpm-combined-images result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: combined images: panda should in {response}' with allure.step('minicpm-combined-images2'): response = get_response_from_output(output_text, 'minicpm-combined-images2') case_result = any(word in response.lower() for word in ['panda', '熊猫']) file.writelines(f'minicpm-combined-images2 result: {case_result}, reason: panda should in {response} \n') with assume: assert case_result, f'reason: combined images2: panda should in {response}' with allure.step('minicpm-fewshot'): response = get_response_from_output(output_text, 'minicpm-fewshot') case_result = any(word in response.lower() for word in ['2021', '14']) file.writelines(f'minicpm-fewshot result: {case_result} reason: 2021 or 14 should in {response} \n') with assume: assert case_result, f'reason: fewshot: 2021 or 14 should in {response}' with allure.step('minicpm-video'): response = get_response_from_output(output_text, 'minicpm-video') case_result = any(word in response.lower() for word in ['red panda', '熊猫']) file.writelines(f'minicpm-video result: {case_result} reason: video: panda should in {response} \n') with assume: assert case_result, f'reason: video: panda should in {response}' def Qwen_vl_testcase(output_text, file): with allure.step('qwen-combined-images'): response = get_response_from_output(output_text, 'qwen-combined-images') case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city']) file.writelines(f'qwen-combined-images result: {case_result}, reason: buildings should in {response} \n') with assume: assert case_result, f'reason: combined images: buildings should in {response}' with allure.step('qwen-combined-images2'): response = get_response_from_output(output_text, 'qwen-combined-images2') case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city']) file.writelines(f'qwen-combined-images2 result: {case_result}, reason: buildings should in {response} \n') with assume: assert case_result, f'reason: combined images2: buildings should in {response}' with allure.step('qwen-performance-images'): response = get_response_from_output(output_text, 'qwen-performance-images') case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city']) file.writelines(f'qwen-performance-images result: {case_result}, reason: buildings should in {response} \n') with assume: assert case_result, f'reason: performance images: buildings should in {response}' with allure.step('qwen-performance-images2'): response = get_response_from_output(output_text, 'qwen-performance-images2') case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city']) file.writelines(f'qwen-performance-images2 result: {case_result}, reason: buildings should in {response} \n') with assume: assert case_result, f'reason: performance images2: buildings should in {response}' def save_pipeline_common_log(config, log_name, result, content, msg: str = '', write_type: str = 'w'): log_path = config.get('log_path') config_log = os.path.join(log_path, log_name) file = open(config_log, write_type) file.writelines(f'result:{result}, reason: {msg}, content: {content}') # noqa E231 file.close() def assert_pipeline_common_log(config, log_name): log_path = config.get('log_path') config_log = os.path.join(log_path, log_name) allure.attach.file(config_log, name=config_log, attachment_type=allure.attachment_type.TEXT) msg = 'result is empty, please check again' result = False with open(config_log, 'r') as f: lines = f.readlines() for line in lines: if 'result:False, reason:' in line: result = False msg = line break if 'result:True, reason:' in line and not result: result = True msg = '' try: if os.path.isfile(config_log): os.remove(config_log) elif os.path.isdir(config_log): shutil.rmtree(config_log) except OSError: pass # Ignore errors when removing log file assert result, msg ================================================ FILE: autotest/utils/proxy_distributed_utils.py ================================================ import os import random import socket import subprocess import time from typing import Any import requests from utils.config_utils import get_case_str_by_config, get_cli_common_param, resolve_extra_params from utils.ray_distributed_utils import verify_service_functionality time_time = time.time DEFAULT_PROXY_PORT = 8000 WORKER_WAIT_INTERVAL = 15 # seconds def is_port_open(host: str, port: int, timeout: float = 1.0) -> bool: """Check if a port is open.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(timeout) try: s.connect((host, port)) return True except (socket.timeout, ConnectionRefusedError, OSError): return False def check_nodes_status(host: str, proxy_port: int, model_name: str, expected_instances: int, check_count: int, current_time: float, last_progress_print: float, progress_print_interval: int) -> tuple[bool, int]: try: nodes_url = f'http://{host}:{proxy_port}/nodes/status' resp = requests.get(nodes_url, timeout=10) if resp.status_code != 200: if current_time - last_progress_print >= progress_print_interval: print(f'🔧 Check {check_count}: Failed to get node status, status code: {resp.status_code}') return False, 0 nodes_data = resp.json() ready_instances = 0 total_instances = len(nodes_data) for node_info in nodes_data.values(): models = node_info.get('models', []) if model_name in models: ready_instances += 1 should_print = current_time - last_progress_print >= progress_print_interval if should_print: basename = os.path.basename(model_name) print(f'📊 Check {check_count}: Model registration progress: ' f'{ready_instances}/{expected_instances} instances ready ' f'(Total reported: {total_instances})') for node_url, node_info in nodes_data.items(): models = node_info.get('models', []) if model_name in models: print(f' ✅ Instance {node_url} registered model {basename}') else: print(f' ⏳ Instance {node_url} has not registered target model') if ready_instances >= expected_instances: if should_print: print(f'🎯 All {expected_instances} API server instances have registered the target model') return True, ready_instances else: if should_print: print(f'⏳ Waiting for more instances to register... ({ready_instances}/{expected_instances})') return False, ready_instances except Exception as e: if current_time - last_progress_print >= progress_print_interval: print(f'🔧 Check {check_count}: Exception getting node status - {e}') return False, 0 def wait_for_model_service_ready(host: str, proxy_port: int, model_name: str, timeout_seconds: int = 2000, expected_instances: int = None) -> bool: if expected_instances: print(f'⏳ Waiting for model service to be fully ready (Model: {model_name}), ' f'expected instances: {expected_instances}, timeout: {timeout_seconds}s') else: print(f'⏳ Waiting for model service to be fully ready (Model: {model_name}), ' f'timeout: {timeout_seconds}s') start_time = time_time() check_count = 0 last_progress_print = 0 progress_print_interval = 30 initial_delay = random.uniform(1, 5) time.sleep(initial_delay) while time_time() - start_time < timeout_seconds: check_count += 1 current_time = time_time() try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(5) if sock.connect_ex((host, proxy_port)) != 0: if current_time - last_progress_print >= progress_print_interval: print(f'🔌 Check {check_count}: proxy port not ready') last_progress_print = current_time time.sleep(10) continue if expected_instances: instances_ready, ready_count = check_nodes_status(host, proxy_port, model_name, expected_instances, check_count, current_time, last_progress_print, progress_print_interval) if not instances_ready: if ready_count is not None and current_time - last_progress_print >= progress_print_interval: last_progress_print = current_time time.sleep(10) continue service_ready = verify_service_functionality(host, proxy_port, model_name, check_count) if service_ready: if expected_instances: print(f'✅ All {expected_instances} API server instances are ready and service is functional!') else: print('✅ Model service is fully ready!') return True except requests.exceptions.RequestException as e: if current_time - last_progress_print >= progress_print_interval: print(f'🔧 Check {check_count}: Request exception - {e}') last_progress_print = current_time except Exception as e: if current_time - last_progress_print >= progress_print_interval: print(f'🔧 Check {check_count}: Unknown exception - {e}') last_progress_print = current_time sleep_time = 10 + random.uniform(-2, 2) time.sleep(sleep_time) print(f'❌ Model service startup timed out ({timeout_seconds} seconds)') return False def proxy_worker_node_wait(manager, timeout_minutes: int = 120): """Worker node waits by periodically checking if the master's proxy service is still alive. If the proxy becomes unreachable for several consecutive checks, assume master has finished. Args: manager: ProxyDistributedManager instance timeout_minutes: Maximum time to wait before giving up (default: 120 minutes) """ print(f'⏸️ Worker node {manager.node_rank} entering monitoring mode...') max_checks = (timeout_minutes * 60) // WORKER_WAIT_INTERVAL consecutive_failures = 0 max_consecutive_failures = 3 for i in range(max_checks): if not is_port_open(manager.master_addr, manager.proxy_port, timeout=2.0): consecutive_failures += 1 print(f'⚠️ Proxy connection to master failed ({consecutive_failures}/{max_consecutive_failures})') if consecutive_failures >= max_consecutive_failures: print('📡 Master proxy service stopped, worker node exiting') break else: consecutive_failures = 0 if i % 4 == 0: elapsed = (i * WORKER_WAIT_INTERVAL) // 60 print(f'⏳ Worker node {manager.node_rank} monitoring... Running for {elapsed} minutes') time.sleep(WORKER_WAIT_INTERVAL) else: print(f'⏰ Worker node {manager.node_rank} monitoring timed out ({timeout_minutes} minutes)') print(f'✅ Worker node {manager.node_rank} completed waiting') class ProxyDistributedManager: def __init__(self): self.master_addr = os.getenv('MASTER_ADDR', '127.0.0.1') self.node_rank = int(os.getenv('NODE_RANK', '0')) self.proxy_port = int(os.getenv('PROXY_PORT', str(DEFAULT_PROXY_PORT))) self.is_master = (self.node_rank == 0) self.proxy_process = None def start(self): if not self.is_master: return cmd = [ 'lmdeploy', 'serve', 'proxy', '--server-name', self.master_addr, '--server-port', str(self.proxy_port), '--routing-strategy', 'min_expected_latency', '--serving-strategy', 'Hybrid' ] print(f"[Proxy] Starting: {' '.join(cmd)}") self.proxy_process = subprocess.Popen(cmd) time.sleep(5) def cleanup(self): if self.proxy_process and self.proxy_process.poll() is None: print('[Proxy] Terminating proxy process...') self.proxy_process.terminate() try: self.proxy_process.wait(timeout=10) except subprocess.TimeoutExpired: self.proxy_process.kill() class ApiServerPerTest: def __init__(self, proxy_manager: ProxyDistributedManager, config: dict[str, Any], run_config: dict[str, Any]): self.proxy_manager = proxy_manager self.config = config self.run_config = run_config model_name = run_config['model'] self.model_path = os.path.join(config['model_path'], model_name) self.master_addr = proxy_manager.master_addr self.proxy_port = proxy_manager.proxy_port self.node_rank = int(os.getenv('NODE_RANK', '0')) self.node_count = int(os.getenv('NODE_COUNT', '1')) self.proc_per_node = int(os.getenv('PROC_PER_NODE', '1')) self.expected_instances = self.node_count * self.proc_per_node self.is_master = (self.node_rank == 0) self.api_process = None def start(self): proxy_url = f'http://{self.master_addr}:{self.proxy_port}' extra_params = self.run_config.get('extra_params', {}) resolve_extra_params(extra_params, self.config['model_path']) # Get model-name: use extra_params['model-name'] if specified, otherwise use case_name case_name = get_case_str_by_config(self.run_config) self.model_name = case_name if extra_params.get('model-name', None) is None else extra_params.get('model-name') cmd = [ 'lmdeploy', 'serve', 'api_server', self.model_path, '--model-name', self.model_name, ] + get_cli_common_param(self.run_config).split() + [ '--proxy-url', proxy_url, ] if self.node_count > 1: cmd += ['--nnodes', str(self.node_count), '--node-rank', str(self.node_rank)] print(f"[API Server] Starting: {' '.join(cmd)}") timestamp = time.strftime('%Y%m%d_%H%M%S') log_dir = self.config.get('server_log_path', '/tmp/lmdeploy_test') os.makedirs(log_dir, exist_ok=True) log_path = os.path.join(log_dir, f'log_{case_name}_{timestamp}.log') self._log_file = open(log_path, 'w') self.api_process = subprocess.Popen(cmd, stdout=self._log_file, stderr=self._log_file) print(f'📝 API Server log: {log_path}') def wait_until_ready(self): if not self.is_master: return success = wait_for_model_service_ready(host=self.master_addr, proxy_port=self.proxy_port, model_name=self.model_name, timeout_seconds=2000, expected_instances=self.expected_instances) if not success: raise RuntimeError(f'API Server failed to register model: {self.model_name}') def cleanup(self): if self.api_process and self.api_process.poll() is None: print(f'[API Server] Terminating for model: {self.model_path}') self.api_process.terminate() try: self.api_process.wait(timeout=15) except subprocess.TimeoutExpired: self.api_process.kill() if hasattr(self, '_log_file') and self._log_file and not self._log_file.closed: self._log_file.close() ================================================ FILE: autotest/utils/quantization_utils.py ================================================ import os import subprocess from subprocess import PIPE def quantization(config, quantization_model_name, origin_model_name, quantization_type: str = 'awq', cuda_prefix: str = 'CUDA_VISIBLE_DEVICES=0'): model_path = config.get('model_path') log_path = config.get('log_path') origin_model_path = os.path.join(config.get('model_path'), origin_model_name) quantization_model_path = os.path.join(model_path, quantization_model_name) quantization_log = os.path.join( log_path, '_'.join(['quantization', quantization_type, quantization_model_name.split('/')[1]]) + '.log') if quantization_type == 'awq': quantization_cmd = ' '.join( ['lmdeploy lite auto_awq', origin_model_path, '--work-dir', quantization_model_path]) elif quantization_type == 'gptq': quantization_cmd = ' '.join( ['lmdeploy lite auto_gptq', origin_model_path, '--work-dir', quantization_model_path]) elif quantization_type == 'w8a8': quantization_cmd = ' '.join( ['lmdeploy lite smooth_quant', origin_model_path, '--work-dir', quantization_model_path]) else: return False, 'quantization type should in [awq, gptq, w8a8], \ now the type is ' + quantization_type # Add device option if specified in environment device = os.environ.get('DEVICE', '') if device == 'ascend': quantization_cmd += ' --device npu ' if cuda_prefix is not None: quantization_cmd = ' '.join([cuda_prefix, quantization_cmd]) if 'llama-3' in origin_model_name.lower(): quantization_cmd += ' --search-scale' if quantization_type == 'gptq' or str(config.get('env_tag')) == '3090' or str(config.get('env_tag')) == '5080': quantization_cmd += ' --batch-size 8' else: quantization_cmd += ' --batch-size 32' with open(quantization_log, 'w') as f: # remove existing folder subprocess.run([' '.join(['rm -rf', quantization_model_path])], stdout=f, stderr=f, shell=True, text=True, encoding='utf-8') f.writelines('reproduce command quantization_cmd: ' + quantization_cmd + '\n') print('reproduce command quantization_cmd: ' + quantization_cmd) # quantization quantizationRes = subprocess.run([quantization_cmd], stdout=f, stderr=PIPE, shell=True, text=True, encoding='utf-8', errors='replace') f.writelines(quantizationRes.stderr) result = quantizationRes.returncode == 0 return result, quantizationRes.stderr ================================================ FILE: autotest/utils/ray_distributed_utils.py ================================================ import os import random import socket import subprocess import time from time import time as time_time from typing import Any import requests from utils.config_utils import get_case_str_by_config, get_cli_common_param, resolve_extra_params # Default constants LM_DEPLOY_API_PORT = 8000 RAY_PORT = 6379 HEALTH_CHECK_TIMEOUT = 30 CONNECTION_CHECK_TIMEOUT = 5 WORKER_WAIT_INTERVAL = 30 def wait_for_model_service_ready( host: str, api_port: int, model_name: str, timeout_seconds: int = 1000, ) -> bool: """Wait for LMDeploy API Server to be ready and verify basic functionality. No longer checks multi-node registration (API Server is a single-point service). """ print(f'⏳ Waiting for LMDeploy API Server to be ready (Model: {model_name}), Timeout: {timeout_seconds}s') start_time = time_time() check_count = 0 last_progress_print = 0 progress_print_interval = 30 # Random initial delay to avoid multiple clients requesting simultaneously time.sleep(random.uniform(1, 5)) while time_time() - start_time < timeout_seconds: check_count += 1 current_time = time_time() try: # Check if port is open with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(5) if sock.connect_ex((host, api_port)) != 0: if current_time - last_progress_print >= progress_print_interval: print(f'🔌 Check {check_count}: API port {api_port} not ready') last_progress_print = current_time time.sleep(10) continue # Verify service functionality if verify_service_functionality(host, api_port, model_name, check_count): print('✅ LMDeploy API Server is fully ready!') return True except Exception as e: if current_time - last_progress_print >= progress_print_interval: print(f'🔧 Check {check_count}: Exception - {e}') last_progress_print = current_time sleep_time = 10 + random.uniform(-2, 2) time.sleep(sleep_time) print(f'❌ LMDeploy API Server startup timed out ({timeout_seconds} seconds)') return False def verify_service_functionality(host: str, api_port: int, model_name: str, check_count: int) -> bool: """Verify that the API Server can respond to basic requests.""" try: test_data = { 'model': model_name, 'messages': [{ 'role': 'user', 'content': 'hi' }], 'max_tokens': 5, 'stream': False } resp = requests.post(f'http://{host}:{api_port}/v1/chat/completions', json=test_data, timeout=15) if resp.status_code == 200: print(f'✅ Check {check_count}: Service functionality normal (received valid response)') return True elif resp.status_code == 400: print(f'✅ Check {check_count}: Service framework activated (received 400)') return True else: print(f'🔧 Check {check_count}: Service test failed, status code: {resp.status_code}') return False except requests.exceptions.RequestException as e: print(f'🔧 Check {check_count}: Service test exception - {e}') return False class RayLMDeployManager: def __init__( self, master_addr: str, ray_port: int = RAY_PORT, api_port: int = LM_DEPLOY_API_PORT, log_dir: str = '.', health_check: bool = True, ): self.master_addr = master_addr self.ray_port = ray_port self.api_port = api_port self.log_dir = log_dir self.health_check = health_check self._cleaned = False # Determine if this is the master node (via environment variable NODE_RANK) self.node_rank = int(os.getenv('NODE_RANK', '0')) self.is_master = (self.node_rank == 0) os.makedirs(self.log_dir, exist_ok=True) print(f'📝 Node {self.node_rank} log directory: {self.log_dir}') # Print cluster information self.node_count = int(os.getenv('NODE_COUNT', '1')) self.job_id = os.getenv('JOB_ID', 'unknown') print(f'🎯 Node {self.node_rank} cluster information:') print(f'- Total nodes: {self.node_count}') print(f"- Role: {'Master node' if self.is_master else 'Worker node'}") print(f'- Master address: {self.master_addr}') print(f'- Ray port: {self.ray_port}') print(f'- API port: {self.api_port}') print(f'- Job ID: {self.job_id}') def start_ray_cluster(self): """Start or join Ray cluster.""" if self.is_master: cmd = ['ray', 'start', '--head', '--port', str(self.ray_port)] print(f'🚀 Master node starting Ray cluster (Port: {self.ray_port})') else: cmd = ['ray', 'start', '--address', f'{self.master_addr}:{self.ray_port}'] print(f'🔌 Worker node {self.node_rank} joining Ray cluster: {self.master_addr}:{self.ray_port}') try: subprocess.run(cmd, capture_output=True, text=True, check=True) print('✅ Ray started successfully') except subprocess.CalledProcessError as e: print(f'💥 Ray startup failed: {e.stderr}') raise def start_lmdeploy_api_server(self, config: dict[str, Any], run_config: dict[str, Any]) -> None: """ Master node: Start LMDeploy API Server and wait for it to be ready. Worker nodes: Do not start the service, only verify that the master node's API Server is ready. """ # Derive model_path from config and run_config model_path = os.path.join(config['model_path'], run_config['model']) extra_params = run_config.get('extra_params', {}) resolve_extra_params(extra_params, config['model_path']) # Get model-name: use extra_params['model-name'] if specified, otherwise use case_name case_name = get_case_str_by_config(run_config) extra_params = run_config.get('extra_params', {}) model_name = case_name if extra_params.get('model-name', None) is None else extra_params.get('model-name') if self.is_master: # === Master node logic: Start service === timestamp = time.strftime('%Y%m%d_%H%M%S') log_path = os.path.join(self.log_dir, f'log_{model_name}_{timestamp}.log') cmd = [ 'lmdeploy', 'serve', 'api_server', model_path, '--server-port', str(self.api_port), '--model-name', model_name, ] + get_cli_common_param(run_config).split() print(f"🚀 Master node starting LMDeploy API Server: {' '.join(cmd)}") self._log_file = open(log_path, 'w') self._api_process = subprocess.Popen(cmd, stdout=self._log_file, stderr=self._log_file) print(f'📝 API Server log: {log_path}') # Wait for service to be ready if self.health_check: ready = wait_for_model_service_ready(host=self.master_addr, api_port=self.api_port, model_name=model_name, timeout_seconds=1000) if not ready: print('❌ API Server failed to be ready, terminating process') self._api_process.terminate() try: self._api_process.wait(timeout=10) except subprocess.TimeoutExpired: self._api_process.kill() raise RuntimeError('LMDeploy API Server failed to start') else: # === Worker node logic: Only verify that the master node service is ready === print(f'🔍 Worker node {self.node_rank} is verifying that the master node ' f'({self.master_addr}:{self.api_port}) API Server is ready...') if self.health_check: ready = wait_for_model_service_ready(host=self.master_addr, api_port=self.api_port, model_name=model_name, timeout_seconds=1000) if not ready: raise RuntimeError(f'Worker node {self.node_rank}: Master node API Server not ready ' f'within 1000 seconds, cannot continue') else: print('⚠️ health_check=False, skipping API Server readiness check (not recommended)') def cleanup(self, force: bool = True): """Clean up resources. Args: force (bool): - False: Only stop LMDeploy API Server (used after individual test completion) - True: Stop API Server + Ray cluster (used for final cleanup at session end) """ if self._cleaned and force: # Note: If this is just an intermediate cleanup with force=False, we shouldn't skip due to _cleaned # So only skip when force=True and already cleaned return print(f'🧹 Node {self.node_rank} cleaning resources... (force={force})') # Stop API Server (master node only) if hasattr(self, '_api_process') and self._api_process.poll() is None: self._api_process.terminate() try: self._api_process.wait(timeout=10) except subprocess.TimeoutExpired: self._api_process.kill() print('✅ LMDeploy API Server stopped') # Note: We don't clear the _api_process attribute here so it can be checked later if hasattr(self, '_log_file') and self._log_file and not self._log_file.closed: self._log_file.close() # Stop Ray (only when force=True) if force: try: subprocess.run(['ray', 'stop', '--force'], check=False, capture_output=True) print('✅ Ray cluster stopped') except Exception as e: print(f'⚠️ Ray stop exception: {e}') self._cleaned = True # Only mark as "fully cleaned" when force=True def get_cluster_info(self) -> dict[str, Any]: return { 'node_rank': self.node_rank, 'node_count': self.node_count, 'master_addr': self.master_addr, 'ray_port': self.ray_port, 'api_port': self.api_port, 'is_master': self.is_master, 'job_id': self.job_id, } def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup() def ray_worker_node_wait(manager: RayLMDeployManager, timeout_minutes: int = 60): """Worker node waits for Ray master node (Head Node) to be alive (by detecting GCS service port)""" if manager.is_master: return print(f'⏸️ Worker node {manager.node_rank} entering wait mode...') max_checks = (timeout_minutes * 60) // WORKER_WAIT_INTERVAL consecutive_failures = 0 max_consecutive_failures = 3 for i in range(max_checks): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(CONNECTION_CHECK_TIMEOUT) if sock.connect_ex((manager.master_addr, RAY_PORT)) == 0: consecutive_failures = 0 else: consecutive_failures += 1 except Exception: consecutive_failures += 1 if consecutive_failures >= max_consecutive_failures: print('📡 Ray master node GCS service unreachable, worker node exiting') break if i % 4 == 0: elapsed = (i * WORKER_WAIT_INTERVAL) // 60 print(f'⏳ Worker node {manager.node_rank} waiting... Running for {elapsed} minutes') time.sleep(WORKER_WAIT_INTERVAL) else: print(f'⏰ Worker node {manager.node_rank} wait timeout ({timeout_minutes} minutes)') manager.cleanup() ================================================ FILE: autotest/utils/restful_return_check.py ================================================ import re def assert_chat_completions_batch_return(output, model_name, check_logprobs: bool = False, logprobs_num: int = 5): assert_usage(output.get('usage')) assert output.get('id') is not None assert output.get('object') == 'chat.completion' assert output.get('model') == model_name output_message = output.get('choices') assert len(output_message) == 1 for message in output_message: assert message.get('finish_reason') in ['stop', 'length'] assert message.get('index') == 0 assert len(message.get('message').get('content')) > 0 assert message.get('message').get('role') == 'assistant' if check_logprobs: len(message.get('logprobs').get('content')) == output.get('usage').get('completion_tokens') for logprob in message.get('logprobs').get('content'): assert_logprobs(logprob, logprobs_num) def assert_completions_batch_return(output, model_name, check_logprobs: bool = False, logprobs_num: int = 5): assert_usage(output.get('usage')) assert output.get('id') is not None assert output.get('object') == 'text_completion' assert output.get('model') == model_name output_message = output.get('choices') assert len(output_message) == 1 for message in output_message: assert message.get('finish_reason') in ['stop', 'length'] assert message.get('index') == 0 assert len(message.get('text')) > 0 if check_logprobs: len(message.get('logprobs').get('content')) == output.get('usage').get('completion_tokens') for logprob in message.get('logprobs').get('content'): assert_logprobs(logprob, logprobs_num) def assert_usage(usage): assert usage.get('prompt_tokens') > 0 assert usage.get('total_tokens') > 0 assert usage.get('completion_tokens') > 0 assert usage.get('completion_tokens') + usage.get('prompt_tokens') == usage.get('total_tokens') def assert_logprobs(logprobs, logprobs_num): assert_logprob_element(logprobs) assert len(logprobs.get('top_logprobs')) >= 0 assert type(logprobs.get('top_logprobs')) == list assert len(logprobs.get('top_logprobs')) <= logprobs_num for logprob_element in logprobs.get('top_logprobs'): assert_logprob_element(logprob_element) def assert_logprob_element(logprob): assert len(logprob.get('token')) > 0 and type(logprob.get('token')) == str assert len(logprob.get('bytes')) > 0 and type(logprob.get('bytes')) == list assert type(logprob.get('logprob')) == float def assert_chat_completions_stream_return(output, model_name, is_last: bool = False, check_logprobs: bool = False, logprobs_num: int = 5): print(output) assert output.get('id') is not None assert output.get('object') == 'chat.completion.chunk' assert output.get('model') == model_name output_message = output.get('choices') assert len(output_message) == 1 for message in output_message: assert message.get('delta').get('role') == 'assistant' assert message.get('index') == 0 assert len(message.get('delta').get('content')) >= 0 if not is_last: assert message.get('finish_reason') is None if check_logprobs: assert (len(message.get('logprobs').get('content')) >= 1) for content in message.get('logprobs').get('content'): assert_logprobs(content, logprobs_num) if is_last is True: assert len(message.get('delta').get('content')) == 0 or 'error' in message.get('delta').get('content') assert message.get('finish_reason') in ['stop', 'length', 'error'] if check_logprobs is True: assert message.get('logprobs') is None def assert_completions_stream_return(output, model_name, is_last: bool = False, check_logprobs: bool = False, logprobs_num: int = 5): print(output) assert output.get('id') is not None assert output.get('object') == 'text_completion' assert output.get('model') == model_name output_message = output.get('choices') assert len(output_message) == 1 for message in output_message: assert message.get('index') == 0 assert len(message.get('text')) >= 0 if is_last is False: assert message.get('finish_reason') is None if check_logprobs: assert (len(message.get('logprobs').get('content')) >= 1) for content in message.get('logprobs').get('content'): assert_logprobs(content, logprobs_num) if is_last is True: assert len(message.get('text')) == 0 assert message.get('finish_reason') in ['stop', 'length'] if check_logprobs is True: assert message.get('logprobs') is None def has_repeated_fragment(text, repeat_count=5): pattern = r'(.+?)\1{' + str(repeat_count - 1) + ',}' match = re.search(pattern, text.replace('\n', '')) if match: repeated_fragment = match.group(1) start_pos = match.start() return True, {'repeated_fragment': repeated_fragment, 'position': start_pos} return False, f'{text} does not contain repeated fragments' ================================================ FILE: autotest/utils/rule_condition_assert.py ================================================ def assert_result(input, rule_condition, model_name: str = None): input = input.replace('\n', '\\n') input_lower = input.lower() for dict in rule_condition: if dict is None: return True, '' for rule in dict: operator = list(rule.keys())[0] value = list(rule.values())[0] if model_name is not None and model_name == operator: dict = value for rule in dict: operator = list(rule.keys())[0] value = list(rule.values())[0] if input is None or len(input) == 0: return False, 'response is empty' if operator == 'contain': if isinstance(value, list): tmpResult = False for word in value: if word.lower() in input_lower: tmpResult = True if not tmpResult: return False, ','.join(value) + " doesn't exist in " + input else: if value.lower() not in input_lower: msg = value + " doesn't exist in:" + input return False, msg if operator == 'not_contain': if isinstance(value, list): for word in value: if word.lower() in input_lower: msg = word + " shouldn't exist in:" + input return False, msg else: if value.lower() in input_lower: msg = value + " shouldn't exist in " + input return False, msg if operator == 'len_g': if len(input) < int(value): return False, input + ' length: ' + str(len(input)) + ', should greater than ' + str(value) return True, '' if __name__ == '__main__': input = '成都的景点hot potdddd' condition = ([[{'contain': ['hot pot']}, {'contain': ['。']}, {'len_g': [10]}]]) print(assert_result(input, condition)) ================================================ FILE: autotest/utils/run_client_chat.py ================================================ import os import time from subprocess import PIPE, Popen import allure from utils.config_utils import get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid from utils.rule_condition_assert import assert_result TEMPLATE = 'autotest/template.json' def run_tests(config, usercase, cli_case_config, run_config, worker_id): if 'coder' in run_config['model'].lower() and usercase == 'chat_testcase': usercase = 'code_testcase' hf_command_line_test(config, usercase, cli_case_config.get(usercase), run_config, cuda_prefix=get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))) def hf_command_line_test(config, case, case_info, run_config, cuda_prefix: str = ''): model = run_config.get('model') if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True': model_path = model else: model_path = os.path.join(config.get('model_path'), model) run_config['extra_params']['session_len'] = 4096 if case == 'base_testcase': run_config['extra_params']['chat_template'] = TEMPLATE run_config['extra_params']['session_len'] = 512 print(run_config) cmd = ' '.join([cuda_prefix, ' '.join(['lmdeploy chat', model_path, get_cli_common_param(run_config)])]).strip() result, chat_log, msg = command_test(config, cmd, run_config, case_info, True) if chat_log: allure.attach.file(chat_log, name=chat_log, attachment_type=allure.attachment_type.TEXT) assert result, msg def command_test(config, cmd, run_config, case_info, need_extract_output): try: log_path = config.get('log_path') case_name = get_case_str_by_config(run_config) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) chat_log = os.path.join(log_path, f'chat_{case_name}_{timestamp}.log') returncode = -1 result = True spliter = '\n\n' # join prompt together prompt = '' for item in case_info: prompt += list(item.keys())[0] + spliter prompt += 'exit' + spliter msg = '' env = os.environ.copy() env.update(run_config.get('env', {})) with Popen([cmd], stdin=PIPE, stdout=PIPE, stderr=PIPE, shell=True, text=True, encoding='utf-8', errors='replace', env=env, start_new_session=True) as proc, open(chat_log, 'a') as file: print(f'reproduce command chat: {cmd} \n') file.writelines(f'reproduce command chat: {cmd} \n') file.writelines('prompt:' + prompt + '\n') outputs, errors = proc.communicate(input=prompt) returncode = proc.returncode if returncode != 0: file.writelines('error:' + errors + '\n') result = False return result, chat_log, errors outputDialogs = parse_dialogue(outputs) file.writelines('answersize:' + str(len(outputDialogs)) + '\n') index = 0 for prompt_detail in case_info: if need_extract_output: output = extract_output(outputDialogs[index], run_config.get('model')) else: output = outputDialogs[index] case_result, reason = assert_result(output, prompt_detail.values(), run_config.get('model')) file.writelines(f'prompt: {list(prompt_detail.keys())[0]}\n') file.writelines(f'output: {output}\n') file.writelines(f'result: {case_result}, reason: {reason}\n') index += 1 if not case_result: print(f'prompt: {list(prompt_detail.keys())[0]}\n') print(f'output: {output}\n') print(f'result: {case_result}, reason: {reason}\n') msg += reason result = result and case_result file.writelines('\n\n\n' + 'full log:' + outputs + '\n') return result, chat_log, msg except Exception as e: return False, None, f'Unknown error: {e}' def parse_dialogue(inputs: str): dialogues = inputs.strip() sep = 'double enter to end input >>>' dialogues = dialogues.strip() dialogues = dialogues.split(sep) dialogues = [d.strip() for d in dialogues] return dialogues[1:-1] def extract_output(output: str, model: str): if 'Qwen' in model or 'internlm2' in model: if len(output.split('<|im_start|>assistant')) >= 2: return output.split('<|im_start|>assistant')[1] if 'Baichuan2' in model: if len(output.split('')) >= 2: return output.split('')[1] if 'internlm' in model: if len(output.split('<|Bot|>: ')) >= 2: return output.split('<|Bot|>: ')[1] if 'llama' in model or 'Llama' in model: if len(output.split('[/INST]')) >= 2: return output.split('[/INST]')[1] return output ================================================ FILE: autotest/utils/run_restful_chat.py ================================================ import json import os import subprocess import time import allure import psutil import requests from openai import OpenAI from pytest_assume.plugin import assume from utils.config_utils import (get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid, get_workerid, resolve_extra_params) from utils.constant import DEFAULT_PORT, DEFAULT_SERVER from utils.restful_return_check import assert_chat_completions_batch_return from utils.rule_condition_assert import assert_result from lmdeploy.serve.openai.api_client import APIClient BASE_HTTP_URL = f'http://{DEFAULT_SERVER}' def start_openai_service(config, run_config, worker_id, timeout: int = 1200): port = DEFAULT_PORT + get_workerid(worker_id) case_name = get_case_str_by_config(run_config) timestamp = time.strftime('%Y%m%d_%H%M%S') server_log = os.path.join(config.get('server_log_path'), f'log_{case_name}_{port}_{timestamp}.log') model = run_config.get('model') if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True': model_path = model else: model_path = os.path.join(config.get('model_path'), model) cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')) # Ensure extra_params exists before modifying if 'extra_params' not in run_config: run_config['extra_params'] = {} resolve_extra_params(run_config['extra_params'], config.get('model_path')) run_config['extra_params']['server-port'] = str(port) run_config['extra_params']['allow-terminate-by-client'] = None model_name = case_name if run_config['extra_params'].get( 'model-name', None) is None else run_config['extra_params'].pop('model-name') cmd = ' '.join([ cuda_prefix, 'lmdeploy serve api_server', model_path, get_cli_common_param(run_config), f'--model-name {model_name}' ]).strip() env = os.environ.copy() env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500) env.update(run_config.get('env', {})) file = open(server_log, 'w') print('reproduce command restful: ' + cmd) file.write('reproduce command restful: ' + cmd + '\n') startRes = subprocess.Popen(cmd, stdout=file, stderr=file, shell=True, text=True, env=env, encoding='utf-8', errors='replace', start_new_session=True) pid = startRes.pid http_url = ':'.join([BASE_HTTP_URL, str(port)]) start_time = int(time.time()) start_timeout = timeout time.sleep(5) for i in range(start_timeout): time.sleep(1) end_time = int(time.time()) total_time = end_time - start_time result = health_check(http_url, case_name) if result or total_time >= start_timeout: break try: # Check if process is still running return_code = startRes.wait(timeout=1) # Small timeout to check status if return_code != 0: with open(server_log, 'r') as f: content = f.read() print(content) return 0, content except subprocess.TimeoutExpired: continue file.close() allure.attach.file(server_log, name=server_log, attachment_type=allure.attachment_type.TEXT) return pid, '' def stop_restful_api(pid, startRes): if pid > 0: parent = psutil.Process(pid) for child in parent.children(recursive=True): child.terminate() parent.terminate() def terminate_restful_api(worker_id): port = DEFAULT_PORT + get_workerid(worker_id) http_url = ':'.join([BASE_HTTP_URL, str(port)]) response = None request_error = None try: response = requests.get(f'{http_url}/terminate') except requests.exceptions.RequestException as exc: request_error = exc if request_error is not None: assert False, f'terminate request failed: {request_error}' assert response is not None and response.status_code == 200, f'terminate with {response}' def run_all_step(log_path, case_name, cases_info, port: int = DEFAULT_PORT): http_url = ':'.join([BASE_HTTP_URL, str(port)]) model = get_model(http_url) if model is None: assert False, 'server not start correctly' for case in cases_info.keys(): if case != 'code_testcase' and 'code' in model.lower(): continue case_info = cases_info.get(case) with allure.step(case + ' restful_test - openai chat'): restful_result, restful_log, msg = open_chat_test(log_path, case_name, case_info, http_url) allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT) with assume: assert restful_result, msg def open_chat_test(log_path, case_name, case_info, url): timestamp = time.strftime('%Y%m%d_%H%M%S') restful_log = os.path.join(log_path, f'log_restful_{case_name}_{timestamp}.log') file = open(restful_log, 'w') result = True client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{url}/v1') model_name = client.models.list().data[0].id messages = [] msg = '' for prompt_detail in case_info: if not result: break prompt = list(prompt_detail.keys())[0] messages.append({'role': 'user', 'content': prompt}) file.writelines('prompt:' + prompt + '\n') outputs = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, top_p=0.8, max_completion_tokens=1024, stream=True) content_chunks = [] reasoning_content_chunks = [] for output in outputs: # Safely handle streaming chunks: choices may be empty and content may be None if not getattr(output, 'choices', None): continue choice = output.choices[0] delta = getattr(choice, 'delta', None) reasoning_content = getattr(delta, 'reasoning_content', None) if delta is not None else None content = getattr(delta, 'content', None) if delta is not None else None if reasoning_content: reasoning_content_chunks.append(reasoning_content) if content: content_chunks.append(content) reasoning_content = ''.join(reasoning_content_chunks) output_content = ''.join(content_chunks) file.writelines(f'reasoning_content :{reasoning_content}, content: {output_content}\n') messages.append({'role': 'assistant', 'content': output_content}) case_result, reason = assert_result(reasoning_content + output_content, prompt_detail.values(), model_name) file.writelines('result:' + str(case_result) + ',reason:' + reason + '\n') if not case_result: msg += reason result = result and case_result file.close() return result, restful_log, msg def health_check(url, model_name): try: api_client = APIClient(url) model_name_current = api_client.available_models[0] messages = [] messages.append({'role': 'user', 'content': '你好'}) for output in api_client.chat_completions_v1(model=model_name, messages=messages, top_k=1): if output.get('code') is not None and output.get('code') != 0: return False # Return True on first successful response return model_name == model_name_current return False # No output received except Exception: return False def get_model(url): print(url) try: api_client = APIClient(url) model_name = api_client.available_models[0] return model_name.split('/')[-1] except Exception: return None def _run_logprobs_test(port: int = DEFAULT_PORT): http_url = ':'.join([BASE_HTTP_URL, str(port)]) api_client = APIClient(http_url) model_name = api_client.available_models[0] output = None for output in api_client.chat_completions_v1(model=model_name, messages='Hi, pls intro yourself', max_tokens=5, temperature=0.01, logprobs=True, top_logprobs=10): continue if output is None: assert False, 'No output received from logprobs test' print(output) assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10) assert output.get('choices')[0].get('finish_reason') == 'length' assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5 PIC = 'tiger.jpeg' # noqa E501 PIC2 = 'human-pose.jpg' # noqa E501 def run_vl_testcase(log_path, resource_path, port: int = DEFAULT_PORT): http_url = ':'.join([BASE_HTTP_URL, str(port)]) model = get_model(http_url) if model is None: assert False, 'server not start correctly' client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1') model_name = client.models.list().data[0].id timestamp = time.strftime('%Y%m%d_%H%M%S') simple_model_name = model_name.split('/')[-1] restful_log = os.path.join(log_path, f'restful_vl_{simple_model_name}_{str(port)}_{timestamp}.log') # noqa file = open(restful_log, 'w') prompt_messages = [{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': f'{resource_path}/{PIC}', }, }, { 'type': 'image_url', 'image_url': { 'url': f'{resource_path}/{PIC2}', }, }], }] response = client.chat.completions.create(model=model_name, messages=prompt_messages, temperature=0.8, top_p=0.8) file.writelines(str(response).lower() + '\n') api_client = APIClient(http_url) model_name = api_client.available_models[0] for item in api_client.chat_completions_v1(model=model_name, messages=prompt_messages): continue file.writelines(str(item) + '\n') file.close() allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT) assert 'tiger' in str(response).lower() or '虎' in str(response).lower() or 'ski' in str( response).lower() or '滑雪' in str(response).lower(), response assert 'tiger' in str(item).lower() or '虎' in str(item).lower() or 'ski' in str(item).lower() or '滑雪' in str( item).lower(), item def _run_reasoning_case(log_path, port: int = DEFAULT_PORT): http_url = ':'.join([BASE_HTTP_URL, str(port)]) model = get_model(http_url) if model is None: assert False, 'server not start correctly' timestamp = time.strftime('%Y%m%d_%H%M%S') restful_log = os.path.join(log_path, f'restful_reasoning_{model}_{str(port)}_{timestamp}.log') file = open(restful_log, 'w') client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1') model_name = client.models.list().data[0].id with allure.step('step1 - stream'): messages = [{'role': 'user', 'content': '9.11 and 9.8, which is greater?'}] response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=True) outputList = [] final_content = '' final_reasoning_content = '' for stream_response in response: if stream_response.choices[0].delta.content is not None: final_content += stream_response.choices[0].delta.content if stream_response.choices[0].delta.reasoning_content is not None: final_reasoning_content += stream_response.choices[0].delta.reasoning_content outputList.append(stream_response) file.writelines(str(outputList) + '\n') with assume: assert '9.11' in final_reasoning_content and '9.11' in final_content and len(outputList) > 1, str( outputList) with allure.step('step2 - batch'): response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=False) print(response) reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content file.writelines(str(outputList) + '\n') with assume: assert '9.11' in reasoning_content and '9.11' in content and len(outputList) > 1, str(outputList) file.close() allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT) def test_internlm_multiple_round_prompt(client, model): def add(a: int, b: int): return a + b def mul(a: int, b: int): return a * b tools = [{ 'type': 'function', 'function': { 'name': 'add', 'description': 'Compute the sum of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }, { 'type': 'function', 'function': { 'name': 'mul', 'description': 'Calculate the product of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }] messages = [{'role': 'user', 'content': 'Compute (3+5)*2'}] response = client.chat.completions.create(model=model, messages=messages, temperature=0.01, stream=False, tools=tools) print(response) response_list = [response] func1_name = response.choices[0].message.tool_calls[0].function.name func1_args = response.choices[0].message.tool_calls[0].function.arguments func1_args_dict = json.loads(func1_args) func1_out = add(**func1_args_dict) if func1_name == 'add' else mul(**func1_args_dict) with assume: assert response.choices[0].finish_reason == 'tool_calls' with assume: assert func1_name == 'add' with assume: assert func1_args == '{"a": 3, "b": 5}' with assume: assert func1_out == 8 with assume: assert response.choices[0].message.tool_calls[0].type == 'function' messages.append({'role': 'assistant', 'content': response.choices[0].message.content}) messages.append({'role': 'environment', 'content': f'3+5={func1_out}', 'name': 'plugin'}) response = client.chat.completions.create(model=model, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) response_list.append(response) func2_name = response.choices[0].message.tool_calls[0].function.name func2_args = response.choices[0].message.tool_calls[0].function.arguments func2_args_dict = json.loads(func2_args) func2_out = add(**func2_args_dict) if func2_name == 'add' else mul(**func2_args_dict) with assume: assert response.choices[0].finish_reason == 'tool_calls' with assume: assert func2_name == 'mul' with assume: assert func2_args == '{"a": 8, "b": 2}' with assume: assert func2_out == 16 with assume: assert response.choices[0].message.tool_calls[0].type == 'function' return response_list def test_qwen_multiple_round_prompt(client, model): def get_current_temperature(location: str, unit: str = 'celsius'): """Get current temperature at a location. Args: location: The location to get the temperature for, in the format "City, State, Country". unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) Returns: the temperature, the location, and the unit in a dict """ return { 'temperature': 26.1, 'location': location, 'unit': unit, } def get_temperature_date(location: str, date: str, unit: str = 'celsius'): """Get temperature at a location and date. Args: location: The location to get the temperature for, in the format 'City, State, Country'. date: The date to get the temperature for, in the format 'Year-Month-Day'. unit: The unit to return the temperature in. Defaults to 'celsius'. (choices: ['celsius', 'fahrenheit']) Returns: the temperature, the location, the date and the unit in a dict """ return { 'temperature': 25.9, 'location': location, 'date': date, 'unit': unit, } def get_function_by_name(name): if name == 'get_current_temperature': return get_current_temperature if name == 'get_temperature_date': return get_temperature_date tools = [{ 'type': 'function', 'function': { 'name': 'get_current_temperature', 'description': 'Get current temperature at a location.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'unit': { 'type': 'string', 'enum': ['celsius', 'fahrenheit'], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': ['location'] } } }, { 'type': 'function', 'function': { 'name': 'get_temperature_date', 'description': 'Get temperature at a location and date.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'date': { 'type': 'string', 'description': 'The date to get the temperature for, in the format \'Year-Month-Day\'.' }, 'unit': { 'type': 'string', 'enum': ['celsius', 'fahrenheit'], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': ['location', 'date'] } } }] messages = [{ 'role': 'user', 'content': 'Today is 2024-11-14, What\'s the temperature in San Francisco now? How about tomorrow?' }] response = client.chat.completions.create(model=model, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) response_list = [response] func1_name = response.choices[0].message.tool_calls[0].function.name func1_args = response.choices[0].message.tool_calls[0].function.arguments func2_name = response.choices[0].message.tool_calls[1].function.name func2_args = response.choices[0].message.tool_calls[1].function.arguments with assume: assert response.choices[0].finish_reason == 'tool_calls' assert func1_name == 'get_current_temperature' assert func1_args == '{"location": "San Francisco, CA, USA"}' \ or func1_args == '{"location": "San Francisco, California, USA", "unit": "celsius"}' assert func2_name == 'get_temperature_date' assert func2_args == '{"location": "San Francisco, CA, USA", "date": "2024-11-15"}' \ or func2_args == '{"location": "San Francisco, California, USA", "date": "2024-11-15", "unit": "celsius"}' assert response.choices[0].message.tool_calls[0].type == 'function' messages.append(response.choices[0].message) for tool_call in response.choices[0].message.tool_calls: tool_call_args = json.loads(tool_call.function.arguments) tool_call_result = get_function_by_name(tool_call.function.name)(**tool_call_args) messages.append({ 'role': 'tool', 'name': tool_call.function.name, 'content': tool_call_result, 'tool_call_id': tool_call.id }) response = client.chat.completions.create(model=model, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) response_list.append(response) with assume: assert response.choices[0].finish_reason == 'stop' assert '26.1' in response.choices[0].message.content return response_list def _run_tools_case(log_path, port: int = DEFAULT_PORT): http_url = ':'.join([BASE_HTTP_URL, str(port)]) model = get_model(http_url) if model is None: assert False, 'server not start correctly' timestamp = time.strftime('%Y%m%d_%H%M%S') restful_log = os.path.join(log_path, f'restful_toolcall_{model}_{str(port)}_{timestamp}.log') file = open(restful_log, 'w') client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1') model_name = client.models.list().data[0].id with open(restful_log, 'a') as file: with allure.step('step1 - one_round_prompt'): tools = [{ 'type': 'function', 'function': { 'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA', }, 'unit': { 'type': 'string', 'enum': ['celsius', 'fahrenheit'] }, }, 'required': ['location'], }, } }] messages = [{'role': 'user', 'content': 'What\'s the weather like in Boston today?'}] response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=False, tools=tools) print(response) with assume: assert response.choices[0].finish_reason == 'tool_calls' with assume: assert response.choices[0].message.tool_calls[0].function.name == 'get_current_weather' with assume: assert 'Boston' in response.choices[0].message.tool_calls[0].function.arguments with assume: assert response.choices[0].message.tool_calls[0].type == 'function' file.writelines(str(response) + '\n') with allure.step('step2 - search prompt'): tools = [{ 'type': 'function', 'function': { 'name': 'search', 'description': 'BING search API', 'parameters': { 'type': 'object', 'properties': { 'query': { 'type': 'string', 'description': 'list of search query strings' } }, 'required': ['location'] } } }] messages = [{'role': 'user', 'content': '搜索最近的人工智能发展趋势'}] response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=False, tools=tools) print(response) with assume: assert response.choices[0].finish_reason == 'tool_calls' with assume: assert response.choices[0].message.tool_calls[0].function.name == 'search' with assume: assert '人工智能' in response.choices[0].message.tool_calls[0].function.arguments with assume: assert response.choices[0].message.tool_calls[0].type == 'function' file.writelines(str(response) + '\n') with allure.step('step3 - multiple_round_prompt'): response_list = None if 'intern' in model.lower(): response_list = test_internlm_multiple_round_prompt(client, model_name) elif 'qwen' in model.lower(): response_list = test_qwen_multiple_round_prompt(client, model_name) if response_list is not None: file.writelines(str(response_list) + '\n') allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT) def proxy_health_check(url): """Check if proxy server is healthy.""" try: # For proxy server, we check if it responds to the /v1/models endpoint import requests response = requests.get(f'{url}/v1/models', timeout=5) if response.status_code == 200: return True return False except Exception: return False def start_proxy_server(log_path, port, case_name: str = 'default'): """Start the proxy server for testing with enhanced error handling and logging.""" if log_path is None: log_path = '/nvme/qa_test_models/evaluation_report' timestamp = time.strftime('%Y%m%d_%H%M%S') proxy_log = os.path.join(log_path, f'proxy_server_{case_name}_{str(port)}_{timestamp}.log') proxy_url = f'http://{DEFAULT_SERVER}:{port}' # noqa: E231, E261 try: response = requests.get(f'{proxy_url}/nodes/status', timeout=5) if response.status_code == 200: print(f'Terminating existing nodes on proxy {proxy_url}') requests.get(f'{proxy_url}/nodes/terminate_all', timeout=10) time.sleep(5) except requests.exceptions.RequestException: pass cmd = (f'lmdeploy serve proxy --server-name {DEFAULT_SERVER} --server-port {port} ' f'--routing-strategy min_expected_latency --serving-strategy Hybrid') print(f'Starting proxy server with command: {cmd}') print(f'Proxy log will be saved to: {proxy_log}') proxy_file = open(proxy_log, 'w') proxy_process = subprocess.Popen([cmd], stdout=proxy_file, stderr=proxy_file, shell=True, text=True, encoding='utf-8') pid = proxy_process.pid start_time = int(time.time()) timeout = 300 time.sleep(5) for i in range(timeout): time.sleep(1) if proxy_health_check(f'http://{DEFAULT_SERVER}:{port}'): # noqa: E231, E261 break try: # Check if process is still running return_code = proxy_process.wait(timeout=1) # Small timeout to check status if return_code != 0: with open(proxy_log, 'r') as f: content = f.read() print(content) return 0, proxy_process except subprocess.TimeoutExpired: continue end_time = int(time.time()) total_time = end_time - start_time if total_time >= timeout: break proxy_file.close() allure.attach.file(proxy_log, name=proxy_log, attachment_type=allure.attachment_type.TEXT) print(f'Proxy server started successfully with PID: {pid}') return pid, proxy_process def run_llm_test(config, run_config, common_case_config, worker_id): pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: case_name = get_case_str_by_config(run_config) run_all_step(config.get('log_path'), case_name, common_case_config, port=DEFAULT_PORT + get_workerid(worker_id)) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) def run_mllm_test(config, run_config, worker_id): pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: run_vl_testcase(config.get('log_path'), config.get('resource_path'), port=DEFAULT_PORT + get_workerid(worker_id)) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) def run_reasoning_case(config, run_config, worker_id): pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: _run_reasoning_case(config.get('log_path'), port=DEFAULT_PORT + get_workerid(worker_id)) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) def run_tools_case(config, run_config, worker_id): pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: _run_tools_case(config.get('log_path'), port=DEFAULT_PORT + get_workerid(worker_id)) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) def run_logprob_test(config, run_config, worker_id): pid, content = start_openai_service(config, run_config, worker_id) try: if pid > 0: _run_logprobs_test(port=DEFAULT_PORT + get_workerid(worker_id)) else: assert False, f'Failed to start RESTful API server: {content}' finally: if pid > 0: terminate_restful_api(worker_id) ================================================ FILE: autotest/utils/toolkit.py ================================================ from functools import lru_cache from transformers import AutoTokenizer def parse_sse_stream(content: str) -> list[str]: """Parse SSE (Server-Sent Events) stream content into a list of events. Each event is either a JSON string or "[DONE]". """ lines = content.strip().split('\n') events = [] for line in lines: line = line.strip() if line.startswith('data: '): data = line[6:] # remove "data: " if data.strip() == '[DONE]': events.append('[DONE]') else: events.append(data) return events @lru_cache(maxsize=4) def _load_tokenizer_cached(model_path: str): try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return tokenizer except Exception as e: raise RuntimeError(f"Failed to load tokenizer from '{model_path}': {e}") def encode_text(model_path: str, text: str) -> list[int]: tokenizer = _load_tokenizer_cached(model_path) encoded = tokenizer.encode(text) return encoded ================================================ FILE: benchmark/README.md ================================================ # Benchmark We provide several profiling tools to benchmark our models. ## profile with dataset Download the dataset below or create your own dataset. ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` Profiling your model with `profile_throughput.py` ```bash python profile_throughput.py \ ShareGPT_V3_unfiltered_cleaned_split.json \ /path/to/your/model \ --concurrency 64 ``` ## profile restful api `profile_restful_api.py` is used to do benchmark on api server. ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json python3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json ``` ================================================ FILE: benchmark/benchmark_decode.py ================================================ import json import pickle import time from pathlib import Path import fire import numpy as np from transformers import AutoTokenizer from lmdeploy.pytorch.decode import Engine def benchmark(model_path, share_gpt_path, downsample=100, accel=None, save_to='decode_result'): """Benchmark using ShareGPT data. Please download `ShareGPT_V3_unfiltered_cleaned_split.json` as data for this benchmark. """ start = time.monotonic() content = json.load(open(share_gpt_path, 'r')) texts = [] for c in content: for cc in c['conversations']: texts.append(cc['value']) print(f'Parse json in {time.monotonic() - start} seconds.') tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = 'right' texts = texts[::downsample] input_ids = tokenizer(texts, padding=False).input_ids print(F'Number of prompts: {len(input_ids)}') print(F'Maximum length: {max(map(len, input_ids))}') print(F'Total length: {sum(map(len, input_ids))}') start = time.monotonic() # Init an engine engine = Engine(model_path, tokenizer=tokenizer, accel=accel) # decode prompts probs = engine.decode(input_ids) total_tokens = sum(map(len, input_ids)) elapsed = time.monotonic() - start print(f'Decoded {total_tokens} tokens in {elapsed:.1f} seconds, ' f'{total_tokens / elapsed:.1f} tokens/s.') print(f'Decoded {len(probs)} prompts in {elapsed:.1f} seconds, ' f'{len(probs) / elapsed:.1f} requests/s.') pkl_path = Path(save_to).with_suffix('.pkl') with pkl_path.open('wb') as f: pickle.dump(probs, f) txt_path = Path(save_to).with_suffix('.txt') np.savetxt(txt_path.as_posix(), probs, fmt='%.4e') if __name__ == '__main__': fire.Fire(benchmark) # llama-2 on 1 A100: # data = ShareGPT, downsample = 100 # Decoded 1579536 tokens in 175.3 seconds, 9012.821089984884 tokens/s. # Decoded 7022 prompts in 175.3 seconds, 40.067481648961376 requests/s. # llama-2 on 3 A100: # data = ShareGPT, downsample = 100 # Decoded 1579536 tokens in 77.9 seconds, 20268.736076299527 tokens/s. # Decoded 7022 prompts in 77.9 seconds, 90.10688248180179 requests/s. # llama-2 on 8 A100: # data = ShareGPT, downsample = 100 # Decoded 1579536 tokens in 55.2 seconds, 28630.35872677815 tokens/s. # Decoded 7022 prompts in 55.2 seconds, 127.27939026361929 requests/s. # llama-2 on 8 A100: # data = ShareGPT, downsample = 10 # Decoded 15991314 tokens in 242.7 seconds, 65893.38488718234 tokens/s. # Decoded 70216 prompts in 242.7 seconds, 289.33018970413536 requests/s. # Above time all includes time for workers to load model. ================================================ FILE: benchmark/benchmark_pipeline.py ================================================ import os import subprocess from typing import Dict, List import fire import yaml def get_cmd(model_path, backend, engine_config, data_config): assert backend in ['turbomind', 'pytorch'] current_dir = os.path.dirname(os.path.abspath(__file__)) dataset_path = data_config.pop('dataset_path') data_config.pop('dataset_name') cmd = ['python3', f'{current_dir}/profile_pipeline_api.py', dataset_path, model_path, '--backend', backend] for key, value in engine_config.items(): # profile_pipeline_api.py uses "--concurrency" to pass the "max_batch_size" value if key == 'max_batch_size': key = 'concurrency' # change the key like 'cache_max_entry_count' to 'cache-max-entry-count' to suit the optional # arguments in "python3 benchmark/profile_pipeline_api.py" key = key.replace('_', '-') cmd.append(f'--{key}') cmd.append(str(value)) for key, value in data_config.items(): # change the key like 'sharegpt_output_len' to 'sharegpt-output-len' to suit the optional # arguments in "python3 benchmark/profile_pipeline_api.py" key = key.replace('_', '-') cmd.append(f'--{key}') cmd.append(str(value)) return cmd def benchmark(model_path, backend, engine_config, data_config): """Benchmark the performance with the given configuration. Args: model_path: Path to the model. :param backend: Backend to use. :param engine_config: Configuration for the inference engine. :param data_config: Configuration for the data. """ model_name = os.path.basename(model_path) bs = engine_config['max_batch_size'] cach_ratio = engine_config.get('cache_max_entry_count', 0.8) tp = engine_config.get('tp', 1) output_file = f'benchmark_pipeline_{model_name}_{backend}_bs{bs}_tp{tp}_cache{cach_ratio}.csv' try: if isinstance(data_config, Dict): data_config = [data_config] assert isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config) for _data_config in data_config: _data_config['csv'] = output_file cmd = get_cmd(model_path, backend, engine_config, _data_config) print(f"Running command: {' '.join(cmd)}") subprocess.run(cmd, check=True) except Exception as e: print(f'exception happened, {e}') def main(model_path=None, backend=None, config_path=None): with open(config_path, 'r') as f: config = yaml.safe_load(f) engine_configs = config['engine'] data_config = config['data'] if isinstance(engine_configs, Dict): engine_configs = [engine_configs] assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs) for engine_config in engine_configs: # The model_path provided by the user will override the model_path in the config file. model_path = model_path or engine_config.pop('model_path') engine_config.pop('model_path', '') benchmark(model_path, backend, engine_config, data_config) if __name__ == '__main__': fire.Fire(main) ================================================ FILE: benchmark/benchmark_serving.py ================================================ import os import subprocess import time from typing import Dict, List, Optional, Tuple import fire import yaml def get_launching_server_cmd(model_path, backend, server_config): if backend in ['turbomind', 'pytorch']: cmd = ['lmdeploy', 'serve', 'api_server', model_path, '--backend', backend] elif backend == 'sglang': cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path] elif backend == 'vllm': cmd = ['vllm', 'serve', model_path] else: raise ValueError(f'unknown backend: {backend}') for key, value in server_config.items(): # Convert snake_case to kebab-case for command line args key = key.replace('_', '-') cmd.append(f'--{key}') if str(value): cmd.append(str(value)) # Special handling for proxy server case if server_config.get('proxy_url') and server_config.get('dp'): cmd.append('--allow-terminate-by-client') return cmd def get_output_file(model_path, backend, server_config): """Generate the benchmark output filename.""" model_name = server_config.get('model_name', None) or os.path.basename(model_path) if backend not in ['turbomind', 'pytorch', 'sglang', 'vllm']: raise ValueError(f'Unknown backend: {backend}') if backend in ['sglang', 'vllm']: return f'benchmark_{model_name}_{backend}.csv' # For turbomind/pytorch backends params = [ ('bs', server_config['max_batch_size']), ('tp', server_config.get('tp', 1)), ('dp', server_config.get('dp', '')), ('ep', server_config.get('ep', '')), ('cache', server_config.get('cache_max_entry_count', 0.8)), ('mptk', server_config.get('max_prefill_token_num', '')), ] params_str = '_'.join(f'{k}{v}' for k, v in params if v != '') # Turbomind-specific additions if backend == 'turbomind' and (comm := server_config.get('communicator')): params_str += f'_{comm}' return f'benchmark_{model_name}_{backend}_{params_str}.csv' def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]: if backend in ['turbomind', 'pytorch']: if server_config.get('proxy_url'): # If proxy_url is set, we use the proxy server's IP and port parts = server_config['proxy_url'].split(':') server_ip = parts[1].lstrip('//') server_port = int(parts[2]) else: # Default to the server IP and port specified in the config server_ip = server_config.get('server_ip', '0.0.0.0') server_port = server_config.get('server_port', 23333) elif backend == 'sglang': return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000)) elif backend == 'vllm': return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000)) else: raise ValueError(f'unknown backend: {backend}') return server_ip, server_port def wait_server_ready(server_ip: str, server_port: int) -> bool: """Wait for the API server to become ready.""" from openai import OpenAI while True: try: client = OpenAI(api_key='DUMMPY', base_url=f'http://{server_ip}:{server_port}/v1') model_name = client.models.list().data[0].id if model_name: print('Server is ready.') return True except Exception as e: print(f'connect to server http://{server_ip}:{server_port} failed {e}') time.sleep(5) def get_client_cmd(backend: str, server_ip: str, server_port: int, client_config: Dict) -> List[str]: """Generate the client benchmark command.""" current_dir = os.path.dirname(os.path.abspath(__file__)) if backend in ['turbomind', 'pytorch']: backend = 'lmdeploy' cmd = [ 'python3', f'{current_dir}/profile_restful_api.py', '--backend', backend, '--host', server_ip, '--port', str(server_port) ] for key, value in client_config.items(): # change the key like 'dataset_path' to 'dataset-path' to suit the optional when performing # "python3 benchmark/profile_restful_api.py" key = key.replace('_', '-') if key == 'disable-warmup': if str(value).lower() == 'true': cmd.append(f'--{key}') continue cmd.append(f'--{key}') cmd.append(str(value)) return cmd def benchmark(model_path: str, backend: str, server_config: Dict, data_config: Dict | List[Dict]): """Benchmark the server with the given configuration. Args: model_path: Path to the model. backend: Backend to use. server_config: Configuration for the server and the inference engine. data_config: Configuration for the data. """ if isinstance(data_config, Dict): data_config = [data_config] if not (isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config)): raise ValueError('data_config must be a dict or list of dicts') server_cmd = get_launching_server_cmd(model_path, backend, server_config) server_ip, server_port = get_server_ip_port(backend, server_config) proc = None try: print(f"Starting api_server: {' '.join(server_cmd)}", flush=True) proc = subprocess.Popen(server_cmd) # Wait for the server to be ready wait_server_ready(server_ip, server_port) # Run benchmarks output_file = get_output_file(model_path, backend, server_config) for data in data_config: data = data.copy() data['output_file'] = output_file client_cmd = get_client_cmd(backend, server_ip, server_port, data) print(f"Running benchmark: {' '.join(client_cmd)}") subprocess.run(client_cmd, check=True) except Exception as e: print(f'Unexpected error: {e}') raise finally: # Clean up server process if proc and proc.poll() is None: if server_config.get('proxy_url') and server_config.get('dp'): # Sending termination request to proxy_server. The request will be broadcasted to # api_server on each dp_rank by proxy server # Note that api_server is supposed to be launched with --allow-terminate-by-client print('Sending termination request to proxy server') subprocess.run(['curl', '-X', 'POST', f'{server_config["proxy_url"]}/nodes/terminate_all'], check=True, timeout=10) proc.terminate() try: proc.wait(timeout=30) except subprocess.TimeoutExpired: print('Server did not terminate gracefully - killing') proc.kill() def validate_config(config: Dict) -> None: """Validate the configuration structure. Args: config: Loaded configuration dictionary Raises: BenchmarkConfigError: If configuration is invalid """ required_sections = ['api_server', 'engine', 'data'] for section in required_sections: if section not in config: raise ValueError(f'Missing required config section: {section}') if not isinstance(config['engine'], (Dict, List)): raise ValueError('engine config must be a dict or list of dicts') if not isinstance(config['data'], (Dict, List)): raise ValueError('data config must be a dict or list of dicts') def main(backend: str, config_path: str, model_path: Optional[str] = None): """Main entry point for the benchmark script. Args: backend: Backend to use config_path: Path to config file model_path: Optional override for model path Raises: BenchmarkConfigError: If required parameters are missing or config is invalid """ with open(config_path, 'r') as f: config = yaml.safe_load(f) server_config = config['server'] engine_configs = config['engine'] data_config = config['data'] if isinstance(engine_configs, Dict): engine_configs = [engine_configs] assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs) for engine_config in engine_configs: server_config = server_config.copy() server_config.update(engine_config) # Merge engine config with server config # The model_path provided by the user will override the model_path in the config file. model_path = model_path or server_config.pop('model_path') # Remove model_path from server_config to avoid passing it to the server command server_config.pop('model_path', None) benchmark(model_path, backend, server_config, data_config) if __name__ == '__main__': fire.Fire(main) ================================================ FILE: benchmark/benchmark_throughput.py ================================================ import os import subprocess from typing import Dict, List import fire import yaml def get_cmd(model_path, backend, engine_config, data_config): assert backend in ['turbomind', 'pytorch'] current_dir = os.path.dirname(os.path.abspath(__file__)) dataset_path = data_config.pop('dataset_path') cmd = ['python3', f'{current_dir}/profile_throughput.py', dataset_path, model_path, '--backend', backend] for key, value in engine_config.items(): # profile_throughput.py uses "--concurrency" to pass the "max_batch_size" value if key == 'max_batch_size': key = 'concurrency' # change the key like 'cache_max_entry_count' to 'cache-max-entry-count' to suit the optional # arguments in "python3 benchmark/profile_throughput.py" key = key.replace('_', '-') cmd.append(f'--{key}') cmd.append(str(value)) for key, value in data_config.items(): # change the key like 'sharegpt_output_len' to 'sharegpt-output-len' to suit the optional # arguments in "python3 benchmark/profile_throughput.py" key = key.replace('_', '-') cmd.append(f'--{key}') cmd.append(str(value)) return cmd def benchmark(model_path, backend, engine_config, data_config): """Benchmark the performance with the given configuration. Args: model_path: Path to the model. :param backend: Backend to use. :param engine_config: Configuration for the inference engine. :param data_config: Configuration for the data. """ model_name = os.path.basename(model_path) bs = engine_config['max_batch_size'] cach_ratio = engine_config.get('cache_max_entry_count', 0.8) tp = engine_config.get('tp', 1) output_file = f'benchmark_throughput_{model_name}_{backend}_bs{bs}_tp{tp}_cache{cach_ratio}.csv' try: if isinstance(data_config, Dict): data_config = [data_config] assert isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config) for _data_config in data_config: _data_config['csv'] = output_file cmd = get_cmd(model_path, backend, engine_config, _data_config) print(f"Running command: {' '.join(cmd)}") subprocess.run(cmd, check=True) except Exception as e: print(f'exception happened, {e}') def main(model_path=None, backend=None, config_path=None): with open(config_path, 'r') as f: config = yaml.safe_load(f) engine_configs = config['engine'] data_config = config['data'] if isinstance(engine_configs, Dict): engine_configs = [engine_configs] assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs) for engine_config in engine_configs: # The model_path provided by the user will override the model_path in the config file. model_path = model_path or engine_config.pop('model_path') engine_config.pop('model_path', '') benchmark(model_path, backend, engine_config, data_config) if __name__ == '__main__': fire.Fire(main) ================================================ FILE: benchmark/lmdeploy.yml ================================================ num_promts: &num_prompts 10000 dataset_path: &dataset_path "/nvme1/shared/ShareGPT_V3_unfiltered_cleaned_split.json" dataset_name: &dataset_name "sharegpt" model_path: &model_path "Qwen/Qwen3-30B-A3B-FP8" server: server_port: 23333 # Inference engine configuration engine: - model_path: *model_path max_batch_size: 1280 cache_max_entry_count: 0.9 tp: 1 - model_path: *model_path max_batch_size: 1280 cache_max_entry_count: 0.9 max_prefill_token_num: 4096 tp: 1 - model_path: "Qwen/Qwen3-235B-A22B-FP8" max_batch_size: 64 cache_max_entry_count: 0.7 max_prefill_token_num: 4096 dp: 8 ep: 8 proxy_url: "http://localhost:8000" # Benchmark test configuration for profile_restful_api.py # Defines multiple test cases with different output lengths to evaluate API performance data: - dataset_name: *dataset_name dataset_path: *dataset_path num_prompts: *num_prompts - dataset_name: *dataset_name dataset_path: *dataset_path sharegpt_output_len: 2048 num_prompts: *num_prompts - dataset_name: *dataset_name dataset_path: *dataset_path sharegpt_output_len: 4096 num_prompts: *num_prompts - dataset_name: *dataset_name dataset_path: *dataset_path sharegpt_output_len: 8192 num_prompts: *num_prompts - dataset_name: *dataset_name dataset_path: *dataset_path sharegpt_output_len: 16384 num_prompts: *num_prompts - dataset_name: *dataset_name dataset_path: *dataset_path sharegpt_output_len: 32768 num_prompts: *num_prompts ================================================ FILE: benchmark/profile_pipeline_api.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import argparse import json import os import random from typing import List, Optional, Tuple import numpy as np from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter from lmdeploy.profiler import Profiler, Session from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError('output_len too small') # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) prompt_len = len(prompt_token_ids) output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None): # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}') print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}') return filtered_dataset def sample_random_requests( input_len: int, output_len: int, num_prompts: int, range_ratio: float, tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, size=num_prompts, ) output_lens = np.random.randint( int(output_len * range_ratio), output_len + 1, size=num_prompts, ) if True: # Sample token ids from ShareGPT and repeat/truncate them to # satisfy the input_lens # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # remove the empty prompt dataset = [(query, answer) for query, answer in dataset if len(query) > 0] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short input_requests: List[Tuple[str, int, int]] = [] for i in range(num_prompts): # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) prompt_len = len(prompt_token_ids) if prompt_len > input_lens[i]: input_ids = prompt_token_ids[:input_lens[i]] else: ratio = (input_lens[i] + prompt_len - 1) // prompt_len input_ids = (prompt_token_ids * ratio)[:input_lens[i]] prompt = tokenizer.decode(input_ids) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) else: # Sample token ids from random integers. # This can cause some NaN issues. offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) print(f'#Input tokens: {np.sum(input_lens)}') print(f'#Output tokens: {np.sum(output_lens)}') return input_requests class Engine: def __init__(self, model_path: str, engine_config, csv: str): self.pipe = pipeline(model_path, backend_config=engine_config, log_level='ERROR') self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.return_routed_experts = getattr(self.pipe.backend_config, 'enable_return_routed_experts', False) self.csv = csv def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output): prompts = [prompt for prompt, _, _ in requests] gen_configs = [ GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k, ignore_eos=True, do_sample=False, return_routed_experts=self.return_routed_experts, max_new_tokens=output_len) for _, _, output_len in requests ] sess: List[Session] = [] for _, input_len, output_len in requests: sess.append(profiler.new_session(input_len, output_len)) def _to_status(finish_reason): if finish_reason == 'length': return Session.SUCCESS else: return Session.FAIL profiler.start() for s in sess: s.tick(0) if stream_output: pbar = tqdm(total=len(requests)) for output in self.pipe.stream_infer(prompts, gen_config=gen_configs, do_preprocess=False): index = output.index n_token = output.generate_token_len finish_reason = output.finish_reason sess[index].tick(n_token) if finish_reason is not None: sess[index].finish(_to_status(finish_reason)) pbar.update(1) pbar.close() else: for output in self.pipe(prompts, gen_configs, do_preprocess=False, use_tqdm=True): index = output.index n_token = output.generate_token_len finish_reason = output.finish_reason sess[index].tick(n_token) sess[index].finish(_to_status(finish_reason)) profiler.finish() # report first failure for i, s in enumerate(sess): if s.status != Session.SUCCESS or s.ns[-1] < s.req_output_len: logger.error(f'Request {i} failed with {s.ns[-1]}/{s.req_output_len} tokens generated' # noqa: E501 ) logger.error(f'Prompt: {prompts[i]}') logger.warning('Got failed requests, metrics may be invalid') break def parse_args(): parser = argparse.ArgumentParser(description='Benchmark the request throughput of lmdeploy ' 'in localhost', formatter_class=DefaultsAndTypesHelpFormatter) parser.add_argument('dataset', type=str, help='the path dataset') parser.add_argument('model_path', type=str, help='the path of the model in localhost or ' 'the repo_id of the model in huggingface.co') parser.add_argument('-c', '--concurrency', type=int, help='Number of working threads to process the sampled prompts', default=256) parser.add_argument('-n', '--num-prompts', type=int, help='Number of prompts to process', default=5000) parser.add_argument('--csv', type=str, help='Where to save the result.', default='./profile_pipeline_api.csv') parser.add_argument('--seed', type=int, default=0, help='Seed used in sampling prompts from dataset') parser.add_argument('--stream-output', action='store_true', help='Trust remote code for loading hf models') parser.add_argument('--dataset-name', type=str, default='sharegpt', choices=['sharegpt', 'random'], help='Name of the dataset to benchmark on.') parser.add_argument( '--sharegpt-output-len', type=int, default=None, help='Output length for each request. Overrides the output length ' 'from the ShareGPT dataset.', ) parser.add_argument( '--random-input-len', type=int, help='Number of input tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-output-len', type=int, help='Number of output tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-range-ratio', type=float, default=0.0, help='Range of sampled ratio of input/output length, ' 'used only for random dataset.', ) # other args ArgumentHelper.top_p(parser) ArgumentHelper.temperature(parser) ArgumentHelper.top_k(parser) ArgumentHelper.log_level(parser) ArgumentHelper.backend(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) ArgumentHelper.enable_return_routed_experts(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) # turbomind engine args tb_group = parser.add_argument_group('TurboMind engine argument') tb_group._group_actions.append(tp_act) tb_group._group_actions.append(cache_count_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(cache_block_seq_len_act) tb_group._group_actions.append(prefix_caching_act) ArgumentHelper.model_format(tb_group, default='hf') ArgumentHelper.quant_policy(tb_group, default=0) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.communicator(tb_group) ArgumentHelper.async_(tb_group) args = parser.parse_args() return args def main(): args = parse_args() random.seed(args.seed) os.environ['TM_LOG_LEVEL'] = args.log_level if args.backend == 'turbomind': engine_config = TurbomindEngineConfig(max_batch_size=args.concurrency, tp=args.tp, cache_max_entry_count=args.cache_max_entry_count, session_len=args.session_len, cache_block_seq_len=args.cache_block_seq_len, model_format=args.model_format, quant_policy=args.quant_policy, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, enable_prefix_caching=args.enable_prefix_caching, communicator=args.communicator, enable_metrics=False, async_=args.async_) elif args.backend == 'pytorch': engine_config = PytorchEngineConfig( cache_max_entry_count=args.cache_max_entry_count, session_len=args.session_len, block_size=args.cache_block_seq_len, max_batch_size=args.concurrency, tp=args.tp, thread_safe=False, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, enable_return_routed_experts=args.enable_return_routed_experts, ) engine = Engine(args.model_path, engine_config, csv=args.csv) profiler = Profiler(args.stream_output, [50, 75, 95, 99]) if args.dataset_name == 'sharegpt': assert args.random_input_len is None and args.random_output_len is None requests = sample_sharegpt_requests( dataset_path=args.dataset, num_requests=args.num_prompts, tokenizer=engine.tokenizer, fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == 'random': assert args.random_input_len is not None and \ args.random_output_len is not None requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, range_ratio=args.random_range_ratio, tokenizer=engine.tokenizer, dataset_path=args.dataset, ) else: raise ValueError(f'Unknown dataset: {args.dataset_name}') engine.process_request(requests, profiler, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, stream_output=args.stream_output) hyperparams = [('Concurrency', args.concurrency), ('Stream output', str(args.stream_output).lower())] profiler.compute_metrics() profiler.summarize(title='Profile Pipeline API', hyperparams=hyperparams) if args.csv: # profiler.save_csv(args.csv, (('batch', args.concurrency), ('num_prompts', args.num_prompts))) profiler.save_csv(args.csv, ( ('backend', args.backend), ('bs', args.concurrency), ('dataset_name', args.dataset_name), ('sharegpt_output_len', args.sharegpt_output_len), ('random_input_len', args.random_input_len), ('random_output_len', args.random_output_len), ('random_range_ratio', args.random_range_ratio), ('num_prompts', args.num_prompts), )) if __name__ == '__main__': main() ================================================ FILE: benchmark/profile_restful_api.py ================================================ # Modify from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_serving.py # noqa # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py # noqa # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py # noqa """Benchmark online serving with dynamic requests. Usage: python3 -m sglang.bench_serving --backend sglang --num-prompt 10 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi """ # noqa import argparse import asyncio import csv import io import json import os import random import resource import sys import time import traceback import warnings from argparse import ArgumentParser from dataclasses import dataclass, field from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union import aiohttp import numpy as np import pybase64 import requests from PIL import Image from tqdm.asyncio import tqdm from transformers import (AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None) _timeout_value = os.getenv('AIOHTTP_TIMEOUT', None) if _timeout_value is not None: try: _timeout_value = int(_timeout_value) if _timeout_value < 0: raise ValueError('AIOHTTP_TIMEOUT cannot be negative.') AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=_timeout_value * 60 * 60) except ValueError as e: print(f'Invalid AIOHTTP_TIMEOUT: {e}.') AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None) global args @dataclass class RequestFuncInput: prompt: str api_url: str prompt_len: int output_len: int model: str image_data: Optional[List[str]] extra_request_body: Dict[str, Any] @dataclass class RequestFuncOutput: generated_text: str = '' success: bool = False latency: float = 0.0 ttft: float = 0.0 # Time to first token itl: List[float] = field(default_factory=list) # List of inter-token latencies prompt_len: int = 0 output_len: int = 0 error: str = '' def remove_prefix(text: str, prefix: str) -> str: return text[len(prefix):] if text.startswith(prefix) else text # trt llm not support ignore_eos # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 async def async_request_trt_llm( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith('generate_stream') async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { 'accumulate_tokens': True, 'text_input': request_func_input.prompt, 'temperature': 0.000001, 'top_p': 1.0, 'max_tokens': request_func_input.output_len, 'stream': True, 'min_length': request_func_input.output_len, 'end_id': 1048576, **request_func_input.extra_request_body, } if args.disable_ignore_eos: del payload['min_length'] del payload['end_id'] output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data:') data = json.loads(chunk) output.generated_text += data['text_output'] timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp output.latency = most_recent_timestamp - st output.success = True output.output_len = request_func_input.output_len else: output.error = response.reason or '' output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = ''.join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output # set ignore_eos True by default async def async_request_openai_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith('completions'), "OpenAI Completions API URL must end with 'completions'." prompt = request_func_input.prompt async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { 'model': request_func_input.model, 'prompt': prompt, 'temperature': 0.0, 'best_of': 1, 'max_tokens': request_func_input.output_len, 'stream': not args.disable_stream, 'ignore_eos': not args.disable_ignore_eos, **request_func_input.extra_request_body, } headers = {'Authorization': f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = '' ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ') latency = time.perf_counter() - st if chunk == '[DONE]': pass else: data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated if data['choices'][0]['text']: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data['choices'][0]['text'] output.generated_text = generated_text output.success = True output.latency = latency output.output_len = request_func_input.output_len else: output.error = response.reason or '' output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = ''.join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith('chat/completions'), "OpenAI Chat Completions API URL must end with 'chat/completions'." if request_func_input.image_data: # Build multi-image content: a list of image_url entries followed by the text content_items = [{ 'type': 'image_url', 'image_url': { 'url': img_url }, } for img_url in request_func_input.image_data] content_items.append({'type': 'text', 'text': request_func_input.prompt}) messages = [ { 'role': 'user', 'content': content_items, }, ] else: messages = [{'role': 'user', 'content': request_func_input.prompt}] async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { 'model': request_func_input.model, 'messages': messages, 'temperature': 0.0, 'max_completion_tokens': request_func_input.output_len, 'stream': not args.disable_stream, 'ignore_eos': not args.disable_ignore_eos, **request_func_input.extra_request_body, } headers = {'Authorization': f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = '' output_len = request_func_input.output_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: if args.disable_stream: # Non-streaming response response_json = await response.json() output.generated_text = response_json['choices'][0]['message']['content'] output.success = True output.latency = time.perf_counter() - st output.ttft = (output.latency) # For non-streaming, TTFT = total latency output.output_len = response_json.get('usage', {}).get('completion_tokens', output_len) else: # Streaming response async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ') latency = time.perf_counter() - st if chunk == '[DONE]': pass else: data = json.loads(chunk) # Check if this chunk contains content delta = data.get('choices', [{}])[0].get('delta', {}) content = delta.get('content', '') if content: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = timestamp - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += content # Check for usage info in final chunk output_len = (data.get('usage') or {}).get('completion_tokens', output_len) output.generated_text = generated_text output.success = True output.latency = latency output.output_len = output_len else: output.error = ((response.reason or '') + ': ' + (await response.text())) output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = ''.join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_sglang_generate( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url prompt = request_func_input.prompt async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { 'text': prompt, 'sampling_params': { 'temperature': 0.0, 'max_new_tokens': request_func_input.output_len, 'ignore_eos': not args.disable_ignore_eos, }, 'stream': not args.disable_stream, **request_func_input.extra_request_body, } headers = {} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = '' ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue # print(chunk_bytes) chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ') latency = time.perf_counter() - st if chunk == '[DONE]': pass else: data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated if data['text']: timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text = data['text'] output.generated_text = generated_text output.success = True output.latency = latency output.output_len = request_func_input.output_len else: output.error = response.reason or '' output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = ''.join(traceback.format_exception(*exc_info)) if pbar: pbar.update(1) return output async def async_request_gserver( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: raise NotImplementedError() def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('SGLANG_USE_MODELSCOPE', 'False').lower() == 'true': import huggingface_hub.constants from modelscope import snapshot_download model_path = snapshot_download( model_id=pretrained_model_name_or_path, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ignore_file_pattern=['.*.pt', '.*.safetensors', '.*.bin'], ) return model_path return pretrained_model_name_or_path def get_tokenizer(pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path.endswith('.json') or pretrained_model_name_or_path.endswith('.model'): from sglang.srt.hf_transformers_utils import get_tokenizer return get_tokenizer(pretrained_model_name_or_path) if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path): pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) def get_processor(pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: assert (pretrained_model_name_or_path is not None and pretrained_model_name_or_path != '') if pretrained_model_name_or_path.endswith('.json') or pretrained_model_name_or_path.endswith('.model'): from sglang.srt.utils.hf_transformers_utils import get_processor return get_processor(pretrained_model_name_or_path) if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path): pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) return AutoProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) ASYNC_REQUEST_FUNCS = { 'sglang': async_request_sglang_generate, 'sglang-native': async_request_sglang_generate, 'sglang-oai': async_request_openai_completions, 'sglang-oai-chat': async_request_openai_chat_completions, 'vllm': async_request_openai_completions, 'vllm-chat': async_request_openai_chat_completions, 'lmdeploy': async_request_openai_completions, 'lmdeploy-chat': async_request_openai_chat_completions, 'trt': async_request_trt_llm, 'gserver': async_request_gserver, } @dataclass class BenchmarkMetrics: completed: int total_input: int total_input_text: int total_input_vision: int total_output: int total_output_retokenized: int request_throughput: float input_throughput: float output_throughput: float output_throughput_retokenized: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float p99_ttft_ms: float mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float p99_tpot_ms: float mean_itl_ms: float median_itl_ms: float std_itl_ms: float p99_itl_ms: float mean_e2e_latency_ms: float median_e2e_latency_ms: float SHAREGPT_URL = 'https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json' # noqa def download_and_cache_file(url: str, filename: Optional[str] = None): """Read and cache a file from a url.""" if filename is None: filename = os.path.join('/tmp', url.split('/')[-1]) # Check if the cache file already exists if os.path.exists(filename): return filename print(f'Downloading from {url} to {filename}') # Stream the response to show the progress bar response = requests.get(url, stream=True) response.raise_for_status() # Check for request errors # Total size of the file in bytes total_size = int(response.headers.get('content-length', 0)) chunk_size = 1024 # Download in chunks of 1KB # Use tqdm to display the progress bar with open(filename, 'wb') as f, tqdm( desc=filename, total=total_size, unit='B', unit_scale=True, unit_divisor=1024, ) as bar: for chunk in response.iter_content(chunk_size=chunk_size): f.write(chunk) bar.update(len(chunk)) return filename @dataclass class DatasetRow: prompt: str prompt_len: int output_len: int text_prompt_len: Optional[int] = None vision_prompt_len: Optional[int] = None image_data: Optional[List[str]] = None def __post_init__(self): if self.text_prompt_len is None: self.text_prompt_len = self.prompt_len if self.vision_prompt_len is None: self.vision_prompt_len = 0 def sample_sharegpt_requests(dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None) -> List[DatasetRow]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError('output_len too small') # Download sharegpt if necessary if not os.path.isfile(dataset_path): dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[DatasetRow] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) prompt_len = len(prompt_token_ids) output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None): # Prune too long sequences. continue filtered_dataset.append(DatasetRow( prompt=prompt, prompt_len=prompt_len, output_len=output_len, )) print(f'#Input tokens: {sum(x.prompt_len for x in filtered_dataset)}') print(f'#Output tokens: {sum(x.output_len for x in filtered_dataset)}') return filtered_dataset def compute_random_lens(full_len: int, range_ratio: float, num: int): return np.random.randint( max(int(full_len * range_ratio), 1), full_len + 1, size=num, ) def sample_random_requests( input_len: int, output_len: int, num_prompts: int, range_ratio: float, tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[DatasetRow]: input_lens = compute_random_lens( full_len=input_len, range_ratio=range_ratio, num=num_prompts, ) output_lens = compute_random_lens( full_len=output_len, range_ratio=range_ratio, num=num_prompts, ) # Sample token ids from ShareGPT and repeat/truncate them to # satisfy the input_lens # Download sharegpt if necessary if not os.path.isfile(dataset_path): dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # remove the empty prompt dataset = [(query, answer) for query, answer in dataset if len(query) > 0] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short input_requests: List[DatasetRow] = [] origin_output_lens: List[int] = [] for i in range(num_prompts): # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) prompt_len = len(prompt_token_ids) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) origin_output_lens.append(len(completion_token_ids)) if prompt_len > input_lens[i]: input_ids = prompt_token_ids[:input_lens[i]] else: ratio = (input_lens[i] + prompt_len - 1) // prompt_len input_ids = (prompt_token_ids * ratio)[:input_lens[i]] prompt = tokenizer.decode(input_ids) input_requests.append(DatasetRow( prompt=prompt, prompt_len=int(input_lens[i]), output_len=int(output_lens[i]), )) print(f'#Input tokens: {sum(x.prompt_len for x in input_requests)}') print(f'#Output tokens: {sum(x.output_len for x in input_requests)}') return input_requests def parse_image_resolution(image_resolution: str) -> Tuple[int, int]: """Parse image resolution into (width, height). Supports presets '1080p', '720p', '360p'. And custom 'heightxwidth' format (e.g., '1080x1920' means height=1080, width=1920) will be parsed into (width, height). """ resolution_to_size = { '4k': (3840, 2160), '1080p': (1920, 1080), '720p': (1280, 720), '360p': (640, 360), } if image_resolution in resolution_to_size: return resolution_to_size[image_resolution] res = image_resolution.strip().lower() if 'x' in res: parts = res.split('x') if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit(): height = int(parts[0]) width = int(parts[1]) if height > 0 and width > 0: return (width, height) raise ValueError(f'Unsupported image resolution: {image_resolution}. ' "Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920).") def gen_mm_prompt(tokenizer, image_pad_id, token_num): """Generate a random prompt of specified token length using tokenizer vocabulary.""" all_available_tokens = list(tokenizer.get_vocab().values()) if image_pad_id: all_available_tokens.remove(image_pad_id) selected_tokens = random.choices(all_available_tokens, k=token_num) return tokenizer.decode(selected_tokens) def create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor, backend): try: content_items = [{'type': 'image', 'image': {'url': image_base64}} for image_base64 in images_base64] content_items.append({'type': 'text', 'text': text_prompt}) prompt_str = processor.apply_chat_template( [{ 'role': 'user', 'content': content_items }], add_generation_prompt=True, tokenize=False, ) except Exception as e: # Note (Xinyuan): This is a workaround for an issue where some tokenizers # do not support content as a list. (e.g. InternVL) print(f'Error applying chat template: {e}, fallback to tag') # Some tokenizers do not support list content; fall back to a placeholder in the text prompt_str = f'{text_prompt}' # Calculate total tokens (text + vision) prompt_len = processor( text=[prompt_str], images=images, padding=False, return_tensors='pt', )['input_ids'].numel() # Calculate text-only tokens try: # Create text-only version of the prompt text_only_prompt = processor.apply_chat_template( [{ 'role': 'user', 'content': text_prompt }], add_generation_prompt=True, tokenize=False, ) text_prompt_len = processor( text=[text_only_prompt], padding=False, return_tensors='pt', )['input_ids'].numel() except Exception: # Fallback: just tokenize the text prompt directly tokenizer_to_use = (processor.tokenizer if hasattr(processor, 'tokenizer') else processor) text_prompt_len = len(tokenizer_to_use.encode(text_prompt)) # Vision tokens = total tokens - text tokens vision_prompt_len = prompt_len - text_prompt_len use_raw_prompt = backend in [ 'sglang', 'sglang-oai', 'sglang-oai-chat', 'vllm', 'vllm-chat', 'lmdeploy', 'lmdeploy-chat', ] return DatasetRow( prompt=text_prompt if use_raw_prompt else prompt_str, prompt_len=prompt_len, output_len=output_len, text_prompt_len=text_prompt_len, vision_prompt_len=vision_prompt_len, image_data=images_base64, ) def sample_image_requests( num_requests: int, image_count: int, input_len: int, output_len: int, range_ratio: float, processor: AutoProcessor, image_content: str, image_format: str, image_resolution: str, backend: str, ) -> List[DatasetRow]: """Generate requests with images. - Each request includes ``image_count`` images. - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360), or custom 'heightxwidth' (e.g., 1080x1920). - Text lengths follow the 'random' dataset sampling rule. ``prompt_len`` only counts text tokens and excludes image data. """ # Parse resolution (supports presets and 'heightxwidth') width, height = parse_image_resolution(image_resolution) # Check for potentially problematic combinations and warn user if width * height >= 1920 * 1080 and image_count * num_requests >= 100: warnings.warn( f'High resolution ({width}x{height}) with {image_count * num_requests} total images ' f'may take a long time. Consider reducing resolution or image count.', UserWarning, stacklevel=2, ) # Sample text lengths input_lens = compute_random_lens( full_len=input_len, range_ratio=range_ratio, num=num_requests, ) output_lens = compute_random_lens( full_len=output_len, range_ratio=range_ratio, num=num_requests, ) def _gen_random_image_data_uri(width: int = width, height: int = height) -> Tuple[Image.Image, str, int]: if image_content == 'blank': # Generate blank white image arr = np.full((height, width, 3), 255, dtype=np.uint8) else: # Generate random colored image arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8) img = Image.fromarray(arr) buf = io.BytesIO() img.save(buf, format=image_format, quality=85) encoded = pybase64.b64encode(buf.getvalue()).decode('utf-8') image_data = f'data:image/{image_format};base64,{encoded}' # noqa image_bytes = len(image_data.encode('utf-8')) return img, image_data, image_bytes dataset: List[DatasetRow] = [] total_image_bytes = 0 for i in range(num_requests): # Generate text prompt text_prompt = gen_mm_prompt( processor.tokenizer, processor.image_token_id if hasattr(processor, 'image_token_id') else None, int(input_lens[i]), ) # Generate image list images, images_base64, images_bytes = zip(*[_gen_random_image_data_uri() for _ in range(image_count)]) total_image_bytes += sum(list(images_bytes)) data_row = create_mm_data_row( text_prompt, list(images), list(images_base64), int(output_lens[i]), processor, backend, ) dataset.append(data_row) avg_image_bytes = total_image_bytes // num_requests if num_requests > 0 else 0 print(f'#Input tokens: {np.sum([x.prompt_len for x in dataset])}') print(f'#Output tokens: {np.sum([x.output_len for x in dataset])}') print(f'\nCreated {len(dataset)} {image_content} {image_format} images \ with average {avg_image_bytes} bytes per request') # noqa return dataset async def get_request( input_requests: List[DatasetRow], request_rate: float, ) -> AsyncGenerator[DatasetRow, None]: input_requests = iter(input_requests) for request in input_requests: yield request if request_rate == float('inf'): # If the request rate is infinity, then we don't need to wait. continue # Sample the request interval from the exponential distribution. interval = np.random.exponential(1.0 / request_rate) # The next request will be sent after the interval. await asyncio.sleep(interval) def calculate_metrics( input_requests: List[DatasetRow], outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, backend: str, ) -> Tuple[BenchmarkMetrics, List[int]]: output_lens: List[int] = [] retokenized_output_lens: List[int] = [] total_input = 0 total_input_text = 0 total_input_vision = 0 completed = 0 itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] e2e_latencies: List[float] = [] for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_len output_lens.append(output_len) retokenized_output_len = len(tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)) retokenized_output_lens.append(retokenized_output_len) total_input += input_requests[i].prompt_len total_input_text += input_requests[i].text_prompt_len total_input_vision += input_requests[i].vision_prompt_len if output_len > 1: tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl ttfts.append(outputs[i].ttft) e2e_latencies.append(outputs[i].latency) completed += 1 else: output_lens.append(0) retokenized_output_lens.append(0) if completed == 0: warnings.warn( 'All requests failed. This is likely due to a misconfiguration ' 'on the benchmark arguments.', stacklevel=2, ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, total_input_text=total_input_text, total_input_vision=total_input_vision, total_output=sum(output_lens), total_output_retokenized=sum(retokenized_output_lens), request_throughput=completed / dur_s, input_throughput=total_input / dur_s, output_throughput=sum(output_lens) / dur_s, output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, p99_itl_ms=np.percentile(itls or 0, 99) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000, ) return metrics, output_lens async def benchmark( backend: str, api_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[DatasetRow], request_rate: float, disable_tqdm: bool, extra_request_body: Dict[str, Any], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] else: raise ValueError(f'Unknown backend: {backend}') if not args.disable_warmup: print('Starting initial single prompt test run...') start_warmup = time.perf_counter() test_request = input_requests[0] test_input = RequestFuncInput( model=model_id, prompt=test_request.prompt, api_url=api_url, prompt_len=test_request.prompt_len, output_len=test_request.output_len, extra_request_body=extra_request_body, image_data=test_request.image_data, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError('Initial test run failed - Please make sure benchmark arguments ' f'are correctly specified. Error: {test_output.error}') else: print('Initial test run completed. Starting main benchmark run...') end_warmup = time.perf_counter() print(f'warmup time: {end_warmup - start_warmup:.2f}s') time.sleep(1.5) pbar = None if disable_tqdm else tqdm(total=len(input_requests)) benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): request_func_input = RequestFuncInput( model=model_id, prompt=request.prompt, api_url=api_url, prompt_len=request.prompt_len, output_len=request.output_len, image_data=request.image_data, extra_request_body=extra_request_body, ) tasks.append(asyncio.create_task(request_func(request_func_input=request_func_input, pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) if pbar is not None: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time metrics, output_lens = calculate_metrics( input_requests=input_requests, outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, backend=backend, ) print('\n{s:{c}^{n}}'.format(s=' Serving Benchmark Result ', n=50, c='=')) print('{:<40} {:<10}'.format('Backend:', backend)) print('{:<40} {:<10}'.format('Traffic request rate:', request_rate)) print('{:<40} {:<10}'.format('Successful requests:', metrics.completed)) print('{:<40} {:<10.2f}'.format('Benchmark duration (s):', benchmark_duration)) print('{:<40} {:<10}'.format('Total input tokens:', metrics.total_input)) print('{:<40} {:<10}'.format('Total input text tokens:', metrics.total_input_text)) print('{:<40} {:<10}'.format('Total input vision tokens:', metrics.total_input_vision)) print('{:<40} {:<10}'.format('Total generated tokens:', metrics.total_output)) print('{:<40} {:<10}'.format('Total generated tokens (retokenized):', metrics.total_output_retokenized)) print('{:<40} {:<10.2f}'.format('Request throughput (req/s):', metrics.request_throughput)) print('{:<40} {:<10.2f}'.format('Input token throughput (tok/s):', metrics.input_throughput)) print('{:<40} {:<10.2f}'.format('Output token throughput (tok/s):', metrics.output_throughput)) print('{s:{c}^{n}}'.format(s='End-to-End Latency', n=50, c='-')) print('{:<40} {:<10.2f}'.format('Mean E2E Latency (ms):', metrics.mean_e2e_latency_ms)) print('{:<40} {:<10.2f}'.format('Median E2E Latency (ms):', metrics.median_e2e_latency_ms)) print('{s:{c}^{n}}'.format(s='Time to First Token', n=50, c='-')) print('{:<40} {:<10.2f}'.format('Mean TTFT (ms):', metrics.mean_ttft_ms)) print('{:<40} {:<10.2f}'.format('Median TTFT (ms):', metrics.median_ttft_ms)) print('{:<40} {:<10.2f}'.format('P99 TTFT (ms):', metrics.p99_ttft_ms)) print('{s:{c}^{n}}'.format(s='Time per Output Token (excl. 1st token)', n=50, c='-')) print('{:<40} {:<10.2f}'.format('Mean TPOT (ms):', metrics.mean_tpot_ms)) print('{:<40} {:<10.2f}'.format('Median TPOT (ms):', metrics.median_tpot_ms)) print('{:<40} {:<10.2f}'.format('P99 TPOT (ms):', metrics.p99_tpot_ms)) print('{s:{c}^{n}}'.format(s='Inter-token Latency', n=50, c='-')) print('{:<40} {:<10.2f}'.format('Mean ITL (ms):', metrics.mean_itl_ms)) print('{:<40} {:<10.2f}'.format('Median ITL (ms):', metrics.median_itl_ms)) print('{:<40} {:<10.2f}'.format('P99 ITL (ms):', metrics.p99_itl_ms)) print('=' * 50) if (metrics.median_ttft_ms is not None and metrics.mean_itl_ms is not None and metrics.output_throughput is not None): FIELD_NAMES = [ 'backend', 'dataset_name', 'sharegpt_output_len', 'random_input_len', 'random_output_len', 'random_range_ratio', 'request_rate', 'completed', 'total_input_tokens', 'total_output_tokens', 'duration', 'request_throughput', 'input_throughput', 'output_throughput', 'mean_e2e_latency_ms', 'mean_ttft_ms', 'mean_tpot_ms', 'mean_itl_ms' ] result = { 'backend': args.backend, 'dataset_name': args.dataset_name, 'request_rate': request_rate, 'total_input_tokens': metrics.total_input, 'total_output_tokens': metrics.total_output, 'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms, 'output_throughput': metrics.output_throughput, 'sharegpt_output_len': args.sharegpt_output_len, 'random_input_len': args.random_input_len, 'random_output_len': args.random_output_len, 'random_range_ratio': args.random_range_ratio, 'duration': benchmark_duration, 'completed': metrics.completed, 'request_throughput': metrics.request_throughput, 'input_throughput': metrics.input_throughput, 'mean_ttft_ms': metrics.mean_ttft_ms, 'mean_tpot_ms': metrics.mean_tpot_ms, 'mean_itl_ms': metrics.mean_itl_ms, } else: print(f'Error running benchmark for request rate: {request_rate}') print('-' * 30) # Determine output file name if args.output_file: output_file_name = args.output_file else: now = datetime.now().strftime('%m%d') if args.dataset_name == 'random': output_file_name = f'{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl' # noqa else: output_file_name = f'{args.backend}_{now}_{args.num_prompts}_sharegpt.csv' # noqa # Append results to a CSV file file_exists = os.path.isfile(output_file_name) with open(output_file_name, mode='a', newline='') as f: writer = csv.DictWriter(f, fieldnames=FIELD_NAMES) if not file_exists: writer.writeheader() writer.writerow(result) result = { 'duration': benchmark_duration, 'completed': metrics.completed, 'total_input_tokens': metrics.total_input, 'total_output_tokens': metrics.total_output, 'total_output_tokens_retokenized': metrics.total_output_retokenized, 'request_throughput': metrics.request_throughput, 'input_throughput': metrics.input_throughput, 'output_throughput': metrics.output_throughput, 'mean_ttft_ms': metrics.mean_ttft_ms, 'median_ttft_ms': metrics.median_ttft_ms, 'std_ttft_ms': metrics.std_ttft_ms, 'p99_ttft_ms': metrics.p99_ttft_ms, 'mean_tpot_ms': metrics.mean_tpot_ms, 'median_tpot_ms': metrics.median_tpot_ms, 'std_tpot_ms': metrics.std_tpot_ms, 'p99_tpot_ms': metrics.p99_tpot_ms, 'mean_itl_ms': metrics.mean_itl_ms, 'median_itl_ms': metrics.median_itl_ms, 'std_itl_ms': metrics.std_itl_ms, 'p99_itl_ms': metrics.p99_itl_ms, 'input_lens': [output.prompt_len for output in outputs], 'output_lens': output_lens, 'ttfts': [output.ttft for output in outputs], 'itls': [output.itl for output in outputs], 'generated_texts': [output.generated_text for output in outputs], 'errors': [output.error for output in outputs], 'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms, 'median_e2e_latency_ms': metrics.median_e2e_latency_ms, } return result def parse_request_rate_range(request_rate_range): if len(request_rate_range.split(',')) == 3: start, stop, step = map(int, request_rate_range.split(',')) return list(range(start, stop, step)) else: return list(map(int, request_rate_range.split(','))) def check_chat_template(model_path): try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return 'chat_template' in tokenizer.init_kwargs except Exception as e: print(f'Fail to load tokenizer config with error={e}') return False def run_benchmark(args_: argparse.Namespace): global args args = args_ # Set global environments set_ulimit() random.seed(args.seed) np.random.seed(args.seed) extra_request_body = {} if args.extra_request_body: extra_request_body = json.loads(args.extra_request_body) # Set url if args.port is None: args.port = { 'sglang': 30000, 'sglang-native': 30000, 'sglang-oai': 30000, 'sglang-oai-chat': 30000, 'lmdeploy': 23333, 'lmdeploy-chat': 23333, 'vllm': 8000, 'vllm-chat': 8000, 'trt': 8000, 'gserver': 9988, }.get(args.backend, 30000) model_url = (f'{args.base_url}/v1/models' if args.base_url else f'http://{args.host}:{args.port}/v1/models') if args.backend in ['sglang', 'sglang-native']: api_url = (f'{args.base_url}/generate' if args.base_url else f'http://{args.host}:{args.port}/generate') elif args.backend in ['sglang-oai', 'vllm', 'lmdeploy']: api_url = (f'{args.base_url}/v1/completions' if args.base_url else f'http://{args.host}:{args.port}/v1/completions') elif args.backend in ['lmdeploy-chat', 'vllm-chat', 'sglang-oai-chat']: api_url = (f'{args.base_url}/v1/chat/completions' if args.base_url else f'http://{args.host}:{args.port}/v1/chat/completions') elif args.backend == 'trt': api_url = ( f'{args.base_url}/v2/models/ensemble/generate_stream' if args.base_url else f'http://{args.host}:{args.port}/v2/models/ensemble/generate_stream' # noqa ) if args.model is None: print('Please provide a model using `--model` when using ' '`trt` backend.') sys.exit(1) elif args.backend == 'gserver': api_url = args.base_url if args.base_url else \ f'{args.host}:{args.port}' args.model = args.model or 'default' # Get model name if args.model is None: try: response = requests.get(model_url) model_list = response.json().get('data', []) args.model = model_list[0]['id'] if model_list else None except Exception as e: print(f'Failed to fetch model from {model_url}. Error: {e}') print('Please specify the correct host and port using ' '`--host` and `--port`.') sys.exit(1) # Read dataset backend = args.backend model_id = args.model model_path = args.model_path if args.model_path is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else model_path if args.model is None: print('No model specified or found. Please provide a model ' 'using `--model`.') sys.exit(1) if not check_chat_template(model_path): print('\nWARNING It is recommended to use the `Chat` or `Instruct` ' 'model for benchmarking.\n' 'Because when the tokenizer counts the output tokens, if ' 'there is gibberish, it might count incorrectly.\n') print(f'{args}\n') tokenizer = get_tokenizer(tokenizer_id) if args.dataset_name == 'sharegpt': assert args.random_input_len is None and args.random_output_len is None input_requests = sample_sharegpt_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == 'random': assert args.random_input_len is not None and \ args.random_output_len is not None input_requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, range_ratio=args.random_range_ratio, tokenizer=tokenizer, dataset_path=args.dataset_path, ) elif args.dataset_name == 'image': processor = get_processor(model_path) input_requests = sample_image_requests( num_requests=args.num_prompts, image_count=args.image_count, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, processor=processor, image_content=args.image_content, image_format=args.image_format, image_resolution=args.image_resolution, backend=args.backend, ) else: raise ValueError(f'Unknown dataset: {args.dataset_name}') if not args.multi: return asyncio.run( benchmark( backend=backend, api_url=api_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, )) else: # Benchmark multiple rps. # TODO: use a fixed duration to compute num_prompts request_rates = parse_request_rate_range(args.request_rate_range) for rate in request_rates: asyncio.run( benchmark( backend=backend, api_url=api_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, request_rate=rate, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, )) def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: print(f'Fail to set RLIMIT_NOFILE: {e}') if __name__ == '__main__': parser = ArgumentParser(description='Benchmark the online serving throughput.') parser.add_argument( '--backend', type=str, choices=list(ASYNC_REQUEST_FUNCS.keys()), default='sglang', help='Must specify a backend, depending on the LLM Inference Engine.', ) parser.add_argument( '--base-url', type=str, default=None, help='Server or API base url if not using http host and port.', ) parser.add_argument('--host', type=str, default='0.0.0.0', help='Default host is 0.0.0.0.') parser.add_argument( '--port', type=int, help='If not set, the default port is configured according to its ' 'default value for different LLM Inference Engines.', ) parser.add_argument( '--dataset-name', type=str, default='sharegpt', choices=['sharegpt', 'random', 'image'], help='Name of the dataset to benchmark on.', ) parser.add_argument('--dataset-path', type=str, default='', help='Path to the dataset.') parser.add_argument( '--model', type=str, help='Name or path of the model. If not set, the default model will ' 'request /v1/models for conf.', ) parser.add_argument( '--model-path', type=str, help='Path to the model. If not set, the default model will be model', ) parser.add_argument( '--tokenizer', type=str, help='Name or path of the tokenizer. If not set, using the model ' 'conf.', ) parser.add_argument( '--num-prompts', type=int, default=1000, help='Number of prompts to process. Default is 1000.', ) parser.add_argument( '--sharegpt-output-len', type=int, default=None, help='Output length for each request. Overrides the output length ' 'from the ShareGPT dataset.', ) parser.add_argument( '--random-input-len', type=int, help='Number of input tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-output-len', type=int, help='Number of output tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-range-ratio', type=float, default=0.0, help='Range of sampled ratio of input/output length, ' 'used only for random dataset.', ) # image dataset args parser.add_argument( '--image-count', type=int, default=1, help='Number of images per request (only available with the image dataset)', ) parser.add_argument( '--image-resolution', type=str, default='1080p', help=('Resolution of images for image dataset. ' "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."), ) parser.add_argument( '--image-format', type=str, default='jpeg', help=('Format of images for image dataset. ' 'Supports jpeg and png.'), ) parser.add_argument( '--image-content', type=str, default='random', help=('Content for images for image dataset. ' 'Supports random and blank.'), ) parser.add_argument( '--request-rate', type=float, default=float('inf'), help='Number of requests per second. If this is inf, then all the ' 'requests are sent at time 0. Otherwise, we use Poisson process to ' 'synthesize the request arrival times. Default is inf.', ) parser.add_argument('--seed', type=int, default=1, help='The random seed.') parser.add_argument( '--multi', action='store_true', help='Use request rate range rather than single value.', ) parser.add_argument( '--request-rate-range', type=str, default='2,34,2', help='Range of request rates in the format start,stop,step. Default ' 'is 2,34,2. It also supports a list of request rates, requiring ' 'the parameters to not equal three.', ) parser.add_argument('--output-file', type=str, help='Output JSONL file name.') parser.add_argument( '--disable-tqdm', action='store_true', help='Specify to disable tqdm progress bar.', ) parser.add_argument( '--disable-stream', action='store_true', help='Disable streaming mode.', ) parser.add_argument( '--disable-ignore-eos', action='store_true', help='Disable ignoring EOS.', ) parser.add_argument( '--extra-request-body', metavar='{"key1": "value1", "key2": "value2"}', type=str, help='Append given JSON object to the request payload. You can use ' 'this to specify additional generate params like sampling params.', ) parser.add_argument( '--disable-warmup', action='store_true', default=None, help='Disable a warmup request before the benchmark. ', ) args = parser.parse_args() run_benchmark(args) ================================================ FILE: benchmark/profile_throughput.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import argparse import asyncio import json import os import random from queue import Queue from typing import List, Optional, Tuple, Union import numpy as np from tqdm import tqdm from transformers import PreTrainedTokenizerBase from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.profiler import Profiler, Session from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import get_logger get_logger('lmdeploy').setLevel('ERROR') os.environ['TM_LOG_LEVEL'] = 'ERROR' def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError('output_len too small') # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) prompt_len = len(prompt_token_ids) output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None): # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}') print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}') return filtered_dataset def sample_random_requests( input_len: int, output_len: int, num_prompts: int, range_ratio: float, tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, size=num_prompts, ) output_lens = np.random.randint( int(output_len * range_ratio), output_len + 1, size=num_prompts, ) if True: # Sample token ids from ShareGPT and repeat/truncate them to # satisfy the input_lens # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data['conversations']) >= 2] # Only keep the first two turns of each conversation. dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] # remove the empty prompt dataset = [(query, answer) for query, answer in dataset if len(query) > 0] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short input_requests: List[Tuple[str, int, int]] = [] for i in range(num_prompts): # Tokenize the prompts and completions. prompt = dataset[i][0] prompt_token_ids = tokenizer.encode(prompt) prompt_len = len(prompt_token_ids) if prompt_len > input_lens[i]: input_ids = prompt_token_ids[:input_lens[i]] else: ratio = (input_lens[i] + prompt_len - 1) // prompt_len input_ids = (prompt_token_ids * ratio)[:input_lens[i]] prompt = tokenizer.decode(input_ids) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) else: # Sample token ids from random integers. # This can cause some NaN issues. offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) print(f'#Input tokens: {np.sum(input_lens)}') print(f'#Output tokens: {np.sum(output_lens)}') return input_requests class Engine: def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, TurbomindEngineConfig]): self.tokenizer = Tokenizer(model_path) if isinstance(engine_config, TurbomindEngineConfig): from lmdeploy.turbomind import TurboMind tm_model = TurboMind.from_pretrained(model_path, engine_config=engine_config) self.backend = 'turbomind' elif isinstance(engine_config, PytorchEngineConfig): from lmdeploy.pytorch.engine import Engine as PytorchEngine tm_model = PytorchEngine.from_pretrained(model_path, engine_config=engine_config) self.backend = 'pytorch' self.tm_model = tm_model self.pbar = None async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int, stream_output: bool, skip_tokenize: bool, skip_detokenize: bool, concurrency: int): model_inst = self.tm_model.create_instance() sess: Session = None for prompt, _, output_seqlen, cancel_after, sess in iter(req_queue.get_nowait, None): sess.tick(0) if skip_tokenize: input_ids = prompt else: input_ids = self.tokenizer(prompt).input_ids state = DetokenizeState(len(input_ids)) n_token = 0 token_ids = input_ids.copy() generator = model_inst.async_stream_infer(session_id, input_ids=input_ids, gen_config=GenerationConfig(max_new_tokens=output_seqlen, temperature=temperature, top_p=top_p, top_k=top_k, ignore_eos=True), sequence_start=True, sequence_end=True, stream_output=stream_output) try: async for outputs in generator: n_token += len(outputs.token_ids) token_ids += outputs.token_ids if not skip_detokenize: _, state = self.tokenizer.detokenize_incrementally(token_ids, state) sess.tick(n_token) if n_token > cancel_after: break sess.finish(Session.SUCCESS) finally: await generator.aclose() # for pytorch engine to restart a session if self.backend == 'pytorch': await model_inst.async_end(session_id) self.pbar.update(1) session_id += concurrency def process_request(self, requests, profiler: Profiler, concurrency, temperature, top_p, top_k, stream_output, skip_tokenize, skip_detokenize, cancel_rate): req_queue = Queue() # feed request to q for prompt, input_len, output_len in requests: cancel_after = output_len + 1 if cancel_rate > 0: if random.random() < cancel_rate: cancel_after = random.randint(0, cancel_after) sess = profiler.new_session(input_len, output_len) req = [prompt, input_len, output_len, cancel_after, sess] if skip_tokenize: req[0] = self.tokenizer.encode(prompt) req_queue.put(req) for i in range(concurrency): req_queue.put(None) # start threads tasks = [] for i in range(concurrency): task = self._inference(req_queue, i, temperature, top_p, top_k, stream_output, skip_tokenize, skip_detokenize, concurrency) tasks.append(task) async def _gather_tasks(tasks): profiler.start() ret = await asyncio.gather(*tasks) profiler.finish() return ret self.pbar = tqdm(total=len(requests)) asyncio.run(_gather_tasks(tasks)) self.pbar.close() def parse_args(): parser = argparse.ArgumentParser(description='Benchmark the request throughput of lmdeploy ' 'in localhost', formatter_class=DefaultsAndTypesHelpFormatter) parser.add_argument('dataset', type=str, help='the path dataset') parser.add_argument('model_path', type=str, help='the path of the model in localhost or ' 'the repo_id of the model in huggingface.co') parser.add_argument('-c', '--concurrency', type=int, help='Number of working threads to process the sampled prompts', default=256) parser.add_argument('-n', '--num-prompts', type=int, help='Number of prompts to process', default=5000) parser.add_argument('--no-stream-output', action='store_true', help='Use stream output') parser.add_argument('--skip-tokenize', action='store_true', help='Pre-tokenize input prompts before starting') parser.add_argument('--skip-detokenize', action='store_true', help='Skip detokenizing output tokens') parser.add_argument('--cancel-rate', type=float, help='Possibility of a request being canceled', default=0) parser.add_argument('--use-uvloop', action='store_true') parser.add_argument('--csv', type=str, help='Where to save the result.', default='./profile_throughput.csv') parser.add_argument('--seed', type=int, default=0, help='Seed used in sampling prompts from dataset') parser.add_argument('--distributed-executor-backend', type=str, default=None, choices=['uni', 'mp', 'ray'], help='backend of executor backend') parser.add_argument('--dataset-name', type=str, default='sharegpt', choices=['sharegpt', 'random'], help='Name of the dataset to benchmark on.') parser.add_argument( '--sharegpt-output-len', type=int, default=None, help='Output length for each request. Overrides the output length ' 'from the ShareGPT dataset.', ) parser.add_argument( '--random-input-len', type=int, help='Number of input tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-output-len', type=int, help='Number of output tokens per request, used only for random ' 'dataset.', ) parser.add_argument( '--random-range-ratio', type=float, default=0.0, help='Range of sampled ratio of input/output length, ' 'used only for random dataset.', ) # other args ArgumentHelper.top_p(parser) ArgumentHelper.temperature(parser) ArgumentHelper.top_k(parser) ArgumentHelper.backend(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) ArgumentHelper.dllm_block_length(pt_group) ArgumentHelper.dllm_unmasking_strategy(pt_group) ArgumentHelper.dllm_denoising_steps(pt_group) ArgumentHelper.dllm_confidence_threshold(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0) dtype_act = ArgumentHelper.dtype(pt_group) # turbomind engine args tb_group = parser.add_argument_group('TurboMind engine argument') tb_group._group_actions.append(tp_act) tb_group._group_actions.append(cache_count_act) tb_group._group_actions.append(cache_block_seq_len_act) tb_group._group_actions.append(prefix_caching_act) tb_group._group_actions.append(quant_policy_act) tb_group._group_actions.append(dtype_act) ArgumentHelper.dp(tb_group) ArgumentHelper.cp(tb_group) ArgumentHelper.model_format(tb_group, default='hf') ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.async_(tb_group) ArgumentHelper.communicator(tb_group) args = parser.parse_args() return args def main(): args = parse_args() random.seed(args.seed) if args.backend == 'turbomind': engine_config = TurbomindEngineConfig( max_batch_size=args.concurrency // args.dp, tp=args.tp, dp=args.dp, cp=args.cp, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, model_format=args.model_format, quant_policy=args.quant_policy, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, async_=args.async_, enable_prefix_caching=args.enable_prefix_caching, dtype=args.dtype, communicator=args.communicator, ) elif args.backend == 'pytorch': engine_config = PytorchEngineConfig( cache_max_entry_count=args.cache_max_entry_count, block_size=args.cache_block_seq_len, max_batch_size=args.concurrency, tp=args.tp, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, quant_policy=args.quant_policy, dtype=args.dtype, distributed_executor_backend=args.distributed_executor_backend, dllm_block_length=args.dllm_block_length, dllm_unmasking_strategy=args.dllm_unmasking_strategy, dllm_denoising_steps=args.dllm_denoising_steps, dllm_confidence_threshold=args.dllm_confidence_threshold, ) if args.use_uvloop: import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) engine = Engine(args.model_path, engine_config) if args.dataset_name == 'sharegpt': assert args.random_input_len is None and args.random_output_len is None requests = sample_sharegpt_requests( dataset_path=args.dataset, num_requests=args.num_prompts, tokenizer=engine.tokenizer.model.model, fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == 'random': assert args.random_input_len is not None and \ args.random_output_len is not None requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, range_ratio=args.random_range_ratio, tokenizer=engine.tokenizer.model.model, dataset_path=args.dataset, ) else: raise ValueError(f'Unknown dataset: {args.dataset_name}') stream_output = not args.no_stream_output profiler = Profiler(stream_output, [50, 75, 95, 99]) engine.process_request(requests, profiler, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, concurrency=args.concurrency if args.concurrency < args.num_prompts else args.num_prompts, stream_output=not args.no_stream_output, skip_tokenize=args.skip_tokenize, skip_detokenize=args.skip_detokenize, cancel_rate=args.cancel_rate) hyperparams = [('Concurrency', args.concurrency), ('Cancel rate', args.cancel_rate), ('Stream output', str(stream_output).lower()), ('Skip tokenize', str(args.skip_tokenize).lower()), ('Skip detokenize', str(args.skip_detokenize).lower())] profiler.compute_metrics() profiler.summarize(title='Profile Throughput', hyperparams=hyperparams) if args.csv: profiler.save_csv(args.csv, ( ('backend', args.backend), ('bs', args.concurrency), ('dataset_name', args.dataset_name), ('sharegpt_output_len', args.sharegpt_output_len), ('random_input_len', args.random_input_len), ('random_output_len', args.random_output_len), ('random_range_ratio', args.random_range_ratio), ('num_prompts', args.num_prompts), )) if __name__ == '__main__': main() ================================================ FILE: builder/manywheel/Dockerfile_2014 ================================================ # WARNING: CentOS 7 is out of date since 6/30/2024, we should use the following one in the future # FROM quay.io/pypa/manylinux_2_28_x86_64 as base FROM quay.io/pypa/manylinux2014_x86_64 as base ARG BASE_CUDA_VERSION=11.8 ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 RUN sed -i 's|^mirrorlist=|#mirrorlist=|g' /etc/yum.repos.d/CentOS-*.repo && \ sed -i 's|^#baseurl=http://mirror.centos.org|baseurl=https://vault.centos.org|g' /etc/yum.repos.d/CentOS-*.repo && \ yum install -y \ wget \ rapidjson-devel \ glog-devel && \ yum clean all ENV LD_LIBRARY_PATH=/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:$LD_LIBRARY_PATH FROM base as cuda COPY manywheel/scripts/install_cuda.sh /tmp/install_cuda.sh RUN bash /tmp/install_cuda.sh ${BASE_CUDA_VERSION} && rm /tmp/install_cuda.sh FROM base as conda COPY manywheel/scripts/install_conda.sh /tmp/install_conda.sh RUN bash /tmp/install_conda.sh && rm /tmp/install_conda.sh # Accept Anaconda's Terms of Service to avoid `CondaToSNonInteractiveError` RUN /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r RUN PY_VERSIONS=(3.10 3.11 3.12 3.13) && \ for pyver in "${PY_VERSIONS[@]}"; do \ /opt/conda/bin/conda create -n py${pyver//./} python=${pyver} -yq && \ /opt/conda/envs/py${pyver//./}/bin/pip install -i 'https://mirrors.aliyun.com/pypi/simple/' --no-cache-dir pybind11; \ done && \ /opt/conda/bin/conda clean -ya FROM base as cuda_final COPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION} RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda ENV PATH=/usr/local/cuda/bin:$PATH COPY --from=conda /opt/conda /opt/conda RUN /opt/conda/bin/conda init bash ================================================ FILE: builder/manywheel/README.md ================================================ # LMDeploy Build System ## Building lmdeploy builder images To build all lmdeploy builder images, such as "lmdeploy-builder:cuda11.8", ""lmdeploy-builder:cuda12.4", execute: ```bash ./build_all_lmdeploy_builders.sh # Build and push images (for CI/CD) WITH_PUSH=true ./build_all_lmdeploy_builders.sh ``` For custom builds with specific versions: ```bash MANY_LINUX_VERSION=2014 GPU_ARCH_VERSION=12.4 ./build_lmdeploy_builder.sh ``` ## Build lmdeploy wheels Compile all wheel packages: ```bash ./build_all_wheel.sh ``` ================================================ FILE: builder/manywheel/build_all_lmdeploy_builders.sh ================================================ #!/usr/bin/env bash set -eou pipefail TOPDIR=$(git rev-parse --show-toplevel)/builder for cuda_version in 12.4 12.6 12.8; do MANY_LINUX_VERSION=2014 GPU_ARCH_VERSION="${cuda_version}" "${TOPDIR}/manywheel/build_lmdeploy_builder.sh" done ================================================ FILE: builder/manywheel/build_all_wheel.sh ================================================ #!/usr/bin/env bash set -eou pipefail TOPDIR=$(git rev-parse --show-toplevel)/builder CUDA_VER=${CUDA_VER:-12.8} PLAT_NAME=manylinux2014_x86_64 for cuver in ${CUDA_VER}; do DOCKER_TAG=cuda${cuver} OUTPUT_FOLDER=cuda${cuver}_dist for pyver in py310 py311 py312 py313; do bash ${TOPDIR}/manywheel/build_wheel.sh ${pyver} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} \ |& tee ${PLAT_NAME}.${pyver}.cuda${cuver}.log.txt done done ================================================ FILE: builder/manywheel/build_lmdeploy_builder.sh ================================================ #!/usr/bin/env bash set -eou pipefail TOPDIR=$(git rev-parse --show-toplevel)/builder GPU_ARCH_VERSION=${GPU_ARCH_VERSION} WITH_PUSH=${WITH_PUSH:-} TARGET=cuda_final DOCKER_TAG=cuda${GPU_ARCH_VERSION} DOCKER_IMAGE=openmmlab/lmdeploy-builder:${DOCKER_TAG} DOCKERFILE_SUFFIX=$([[ -n ${MANY_LINUX_VERSION} ]] && echo "_${MANY_LINUX_VERSION}" || echo "") # List of all build arguments (format: KEY=VALUE) # Empty values will be automatically filtered out later BUILD_ARGS=( "BASE_CUDA_VERSION=${GPU_ARCH_VERSION}" "DEVTOOLSET_VERSION=9" "HTTPS_PROXY=${HTTPS_PROXY:-}" "HTTP_PROXY=${HTTP_PROXY:-}" # Add more parameters here if needed ) # Base Docker build command arguments docker_build_args=( -t "${DOCKER_IMAGE}" --target "${TARGET}" -f "${TOPDIR}/manywheel/Dockerfile${DOCKERFILE_SUFFIX}" ) # Process build arguments: filter empty values and format as --build-arg for arg in "${BUILD_ARGS[@]}"; do IFS='=' read -r key value <<< "$arg" # Split KEY=VALUE if [[ -n "$value" ]]; then # Only add non-empty values docker_build_args+=(--build-arg "$arg") fi done ( set -x DOCKER_BUILDKIT=1 docker build "${docker_build_args[@]}" "${TOPDIR}" ) if [[ "${WITH_PUSH}" == true ]]; then ( set -x docker push "${DOCKER_IMAGE}" ) fi ================================================ FILE: builder/manywheel/build_wheel.sh ================================================ #!/usr/bin/env bash set -eux PYTHON_VERSION="$1" PLAT_NAME="$2" DOCKER_TAG="$3" OUTPUT_DIR="$4" DOCKER_IMAGE="openmmlab/lmdeploy-builder:${DOCKER_TAG}" export USERID=$(id -u) export GROUPID=$(id -g) cd "$(dirname "$0")" # move inside the script directory mkdir -p "${OUTPUT_DIR}" docker pull ${DOCKER_IMAGE} docker run --rm -it \ --env PYTHON_VERSION="${PYTHON_VERSION}" \ --env PLAT_NAME="${PLAT_NAME}" \ --env USERID="${USERID}" \ --env GROUPID="${GROUPID}" \ --volume "$(pwd)/../../:/lmdeploy" \ --volume "$(pwd)/${OUTPUT_DIR}:/lmdeploy_build" \ --volume "$(pwd)/entrypoint_build.sh:/entrypoint_build.sh" \ --entrypoint /entrypoint_build.sh \ ${DOCKER_IMAGE} ================================================ FILE: builder/manywheel/entrypoint_build.sh ================================================ #!/usr/bin/env bash set -eux export PYTHON_VERSION=$PYTHON_VERSION export PLAT_NAME=$PLAT_NAME export USERID=${USERID} export GROUPID=${GROUPID} export NCCL_INCLUDE_DIR=/usr/local/cuda/include export NCCL_LIB_DIR=/usr/local/cuda/lib64 source /opt/conda/bin/activate conda activate $PYTHON_VERSION cd lmdeploy pip install build change-wheel-version python -m build --wheel -o /tmpbuild/ for file in $(find /tmpbuild/ -name "*.whl") do platform_tag="$(basename $file | cut -d- -f3-4)-${PLAT_NAME}" change_wheel_version /tmpbuild/*.whl --delete-old-wheel --platform-tag ${platform_tag} done chown ${USERID}:${GROUPID} /tmpbuild/* mv /tmpbuild/* /lmdeploy_build/ ================================================ FILE: builder/manywheel/scripts/install_conda.sh ================================================ #!/bin/bash set -ex wget -q https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh chmod +x Miniconda3-latest-Linux-x86_64.sh bash ./Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda rm Miniconda3-latest-Linux-x86_64.sh ================================================ FILE: builder/manywheel/scripts/install_cuda.sh ================================================ #!/bin/bash set -ex function install_118 { echo "Installing CUDA 11.8 and NCCL 2.15" rm -rf /usr/local/cuda-11.8 /usr/local/cuda # install CUDA 11.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run chmod +x cuda_11.8.0_520.61.05_linux.run ./cuda_11.8.0_520.61.05_linux.run --toolkit --silent rm -f cuda_11.8.0_520.61.05_linux.run rm -f /usr/local/cuda && ln -s /usr/local/cuda-11.8 /usr/local/cuda # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses mkdir tmp_nccl && cd tmp_nccl wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.15.5/nccl_2.15.5-1+cuda11.8_x86_64.txz tar xf nccl_2.15.5-1+cuda11.8_x86_64.txz cp -a nccl_2.15.5-1+cuda11.8_x86_64/include/* /usr/local/cuda/include/ cp -a nccl_2.15.5-1+cuda11.8_x86_64/lib/* /usr/local/cuda/lib64/ cd .. rm -rf tmp_nccl ldconfig } function install_121 { echo "Installing CUDA 12.1 and NCCL 2.18.1" rm -rf /usr/local/cuda-12.1 /usr/local/cuda # install CUDA 12.1.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run chmod +x cuda_12.1.0_530.30.02_linux.run ./cuda_12.1.0_530.30.02_linux.run --toolkit --silent rm -f cuda_12.1.0_530.30.02_linux.run rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.1 /usr/local/cuda # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses mkdir tmp_nccl && cd tmp_nccl wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.18.1/nccl_2.18.1-1+cuda12.1_x86_64.txz tar xf nccl_2.18.1-1+cuda12.1_x86_64.txz cp -a nccl_2.18.1-1+cuda12.1_x86_64/include/* /usr/local/cuda/include/ cp -a nccl_2.18.1-1+cuda12.1_x86_64/lib/* /usr/local/cuda/lib64/ cd .. rm -rf tmp_nccl ldconfig } function install_124 { echo "Installing CUDA 12.4 and NCCL 2.25.1" rm -rf /usr/local/cuda-12.4 /usr/local/cuda # install CUDA 12.4.1 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run chmod +x cuda_12.4.1_550.54.15_linux.run ./cuda_12.4.1_550.54.15_linux.run --toolkit --silent rm -f cuda_12.4.1_550.54.15_linux.run rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.4 /usr/local/cuda # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses mkdir tmp_nccl && cd tmp_nccl wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.25.1/nccl_2.25.1-1+cuda12.4_x86_64.txz tar xf nccl_2.25.1-1+cuda12.4_x86_64.txz cp -a nccl_2.25.1-1+cuda12.4_x86_64/include/* /usr/local/cuda/include/ cp -a nccl_2.25.1-1+cuda12.4_x86_64/lib/* /usr/local/cuda/lib64/ cd .. rm -rf tmp_nccl ldconfig } function install_126 { echo "Installing CUDA 12.6 and NCCL 2.24.3" rm -rf /usr/local/cuda-12.6 /usr/local/cuda # install CUDA 12.6.3 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run chmod +x cuda_12.6.3_560.35.05_linux.run ./cuda_12.6.3_560.35.05_linux.run --toolkit --silent rm -f cuda_12.6.3_560.35.05_linux.run rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.6 /usr/local/cuda # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses mkdir tmp_nccl && cd tmp_nccl wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.24.3/nccl_2.24.3-1+cuda12.6_x86_64.txz tar xf nccl_2.24.3-1+cuda12.6_x86_64.txz cp -a nccl_2.24.3-1+cuda12.6_x86_64/include/* /usr/local/cuda/include/ cp -a nccl_2.24.3-1+cuda12.6_x86_64/lib/* /usr/local/cuda/lib64/ cd .. rm -rf tmp_nccl ldconfig } function install_128 { echo "Installing CUDA 12.8 and NCCL 2.25.1" rm -rf /usr/local/cuda-12.8 /usr/local/cuda # install CUDA 12.8.1 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_570.124.06_linux.run chmod +x cuda_12.8.1_570.124.06_linux.run ./cuda_12.8.1_570.124.06_linux.run --toolkit --silent rm -f cuda_12.8.1_570.124.06_linux.run rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.8 /usr/local/cuda # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses mkdir tmp_nccl && cd tmp_nccl wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.25.1/nccl_2.25.1-1+cuda12.8_x86_64.txz tar xf nccl_2.25.1-1+cuda12.8_x86_64.txz cp -a nccl_2.25.1-1+cuda12.8_x86_64/include/* /usr/local/cuda/include/ cp -a nccl_2.25.1-1+cuda12.8_x86_64/lib/* /usr/local/cuda/lib64/ cd .. rm -rf tmp_nccl ldconfig } if test $# -eq 0 then echo "doesn't provide cuda version"; exit 1; fi # idiomatic parameter and option handling in sh while test $# -gt 0 do case "$1" in 11.8) install_118 ;; 12.1) install_121 ;; 12.4) install_124 ;; 12.6) install_126 ;; 12.8) install_128 ;; *) echo "bad argument $1"; exit 1 ;; esac shift done ================================================ FILE: builder/manywheel/scripts/install_openmpi.sh ================================================ #!/bin/bash set -ex wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz tar xf openmpi-4.1.5.tar.gz cd openmpi-4.1.5 ./configure --prefix=/usr/local/mpi make -j$(nproc) make install ================================================ FILE: builder/windows/README.md ================================================ # Build lmdeploy on windows ## Requirements - [CMake 3.17+](https://github.com/Kitware/CMake/releases) - [Visual Studio 2019+](https://visualstudio.microsoft.com/downloads/) - [CUDA Toolkit 11.8+](https://developer.nvidia.com/cuda-toolkit-archive) ## Build lmdeploy wheel ```powershell pip install build python -m build --wheel ``` ================================================ FILE: builder/windows/generate.ps1 ================================================ cmake .. -A x64 -T "v143,cuda=$env:CUDA_PATH" ` -DCMAKE_BUILD_TYPE=Release ` -DCMAKE_INSTALL_PREFIX=install ` -DBUILD_PY_FFI=ON ` -DBUILD_MULTI_GPU=OFF ` -DUSE_NVTX=OFF ` -DBUILD_TEST="$env:BUILD_TEST" ================================================ FILE: builder/windows/setup_cuda.ps1 ================================================ # Copyright (c) OpenMMLab. All rights reserved. # Adapted from https://github.com/thewh1teagle/vibe/blob/5d7b75568ca65ab635bdf0ce912bbc975a043066/scripts/setup_cuda.ps1 $CUDA_VERSION_FULL = $env:INPUT_CUDA_VERSION # v12.1.0 or v11.8.0 # Make sure CUDA_VERSION_FULL is set and valid, otherwise error. # Validate CUDA version, extracting components via regex $cuda_ver_matched = $CUDA_VERSION_FULL -match "^(?[1-9][0-9]*)\.(?[0-9]+)\.(?[0-9]+)$" if(-not $cuda_ver_matched){ Write-Output "Invalid CUDA version specified, .. required. '$CUDA_VERSION_FULL'." exit 1 } $CUDA_MAJOR=$Matches.major $CUDA_MINOR=$Matches.minor $CUDA_PATCH=$Matches.patch Write-Output "Selected CUDA version: $CUDA_VERSION_FULL" $src = "cuda" $dst = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$($CUDA_MAJOR).$($CUDA_MINOR)" $installer = "cuda.exe" if ($CUDA_VERSION_FULL -eq "12.1.0") { $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_531.14_windows.exe" } elseif ($CUDA_VERSION_FULL -eq "11.8.0") { $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe" } elseif ($CUDA_VERSION_FULL -eq "12.5.0") { $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe" } elseif ($CUDA_VERSION_FULL -eq "12.6.2") { $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.94_windows.exe" } elseif ($CUDA_VERSION_FULL -eq "12.8.1") { $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe" } else { Write-Output "Unsupported CUDA version specified" exit 1 } # Download cuda Write-Output "Downloading CUDA from: $downloadUrl" if (-not (Test-Path -Path $installer)) { Write-Output "Downloading CUDA installer..." # If the file does not exist, download it & "C:\msys64\usr\bin\wget" $downloadUrl -O $installer -q } # Extract cuda if (-not (Test-Path -Path $src -Type Container)) { # Extract CUDA using 7-Zip Write-Output "Extracting CUDA using 7-Zip..." mkdir "$src" & 'C:\Program Files\7-Zip\7z' x $installer -o"$src" } # Create destination directory if it doesn't exist if (-Not (Test-Path -Path $dst)) { Write-Output "Creating destination directory: $dst" New-Item -Path $dst -ItemType Directory } # Get directories to process from the source path $directories = Get-ChildItem -Directory -Path $src $whitelist = @("CUDA_Toolkit_Release_Notes.txt", "DOCS", "EULA.txt", "LICENSE", "README", "version.json") foreach ($dir in $directories) { # Get all subdirectories and files in the current directory $items = Get-ChildItem -Path (Join-Path $src $dir.Name) foreach ($item in $items) { if ($item.PSIsContainer) { # If the item is a directory, copy its contents Write-Output "Copying contents of directory $($item.FullName) to $dst" Copy-Item -Path "$($item.FullName)\*" -Destination $dst -Recurse -Force } else { if ($whitelist -contains $item.Name) { Write-Output "Copying file $($item.FullName) to $dst" Copy-Item -Path $item.FullName -Destination $dst -Force } } } } # Add msbuild cuda extensions $msBuildExtensions = (Get-ChildItem "$src\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions").fullname (Get-ChildItem 'C:\Program Files\Microsoft Visual Studio\2022\*\MSBuild\Microsoft\VC\*\BuildCustomizations').FullName | ForEach-Object { $destination = $_ $msBuildExtensions | ForEach-Object { $extension = $_ Copy-Item $extension -Destination $destination -Force Write-Output "Copied $extension to $destination" } } $CUDA_FLAGS="-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH=1" # Add to Github env Write-Output "Setting environment variables for GitHub Actions..." Write-Output "CUDA_PATH=$dst" Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" Write-Output "CUDA_PATH=$dst" >> $env:GITHUB_ENV Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" >> $env:GITHUB_ENV Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" >> $env:GITHUB_ENV Write-Output "CudaToolkitDir=$dst" >> $env:GITHUB_ENV Write-Output "CMAKE_CUDA_COMPILER=$dst\bin\nvcc.exe" >> $env:GITHUB_ENV Write-Output "NVCC_APPEND_FLAGS=$CUDA_FLAGS" >> $env:GITHUB_ENV Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" >> $env:GITHUB_ENV Write-Output "Setup completed." ================================================ FILE: cmake/Modules/FindNCCL.cmake ================================================ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # From PyTorch: # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) # Copyright (c) 2011-2013 NYU (Clement Farabet) # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) # # From Caffe2: # # Copyright (c) 2016-present, Facebook Inc. All rights reserved. # # All contributions by Facebook: # Copyright (c) 2016 Facebook Inc. # # All contributions by Google: # Copyright (c) 2015 Google Inc. # All rights reserved. # # All contributions by Yangqing Jia: # Copyright (c) 2015 Yangqing Jia # All rights reserved. # # All contributions by Kakao Brain: # Copyright 2019-2020 Kakao Brain # # All contributions from Caffe: # Copyright(c) 2013, 2014, 2015, the respective contributors # All rights reserved. # # All other contributions: # Copyright(c) 2015, 2016 the respective contributors # All rights reserved. # # Caffe2 uses a copyright model similar to Caffe: each contributor holds # copyright over their contributions to Caffe2. The project versioning records # all such contribution and copyright details. If a contributor wants to further # mark their specific copyright on a particular contribution, they should # indicate their copyright solely in the commit message of the change when it is # committed. # # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America # and IDIAP Research Institute nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # # Find the nccl libraries # # The following variables are optionally searched for defaults # NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… # NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo # NCCL_LIB_DIR: Directory where NCCL library is found # # The following are set after configuration is done: # NCCL_FOUND # NCCL_INCLUDE_DIRS # NCCL_LIBRARIES # # The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks # install NCCL in the same location as the CUDA toolkit. # See https://github.com/caffe2/caffe2/issues/1601 set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") if ($ENV{NCCL_ROOT_DIR}) message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") endif() list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) # Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) find_path(NCCL_INCLUDE_DIRS NAMES nccl.h HINTS ${NCCL_INCLUDE_DIR}) if (USE_STATIC_NCCL) MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") SET(NCCL_LIBNAME "nccl_static") if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) endif() else() SET(NCCL_LIBNAME "nccl") if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) endif() endif() find_library(NCCL_LIBRARIES NAMES ${NCCL_LIBNAME} HINTS ${NCCL_LIB_DIR}) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) if(NCCL_FOUND) # obtaining NCCL version and some sanity checks set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) include(CheckCXXSymbolExists) check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) if (NCCL_VERSION_DEFINED) set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") file(WRITE ${file} " #include #include int main() { std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; int x; ncclGetVersion(&x); return x == NCCL_VERSION_CODE; } ") try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" LINK_LIBRARIES ${NCCL_LIBRARIES}) if (NOT NCCL_VERSION_MATCHED) message(FATAL_ERROR "Found NCCL header version and library version do not match! \ (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") endif() message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") else() # message(STATUS "NCCL version < 2.3.5-5") endif () set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) endif() ================================================ FILE: cmake/TritonTurboMindBackendConfig.cmake.in ================================================ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. include(CMakeFindDependencyMacro) get_filename_component( TRITONPYTORCHBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH ) list(APPEND CMAKE_MODULE_PATH ${TRITONPYTORCHBACKEND_CMAKE_DIR}) if(NOT TARGET TritonPyTorchBackend::triton-pytorch-backend) include("${TRITONPYTORCHBACKEND_CMAKE_DIR}/TritonPyTorchBackendTargets.cmake") endif() set(TRITONPYTORCHBACKEND_LIBRARIES TritonPyTorchBackend::triton-pytorch-backend) ================================================ FILE: cmake/TurboMindConfig.cmake.in ================================================ # Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. include(CMakeFindDependencyMacro) get_filename_component( TURBOMIND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH ) list(APPEND CMAKE_MODULE_PATH ${TURBOMIND_CMAKE_DIR}) if(NOT TARGET transformer-shared) include("${TURBOMIND_CMAKE_DIR}/TurboMindTargets.cmake") endif() set(TURBOMIND_LIBRARIES transformer-shared) ================================================ FILE: cmake/yaml-cpp_cmake_policy.patch ================================================ diff --git a/CMakeLists.txt b/CMakeLists.txt index 46dc180..b746ac1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ # 3.5 is actually available almost everywhere, but this a good minimum -cmake_minimum_required(VERSION 3.4) +cmake_minimum_required(VERSION 3.5) # enable MSVC_RUNTIME_LIBRARY target property # see https://cmake.org/cmake/help/latest/policy/CMP0091.html ================================================ FILE: debug.sh ================================================ #!/bin/bash -e builder="-G Ninja" if [ "$1" == "make" ]; then builder="" fi cmake ${builder} .. \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DCMAKE_INSTALL_PREFIX=./install \ -DBUILD_PY_FFI=ON \ -DBUILD_MULTI_GPU=ON \ -DCMAKE_CUDA_FLAGS="-lineinfo" \ -DUSE_NVTX=ON \ -DPYTHON_EXECUTABLE=$(which python3) \ -DFETCHCONTENT_QUIET=OFF \ -DBUILD_TEST=ON ================================================ FILE: docker/Dockerfile ================================================ # Base images ARG IMAGE_TYPE=final ARG CUDA_VERSION=cu12 FROM nvidia/cuda:13.0.2-devel-ubuntu22.04 AS cu13 ENV CUDA_VERSION_SHORT=cu130 FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8 ENV CUDA_VERSION_SHORT=cu128 FROM nvidia/cuda:12.6.3-devel-ubuntu22.04 AS cu12 ENV CUDA_VERSION_SHORT=cu126 # Builder image FROM ${CUDA_VERSION} AS dev ARG PYTHON_VERSION=3.10 ENV PATH=/opt/py3/bin:/root/.local/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV TZ=Etc/UTC RUN --mount=type=cache,target=/root/.cache \ sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list && \ apt-get update -y && \ apt-get install -y --no-install-recommends \ tzdata wget curl openssh-server ssh sudo git-core \ libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1 \ libssl-dev pkg-config vim rapidjson-dev libgoogle-glog-dev gdb && \ apt-get clean -y && \ rm -rf /var/lib/apt/lists/* && \ wget -qO- https://astral.sh/uv/install.sh | sh && \ uv venv -p python${PYTHON_VERSION} --seed /opt/py3 && \ pip install --upgrade pip build FROM dev AS builder # Should be in the lmdeploy root directory when building docker image COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache \ pip install -r requirements/runtime_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} RUN --mount=type=cache,target=/root/.cache \ docker/build.sh RUN --mount=type=cache,target=/root/.cache \ docker/prepare_wheel.sh # Runtime image FROM nvidia/cuda:13.0.2-base-ubuntu22.04 AS cu13-base ENV CUDA_VERSION_SHORT=cu130 FROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS cu12.8-base ENV CUDA_VERSION_SHORT=cu128 FROM nvidia/cuda:12.6.3-base-ubuntu22.04 AS cu12-base ENV CUDA_VERSION_SHORT=cu126 FROM ${CUDA_VERSION}-base AS final ARG PYTHON_VERSION=3.10 # Some dependencies such as timm(required by InternVL models) are missed in the docker image # We need to install them via pip. Since these dependencies are listed in requirements/serve.txt, # we copy the requirements directory here. COPY requirements /tmp/requirements COPY docker/install.sh /tmp/install.sh RUN --mount=type=cache,target=/root/.cache \ --mount=type=cache,target=/wheels,from=builder,source=/wheels \ /tmp/install.sh # explicitly set ptxas path for triton ENV PATH=/opt/py3/bin:$PATH ENV TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas FROM ${IMAGE_TYPE} ================================================ FILE: docker/Dockerfile.jetson ================================================ # Base images FROM nvcr.io/nvidia/l4t-base:r36.2.0 ENV CUDA_VER=12.6 \ PYTHON_VERSION=3.10 \ PATH=/opt/py3/bin:/root/.local/bin:/usr/local/cuda/bin:${PATH} RUN --mount=type=cache,target=/root/.cache \ --mount=type=cache,target=/tmp/download \ export CUDA_SUFFIX=$(echo $CUDA_VER | sed 's/\./-/g') && \ cd /tmp/download && \ mkdir -p /opt/nvidia/l4t-packages/ && \ touch /opt/nvidia/l4t-packages/.nv-l4t-disable-boot-fw-update-in-preinstall && \ wget -q "https://repo.download.nvidia.com/jetson/t234/pool/main/n/nvidia-l4t-core/nvidia-l4t-core_36.2.0-20231218214829_arm64.deb" && \ wget -q "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64/cuda-keyring_1.1-1_all.deb" && \ yes | dpkg -i nvidia-l4t-core_*.deb cuda-keyring_*.deb && \ rm -rf *.deb *.deb.* && \ apt update -y && \ apt-get install -y --no-install-recommends \ cuda-toolkit-${CUDA_SUFFIX} cuda-compat-${CUDA_SUFFIX} libcudnn9-cuda-12 libcusparselt0 cudss \ git libopenblas-dev python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \ apt-get clean -y && \ rm -rf /var/lib/apt/lists/* && \ python${PYTHON_VERSION} -m venv /opt/py3 && \ mkdir -p /wheels # Should be in the lmdeploy root directory when building docker image COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache \ --mount=type=cache,target=/opt/pytorch \ pip install build change-wheel-version && \ python -m build -w -o /wheels -v . && \ change_wheel_version --local-version cu126 --delete-old-wheel /wheels/lmdeploy*.whl && \ pip install -v /wheels/lmdeploy*.whl --index-url https://pypi.jetson-ai-lab.io/jp6/cu126/+simple/ ================================================ FILE: docker/Dockerfile_ascend_a2_300i ================================================ # DOCKER_BUILDKIT=1 docker build --build-arg ASCEND_DEVICE_TYPE=ascend_a2 \ # --build-arg DLINFER_TAG=main --build-arg LMDEPLOY_TAG=main --network=host \ # -t lmdeploy_dlinfer:a2 -f Dockerfile_ascend_a2_300i . ARG ASCEND_DEVICE_TYPE=ascend_a2 ARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub FROM ${ASCEND_HUB}/cann:8.3.rc1-910b-ubuntu22.04-py3.11 AS ascend_a2_base FROM ${ASCEND_HUB}/cann:8.3.rc1-310p-ubuntu22.04-py3.11 AS ascend_300i_base FROM ${ASCEND_DEVICE_TYPE}_base AS builder ENV DEBIAN_FRONTEND=noninteractive RUN apt update -y && \ apt install -y libjemalloc-dev git && \ apt clean && rm -rf /var/lib/apt/lists/* ENV HCCL_CONNECT_TIMEOUT=7200 \ PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" \ HCCL_OP_EXPANSION_MODE="AIV" \ LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so:$LD_PRELOAD ARG DLINFER_TAG=main ARG LMDEPLOY_TAG=main RUN --mount=type=cache,target=/root/.cache \ pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \ pip install --no-cache-dir torch==2.8.0 torch-npu==2.8.0 torchvision==0.23.0 && \ TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \ LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG} ================================================ FILE: docker/Dockerfile_ascend_a3 ================================================ # DOCKER_BUILDKIT=1 docker build --build-arg ASCEND_DEVICE=ascend_a3 \ # --build-arg DLINFER_TAG=main --build-arg LMDEPLOY_TAG=main --network=host \ # -t lmdeploy_dlinfer:a3 -f Dockerfile_ascend_a3 . ARG ASCEND_DEVICE_TYPE=ascend_a3 ARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub FROM ${ASCEND_HUB}/cann:8.5.0-a3-openeuler24.03-py3.11 AS ascend_a3_base FROM ${ASCEND_DEVICE_TYPE}_base AS builder ENV DEBIAN_FRONTEND=noninteractive RUN dnf update -y && \ dnf install -y jemalloc jemalloc-devel && \ dnf clean all && rm -rf /var/cache/dnf ENV HCCL_CONNECT_TIMEOUT=7200 \ PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" \ HCCL_OP_EXPANSION_MODE="AIV" \ LD_PRELOAD=/usr/lib64/libjemalloc.so.2:$LD_PRELOAD ARG DLINFER_TAG=main ARG LMDEPLOY_TAG=main RUN --mount=type=cache,target=/root/.cache \ pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \ pip install --no-cache-dir torch==2.9.0 torch-npu==2.9.0 torchvision==0.24.0 && \ TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \ LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG} ================================================ FILE: docker/Dockerfile_dev ================================================ FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8 # environment variables ENV DEBIAN_FRONTEND=noninteractive \ TZ=Etc/UTC \ PATH=/opt/py3/bin:/root/.local/bin:${PATH} \ TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \ CUDA_VERSION_SHORT=cu128 # Install dependencies and create python virtual environment RUN --mount=type=cache,target=/var/cache/apt \ --mount=type=cache,target=/root/.cache \ sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list && \ apt-get update -y && \ apt-get install -y --no-install-recommends \ tzdata wget curl openssh-server ssh sudo git-core \ libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 \ libibverbs-dev rdma-core libmlx5-1 libssl-dev pkg-config \ vim rapidjson-dev libgoogle-glog-dev gdb cmake build-essential \ python3-dev ninja-build htop tree jq unzip && \ apt-get clean -y && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ # install UV wget -qO- https://astral.sh/uv/install.sh | sh && \ # create Python virtual environment uv venv -p python3.12 --seed /opt/py3 # Should be in the lmdeploy root directory when building docker image COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy # install lmdeploy and its dependencies RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/cu128 && \ uv pip install -e . RUN --mount=type=cache,target=/root/.cache/uv \ docker/prepare_wheel.sh RUN --mount=type=cache,target=/root/.cache/uv \ cp -r requirements /tmp/requirements && \ docker/install.sh # Clean up to reduce image size RUN uv cache clean && \ rm -rf /wheels /tmp/* /var/tmp/* /root/.cache/uv/* && \ find /opt/lmdeploy -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true && \ find /opt/lmdeploy -type f -name "*.pyc" -delete 2>/dev/null || true ================================================ FILE: docker/InternVL_Dockerfile ================================================ ARG CUDA_VERSION=cu12 FROM openmmlab/lmdeploy:latest-cu12 AS cu12 ENV CUDA_VERSION_SHORT=cu123 FROM openmmlab/lmdeploy:latest-cu11 AS cu11 ENV CUDA_VERSION_SHORT=cu118 FROM ${CUDA_VERSION} AS final RUN python3 -m pip install timm!=1.0.23 RUN python3 -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+${CUDA_VERSION_SHORT}torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ================================================ FILE: docker/Qwen2VL_Dockerfile ================================================ ARG CUDA_VERSION=cu12 FROM openmmlab/lmdeploy:latest-cu12 AS cu12 ENV CUDA_VERSION_SHORT=cu123 FROM openmmlab/lmdeploy:latest-cu11 AS cu11 ENV CUDA_VERSION_SHORT=cu118 FROM ${CUDA_VERSION} AS final # we use transformers to load vision part of qwen2_vl and it needs transformers > v4.44.2 RUN python3 -m pip install git+https://github.com/huggingface/transformers.git RUN python3 -m pip install qwen_vl_utils ================================================ FILE: docker/build.sh ================================================ #!/bin/bash -ex mkdir -p /wheels if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then pip install nvidia-nccl-cu13 else pip install nvidia-nccl-cu12 fi python3 -m build -w -o /wheels -v . ================================================ FILE: docker/install.sh ================================================ #!/bin/bash -ex # Skip system setup if virtual env already exists (e.g., in dev image) if [ ! -f "/opt/py3/bin/python" ]; then # install system packages export DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list apt-get update -y apt-get install -y --no-install-recommends \ tzdata wget curl ssh sudo git-core vim libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1 if [[ ${PYTHON_VERSION} != "3.10" ]]; then apt-get install -y --no-install-recommends software-properties-common add-apt-repository -y ppa:deadsnakes/ppa apt-get update -y fi # install python, create virtual env apt-get install -y --no-install-recommends \ python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv pushd /opt >/dev/null python${PYTHON_VERSION} -m venv py3 popd >/dev/null # install CUDA build tools if [[ "${CUDA_VERSION_SHORT}" = "cu126" ]]; then apt-get install -y --no-install-recommends cuda-minimal-build-12-6 numactl dkms elif [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then apt-get install -y --no-install-recommends cuda-minimal-build-12-8 numactl dkms elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then apt-get install -y --no-install-recommends cuda-minimal-build-13-0 numactl dkms fi apt-get clean -y rm -rf /var/lib/apt/lists/* fi # install GDRCopy debs if [ "$(ls -A /wheels/*.deb 2>/dev/null)" ]; then dpkg -i /wheels/*.deb fi # install python packages export PATH=/opt/py3/bin:$PATH pip install -U pip wheel setuptools if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then pip install nvidia-nvshmem-cu13==3.4.5 else pip install nvidia-nvshmem-cu12==3.4.5 fi pip install /wheels/*.whl pip install dlblas==0.0.7 dlslime==0.0.2.post1 # install pre-built flash attention 3 wheel TORCH_VER=$(python3 -c "import torch; print(''.join(torch.__version__.split('+')[0].split('.')))") pip install ninja einops packaging FA3_WHEELS_URL="https://windreamer.github.io/flash-attention3-wheels/${CUDA_VERSION_SHORT}_torch${TORCH_VER}" pip install --no-index flash_attn_3 --find-links ${FA3_WHEELS_URL} # install requirements/serve.txt dependencies such as timm if [ -f /tmp/requirements/serve.txt ]; then pip install -r /tmp/requirements/serve.txt fi if [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then # As described in https://github.com/InternLM/lmdeploy/pull/4313, # window registration may cause memory leaks in NCCL 2.27, NCCL 2.28+ resolves the issue, # but turbomind engine will use nccl GIN for EP in future, which is brought in since 2.29 pip install "nvidia-nccl-cu12>2.29" fi ================================================ FILE: docker/prepare_wheel.sh ================================================ #!/bin/bash -ex export PATH=/opt/py3/bin:$PATH pip install "cmake<4.0" wheel ninja setuptools packaging if [[ ${PYTHON_VERSION} = "3.13" ]]; then curl https://sh.rustup.rs -sSf | sh -s -- -y . "$HOME/.cargo/env" pip install setuptools_rust pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/google/sentencepiece.git@v0.2.0#subdirectory=python" fi GDRCOPY_VERSION=2.5.1 DEEP_EP_VERSION=9af0e0d # v1.2.1 DEEP_GEMM_VERSION=c9f8b34 # v2.1.1.post3 FLASH_MLA_VERSION=1408756 # no release, pick the latest commit # DeepEP if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then export CPLUS_INCLUDE_PATH="/usr/local/cuda/include/cccl":${CPLUS_INCLUDE_PATH} pip install nvidia-nvshmem-cu13==3.4.5 else pip install nvidia-nvshmem-cu12==3.4.5 fi pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepEP.git@${DEEP_EP_VERSION}" # DeepGEMM pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepGEMM.git@${DEEP_GEMM_VERSION}" # FlashMLA # sm100 compilation for Flash MLA requires NVCC 12.9 or higher FLASH_MLA_DISABLE_SM100=1 pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/FlashMLA.git@${FLASH_MLA_VERSION}" # GDRCopy debs apt-get update -y \ && apt-get install -y --no-install-recommends build-essential devscripts debhelper fakeroot pkg-config dkms wget -q https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \ && tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \ && cd gdrcopy-${GDRCOPY_VERSION}/packages \ && CUDA=/usr/local/cuda ./build-deb-packages.sh \ && mv ./*.deb /wheels # Clean up build artifacts cd / && rm -rf gdrcopy-${GDRCOPY_VERSION} apt-get clean -y && rm -rf /var/lib/apt/lists/* ================================================ FILE: docs/en/.readthedocs.yaml ================================================ version: 2 formats: all build: os: "ubuntu-22.04" tools: python: "3.10" sphinx: configuration: docs/en/conf.py python: install: - requirements: requirements/docs.txt - requirements: requirements/readthedocs.txt ================================================ FILE: docs/en/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/en/_static/css/readthedocs.css ================================================ table.autosummary td { width: 50% } img.align-center { display: block; margin-left: auto; margin-right: auto; } ================================================ FILE: docs/en/advance/chat_template.md ================================================ # Customized chat template The effect of the applied chat template can be observed by **setting log level** `INFO`. LMDeploy supports two methods of adding chat templates: - One approach is to utilize an existing conversation template by directly configuring a JSON file like the following. ```json { "model_name": "your awesome chat template name", "system": "<|im_start|>system\n", "meta_instruction": "You are a robot developed by LMDeploy.", "eosys": "<|im_end|>\n", "user": "<|im_start|>user\n", "eoh": "<|im_end|>\n", "assistant": "<|im_start|>assistant\n", "eoa": "<|im_end|>", "separator": "\n", "capability": "chat", "stop_words": ["<|im_end|>"] } ``` The new chat template would be applied like this: ``` {system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}{assistant_content}{eoa}{separator}{user}... ``` When using the CLI tool, you can pass in a custom chat template with `--chat-template`, for example. ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE} ``` You can also pass it in through the interface function, for example. ```python from lmdeploy import ChatTemplateConfig, serve serve('internlm/internlm2_5-7b-chat', chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}')) ``` - Another approach is to customize a Python chat template class like the existing LMDeploy chat templates. It can be used directly after successful registration. The advantages are a high degree of customization and strong controllability. Below is an example of registering an LMDeploy chat template. ```python from lmdeploy.model import MODELS, BaseChatTemplate @MODELS.register_module(name='customized_model') class CustomizedModel(BaseChatTemplate): """A customized chat template.""" def __init__(self, system='<|im_start|>system\n', meta_instruction='You are a robot developed by LMDeploy.', user='<|im_start|>user\n', assistant='<|im_start|>assistant\n', eosys='<|im_end|>\n', eoh='<|im_end|>\n', eoa='<|im_end|>', separator='\n', stop_words=['<|im_end|>', '<|action_end|>']): super().__init__(system=system, meta_instruction=meta_instruction, eosys=eosys, user=user, eoh=eoh, assistant=assistant, eoa=eoa, separator=separator, stop_words=stop_words) from lmdeploy import ChatTemplateConfig, pipeline messages = [{'role': 'user', 'content': 'who are you?'}] pipe = pipeline('internlm/internlm2_5-7b-chat', chat_template_config=ChatTemplateConfig('customized_model')) for response in pipe.stream_infer(messages): print(response.text, end='') ``` In this example, we register a LMDeploy chat template that sets the model to be created by LMDeploy, so when the user asks who the model is, the model will answer that it was created by LMDeploy. ================================================ FILE: docs/en/advance/context_parallel.md ================================================ # Context Parallel When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages: 1. The amount of available kv_cache is halved, which reducing the maximum supported session length. 2. The maximum inference batch size is reduced, leading to lower throughput. To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below: ``` cp_rank=2, prompt_len=5, generation_len=4 kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 kv_cache stored on cp_rank1: 1, 3, 5, 7 ``` ## Usage Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way: ``` lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 ``` ================================================ FILE: docs/en/advance/debug_turbomind.md ================================================ # How to debug Turbomind Turbomind is implemented in C++, which is not as easy to debug as Python. This document provides basic methods for debugging Turbomind. ## Prerequisite First, complete the local compilation according to the commands in [Install from source](../get_started/installation.md). ## Configure Python debug environment Since many large companies currently use Centos 7 for online production environments, we will use Centos 7 as an example to illustrate the process. ### Obtain `glibc` and `python3` versions ```bash rpm -qa | grep glibc rpm -qa | grep python3 ``` The result should be similar to this: ``` [username@hostname workdir]# rpm -qa | grep glibc glibc-2.17-325.el7_9.x86_64 glibc-common-2.17-325.el7_9.x86_64 glibc-headers-2.17-325.el7_9.x86_64 glibc-devel-2.17-325.el7_9.x86_64 [username@hostname workdir]# rpm -qa | grep python3 python3-pip-9.0.3-8.el7.noarch python3-rpm-macros-3-34.el7.noarch python3-rpm-generators-6-2.el7.noarch python3-setuptools-39.2.0-10.el7.noarch python3-3.6.8-21.el7_9.x86_64 python3-devel-3.6.8-21.el7_9.x86_64 python3.6.4-sre-1.el6.x86_64 ``` Based on the information above, we can see that the version of `glibc` is `2.17-325.el7_9.x86_64` and the version of `python3` is `3.6.8-21.el7_9.x86_64`. ### Download and install `debuginfo` library Download `glibc-debuginfo-common-2.17-325.el7.x86_64.rpm`, `glibc-debuginfo-2.17-325.el7.x86_64.rpm`, and `python3-debuginfo-3.6.8-21.el7.x86_64.rpm` from http://debuginfo.centos.org/7/x86_64. ```bash rpm -ivh glibc-debuginfo-common-2.17-325.el7.x86_64.rpm rpm -ivh glibc-debuginfo-2.17-325.el7.x86_64.rpm rpm -ivh python3-debuginfo-3.6.8-21.el7.x86_64.rpm ``` ### Upgrade GDB ```bash sudo yum install devtoolset-10 -y echo "source scl_source enable devtoolset-10" >> ~/.bashrc source ~/.bashrc ``` ### Verification ```bash gdb python3 ``` The output should be similar to this: ``` [username@hostname workdir]# gdb python3 GNU gdb (GDB) Red Hat Enterprise Linux 9.2-10.el7 Copyright (C) 2020 Free Software Foundation, Inc. License GPLv3+: GNU GPL version 3 or later This is free software: you are free to change and redistribute it. There is NO WARRANTY, to the extent permitted by law. Type "show copying" and "show warranty" for details. This GDB was configured as "x86_64-redhat-linux-gnu". Type "show configuration" for configuration details. For bug reporting instructions, please see: . Find the GDB manual and other documentation resources online at: . For help, type "help". Type "apropos word" to search for commands related to "word"... Reading symbols from python3... (gdb) ``` If it shows `Reading symbols from python3`, the configuration has been successful. For other operating systems, please refer to [DebuggingWithGdb](https://wiki.python.org/moin/DebuggingWithGdb). ## Set up symbolic links After setting up symbolic links, there is no need to install it locally with `pip` every time. ```bash # Change directory to lmdeploy, e.g. cd /workdir/lmdeploy # Since it has been built in the build directory # Link the lib directory cd lmdeploy && ln -s ../build/lib . && cd .. # (Optional) Link compile_commands.json for clangd index ln -s build/compile_commands.json . ``` ## Start debugging ````bash # Use gdb to start the API server with Llama-2-13b-chat-hf, e.g. gdb --args python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf # Set directories in gdb Reading symbols from python3... (gdb) set directories /workdir/lmdeploy # Set a breakpoint using the relative path, e.g. (gdb) b src/turbomind/models/llama/BlockManager.cc:104 # When it shows # ``` # No source file named src/turbomind/models/llama/BlockManager.cc. # Make breakpoint pending on future shared library load? (y or [n]) # ``` # Just type `y` and press enter # Run (gdb) r # (Optional) Use https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py to send a request python3 profile_restful_api.py --backend lmdeploy --dataset-path /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --num_prompts 1 ```` ## Using GDB Refer to [GDB Execution Commands](https://lldb.llvm.org/use/map.html) and happy debugging. ================================================ FILE: docs/en/advance/long_context.md ================================================ # Context length extrapolation Long text extrapolation refers to the ability of LLM to handle data longer than the training text during inference. TurboMind engine now support [LlamaDynamicNTKScalingRotaryEmbedding](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L178) and the implementation is consistent with huggingface. ## Usage You can enable the context length extrapolation abality by modifying the TurbomindEngineConfig. Edit the `session_len` to the expected length and change `rope_scaling_factor` to a number no less than 1.0. Take `internlm2_5-7b-chat-1m` as an example, which supports a context length of up to **1 million tokens**: ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=1000000, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config) prompt = 'Use a long prompt to replace this sentence' gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) response = pipe(prompt, gen_config=gen_config) print(response) ``` ## Evaluation We use several methods to evaluate the long-context-length inference ability of LMDeploy, including [passkey retrieval](#passkey-retrieval), [needle in a haystack](#needle-in-a-haystack) and computing [perplexity](#perplexity) ### Passkey Retrieval You can try the following code to test how many times LMDeploy can retrieval the special key. ```python import numpy as np from lmdeploy import pipeline from lmdeploy import TurbomindEngineConfig import time session_len = 1000000 backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config) def passkey_retrieval(session_len, n_round=5): # create long context input tok = pipe.tokenizer task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.' garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.' for _ in range(n_round): start = time.perf_counter() n_times = (session_len - 1000) // len(tok.encode(garbage)) n_garbage_prefix = np.random.randint(0, n_times) n_garbage_suffix = n_times - n_garbage_prefix garbage_prefix = ' '.join([garbage] * n_garbage_prefix) garbage_suffix = ' '.join([garbage] * n_garbage_suffix) pass_key = np.random.randint(1, 50000) information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.' # noqa: E501 final_question = 'What is the pass key? The pass key is' lines = [ task_description, garbage_prefix, information_line, garbage_suffix, final_question, ] # inference prompt = ' '.join(lines) response = pipe([prompt]) print(pass_key, response) end = time.perf_counter() print(f'duration: {end - start} s') passkey_retrieval(session_len, 5) ``` This test takes approximately 364 seconds per round when conducted on A100-80G GPUs ### Needle In A Haystack [OpenCompass](https://github.com/open-compass/opencompass) offers very useful tools to perform needle-in-a-haystack evaluation. For specific instructions, please refer to the [guide](https://github.com/open-compass/opencompass/blob/main/docs/en/advanced_guides/needleinahaystack_eval.md). ### Perplexity The following codes demonstrate how to use LMDeploy to calculate perplexity. ```python from transformers import AutoTokenizer from lmdeploy import TurbomindEngineConfig, pipeline import numpy as np # load model and tokenizer model_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m' backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=1000000, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline(model_repoid_or_path, backend_config=backend_config) tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True) # get perplexity text = 'Use a long prompt to replace this sentence' input_ids = tokenizer.encode(text) ppl = pipe.get_ppl(input_ids)[0] print(ppl) ``` ================================================ FILE: docs/en/advance/metrics.md ================================================ # Production Metrics LMDeploy exposes a set of metrics via Prometheus, and provides visualization via Grafana. ## Setup Guide This section describes how to set up the monitoring stack (Prometheus + Grafana) provided in the `lmdeploy/monitoring` directory. ## Prerequisites - [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) installed - LMDeploy server running with metrics system enabled ## Usage (DP = 1) 1. **Start your LMDeploy server with metrics enabled** ``` lmdeploy serve api_server Qwen/Qwen2.5-7B-Instruct --enable-metrics ``` Replace the model path according to your needs. By default, the metrics endpoint will be available at `http://:23333/metrics`. 2. **Navigate to the monitoring directory** ``` cd lmdeploy/monitoring ``` 3. **Start the monitoring stack** ``` docker compose up ``` This command will start Prometheus and Grafana in the background. 4. **Access the monitoring interfaces** - Prometheus: Open your web browser and go to http://localhost:9090. - Grafana: Open your web browser and go to http://localhost:3000. 5. **Log in to Grafana** - Default Username: `admin` - Default Password: `admin` You will be prompted to change the password upon your first login. 6. **View the Dashboard** The LMDeploy dashboard is pre-configured and should be available automatically. ## Usage (DP > 1) 1. **Start your LMDeploy server with metrics enabled** As an example, we use the model `Qwen/Qwen2.5-7B-Instruct` with `DP=2, TP=2`. Start the service as follows: ```bash # Proxy server lmdeploy serve proxy --server-port 8000 --routing-strategy 'min_expected_latency' --serving-strategy Hybrid --log-level INFO # API server LMDEPLOY_DP_MASTER_ADDR=127.0.0.1 \ LMDEPLOY_DP_MASTER_PORT=29555 \ lmdeploy serve api_server \ Qwen/Qwen2.5-7B-Instruct \ --backend pytorch \ --tp 2 \ --dp 2 \ --proxy-url http://0.0.0.0:8000 \ --nnodes 1 \ --node-rank 0 \ --enable-metrics ``` You should be able to see multiple API servers added to the proxy server list. Details can be found in `lmdeploy/serve/proxy/proxy_config.json`. For example, you may have the following API servers: ``` http://$host_ip:$api_server_port1 http://$host_ip:$api_server_port2 ``` 2. **Modify the Prometheus configuration** When `DP > 1`, LMDeploy will launch one API server for each DP rank. If you want to monitor a specific API server, e.g. `http://$host_ip:$api_server_port1`, modify the configuration file `lmdeploy/monitoring/prometheus.yaml` as follows. > Note that you should use the actual host machine IP instead of `127.0.0.1` here, since LMDeploy starts the API server using the actual host IP when `DP > 1` ``` global: scrape_interval: 5s evaluation_interval: 30s scrape_configs: - job_name: lmdeploy static_configs: - targets: - '$host_ip:$api_server_port1' # <= Modify this ``` 3. **Navigate to the monitoring folder and perform the same steps as described above** ## Troubleshooting 1. **Port conflicts** Check if any services are occupying ports `23333` (LMDeploy server port), `9090` (Prometheus port), or `3000` (Grafana port). You can either stop the conflicting running ports or modify the config files as follows: - Modify LMDeploy server port for Prometheus scrape In `lmdeploy/monitoring/prometheus.yaml` ``` global: scrape_interval: 5s evaluation_interval: 30s scrape_configs: - job_name: lmdeploy static_configs: - targets: - '127.0.0.1:23333' # <= Modify this LMDeploy server port 23333, need to match the running server port ``` - Modify Prometheus port In `lmdeploy/monitoring/grafana/datasources/datasource.yaml` ``` apiVersion: 1 datasources: - name: Prometheus type: prometheus access: proxy url: http://localhost:9090 # <= Modify this Prometheus interface port 9090 isDefault: true editable: false ``` - Modify Grafana port: In `lmdeploy/monitoring/docker-compose.yaml`, for example, change the port to `3090` Option 1: Add `GF_SERVER_HTTP_PORT` to the environment section. ``` environment: - GF_AUTH_ANONYMOUS_ENABLED=true - GF_SERVER_HTTP_PORT=3090 # <= Add this line ``` Option 2: Use port mapping. ``` grafana: image: grafana/grafana:latest container_name: grafana ports: - "3090:3000" # <= Host:Container port mapping ``` 2. **No data on the dashboard** - Create traffic Try to send some requests to the LMDeploy server to create certain traffic ``` python3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` After refreshing, you should be able to see data on the dashboard. ================================================ FILE: docs/en/advance/pytorch_multinodes.md ================================================ # PyTorchEngine Multi-Node Deployment Guide To support larger-scale model deployment requirements, PyTorchEngine provides multi-node deployment support. Below are the detailed steps for deploying a `tp=16` model across two 8-GPU nodes. ## 1. Create Docker Containers (Optional) To ensure consistency across the cluster environment, it is recommended to use Docker to set up the cluster. Create containers on each node as follows: ```bash docker run -it \ --network host \ -v $MODEL_PATH:$CONTAINER_MODEL_PATH \ openmmlab/lmdeploy:latest ``` > \[!IMPORTANT\] > Ensure that the model is placed in the same directory on all node containers. ## 2. Set Up the Cluster Using Ray ### 2.1 Start the Head Node Select one node as the **head node** and run the following command in its container: ```bash ray start --head --port=$DRIVER_PORT ``` ### 2.2 Join the Cluster On the other nodes, use the following command in their containers to join the cluster created by the head node: ```bash ray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT ``` run `ray status` on head node to check the cluster. > \[!IMPORTANT\] > Ensure that `DRIVER_NODE_ADDR` is the address of the head node and `DRIVER_PORT` matches the port number used during the head node initialization. ## 3. Use LMDeploy Interfaces In the head node's container, you can use all functionalities of PyTorchEngine as usual. ### 3.1 Start the Server ```bash lmdeploy serve api_server \ $CONTAINER_MODEL_PATH \ --backend pytorch \ --tp 16 ``` ### 3.2 Use the Pipeline ```python from lmdeploy import pipeline, PytorchEngineConfig if __name__ == '__main__': model_path = '/path/to/model' backend_config = PytorchEngineConfig(tp=16) with pipeline(model_path, backend_config=backend_config) as pipe: outputs = pipe('Hakuna Matata') ``` > \[!NOTE\] > PyTorchEngine will automatically choose the appropriate launch method (single-node/multi-node) based on the `tp` parameter and the number of devices available in the cluster. If you want to enforce the use of the Ray cluster, you can configure `distributed_executor_backend='ray'` in `PytorchEngineConfig` or use the environment variable `LMDEPLOY_EXECUTOR_BACKEND=ray`. ______________________________________________________________________ By following the steps above, you can successfully deploy PyTorchEngine in a multi-node environment and leverage the Ray cluster for distributed computing. > \[!WARNING\] > To achieve better performance, we recommend users to configure a higher-quality network environment (such as using [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand)) to improve engine efficiency. ================================================ FILE: docs/en/advance/pytorch_multithread.md ================================================ # PyTorchEngine Multithread We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example: ```python import asyncio from lmdeploy import pipeline, PytorchEngineConfig event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) model_path = 'Llama-3.2-1B-Instruct' pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) async def _gather_output(): tasks = [ pipe.async_batch_infer('Hakuna Matata'), pipe.async_batch_infer('giraffes are heartless creatures'), ] return await asyncio.gather(*tasks) output = asyncio.run(_gather_output()) print(output[0].text) print(output[1].text) ``` If you do need multithreading, it would be easy to warp it like below: ```python import threading from queue import Queue import asyncio from lmdeploy import pipeline, PytorchEngineConfig model_path = 'Llama-3.2-1B-Instruct' async def _batch_infer(inque: Queue, outque: Queue, pipe): while True: if inque.empty(): await asyncio.sleep(0) continue input = inque.get_nowait() output = await pipe.async_batch_infer(input) outque.put(output) def server(inques, outques): event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) for inque, outque in zip(inques, outques): event_loop.create_task(_batch_infer(inque, outque, pipe)) event_loop.run_forever() def client(inque, outque, message): inque.put(message) print(outque.get().text) inques = [Queue(), Queue()] outques = [Queue(), Queue()] t_server = threading.Thread(target=server, args=(inques, outques)) t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata')) t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures')) t_server.start() t_client0.start() t_client1.start() t_client0.join() t_client1.join() ``` > \[!WARNING\] > This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance. ================================================ FILE: docs/en/advance/pytorch_new_model.md ================================================ # lmdeploy.pytorch New Model Support lmdeploy.pytorch is designed to simplify the support for new models and the development of prototypes. Users can adapt new models according to their own needs. ## Model Support ### Configuration Loading (Optional) lmdeploy.pytorch initializes the engine based on the model's config file. If the parameter naming of the model to be integrated differs from common models in transformers, parsing errors may occur. A custom ConfigBuilder can be added to parse the configuration. ```python # lmdeploy/pytorch/configurations/gemma.py from lmdeploy.pytorch.config import ModelConfig from .builder import AutoModelConfigBuilder class GemmaModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): # Check if hf_config is suitable for this builder return hf_config.model_type in ['gemma', 'gemma2'] @classmethod def build(cls, hf_config, model_path: str = None): # Use the hf_config loaded by transformers # Construct the ModelConfig for the pytorch engine return ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, num_attention_heads=hf_config.num_attention_heads, num_key_value_heads=hf_config.num_key_value_heads, bos_token_id=hf_config.bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=hf_config.head_dim, vocab_size=hf_config.vocab_size) ``` The `lmdeploy.pytorch.check_env.check_model` function can be used to verify if the configuration can be parsed correctly. ### Implementing the Model After ensuring that the configuration can be parsed correctly, you can start implementing the model logic. Taking the implementation of llama as an example, we need to create the model using the configuration file from transformers. ```python class LlamaForCausalLM(nn.Module): # Constructor, builds the model with the given config # ctx_mgr is the context manager, which can be used to pass engine configurations or additional parameters def __init__(self, config: LlamaConfig, ctx_mgr: StepContextManager, dtype: torch.dtype = None, device: torch.device = None): super().__init__() self.config = config self.ctx_mgr = ctx_mgr # build LLamaModel self.model = LlamaModel(config, dtype=dtype, device=device) # build lm_head self.lm_head = build_rowwise_linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) # Model inference function # It is recommended to use the same parameters as below def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, ): hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) logits = self.lm_head(hidden_states) logits = logits.float() return logits ``` In addition to these, the following content needs to be added: ```python class LlamaForCausalLM(nn.Module): ... # Indicates whether the model supports cudagraph # Can be a callable object, receiving forward inputs # Dynamically determines if cudagraph is supported support_cuda_graph = True # Builds model inputs # Returns a dictionary, the keys of which must be inputs to forward def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], inputs_embeds: Optional[torch.Tensor] = None, context: StepContext = None, ): ... # Loads weights # The model's inputs are key-value pairs of the state dict def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ... ``` We have encapsulated many fused operators to simplify the model construction. These operators better support various functions such as tensor parallelism and quantization. We encourage developers to use these ops as much as possible. ```python # Using predefined build_merged_colwise_linear, SiluAndMul, build_rowwise_linear # Helps us build the model faster and without worrying about tensor concurrency, quantization, etc. class LlamaMLP(nn.Module): def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None): super().__init__() quantization_config = getattr(config, 'quantization_config', None) # gate up self.gate_up_proj = build_merged_colwise_linear( config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=config.mlp_bias, dtype=dtype, device=device, quant_config=quantization_config, is_tp=True, ) # silu and mul self.act_fn = SiluAndMul(inplace=True) # down self.down_proj = build_rowwise_linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias, quant_config=quantization_config, dtype=dtype, device=device, is_tp=True) def forward(self, x): """forward.""" gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.down_proj(act) ``` ### Model Registration To ensure that the developed model implementation can be used normally, we also need to register the model in `lmdeploy/pytorch/models/module_map.py` ```python MODULE_MAP.update({ 'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', }) ``` If you do not wish to modify the model source code, you can also pass a custom module map from the outside, making it easier to integrate into other projects. ``` from lmdeploy import PytorchEngineConfig, pipeline backend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py') generator = pipeline(model_path, backend_config=backend_config) ``` ================================================ FILE: docs/en/advance/pytorch_profiling.md ================================================ # PyTorchEngine Profiling We provide multiple profiler to analysis the performance of PyTorchEngine. ## PyTorch Profiler We have integrated the PyTorch Profiler. You can enable it by setting environment variables when launching the pipeline or API server: ```bash # enable profile cpu export LMDEPLOY_PROFILE_CPU=1 # enable profile cuda export LMDEPLOY_PROFILE_CUDA=1 # profile would start after 3 seconds export LMDEPLOY_PROFILE_DELAY=3 # profile 10 seconds export LMDEPLOY_PROFILE_DURATION=10 # prefix path to save profile files export LMDEPLOY_PROFILE_OUT_PREFIX="/path/to/save/profile_" ``` After the program exits, the profiling data will be saved to the path specified by `LMDEPLOY_PROFILE_OUT_PREFIX` for performance analysis. ## Nsight System We also support using Nsight System to profile NVIDIA devices. ### Single GPU For single-GPU scenarios, simply use `nsys profile`: ```bash nsys profile python your_script.py ``` ### Multi-GPU When using multi-GPU solutions like DP/TP/EP, set the following environment variables: ```bash # enable nsight system export LMDEPLOY_RAY_NSYS_ENABLE=1 # prefix path to save profile files export LMDEPLOY_RAY_NSYS_OUT_PREFIX="/path/to/save/profile_" ``` Then launch the script or API server as usual (Do **NOT** use nsys profile here). The profiling results will be saved under `LMDEPLOY_RAY_NSYS_OUT_PREFIX`. If `LMDEPLOY_RAY_NSYS_OUT_PREFIX` is not configured, you can find the results in `/tmp/ray/session_xxx/nsight`. ## Ray timeline We use `ray` to support multi-device deployment. You can get the ray timeline with the environments below. ```bash export LMDEPLOY_RAY_TIMELINE_ENABLE=1 export LMDEPLOY_RAY_TIMELINE_OUT_PATH="/path/to/save/timeline.json" ``` ================================================ FILE: docs/en/advance/spec_decoding.md ================================================ # Speculative Decoding Speculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once. > \[!NOTE\] > This is an experimental feature in lmdeploy. ## Examples Here are some examples. ### Eagle 3 #### Prepare Install [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) ```shell git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper python setup.py install ``` #### pipeline ```python from lmdeploy import PytorchEngineConfig, pipeline from lmdeploy.messages import SpeculativeConfig if __name__ == '__main__': model_path = 'meta-llama/Llama-3.1-8B-Instruct' spec_cfg = SpeculativeConfig( method='eagle3', num_speculative_tokens=3, model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', ) pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` #### serving ```shell lmdeploy serve api_server \ meta-llama/Llama-3.1-8B-Instruct \ --backend pytorch \ --server-port 24545 \ --speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \ --speculative-algorithm eagle3 \ --speculative-num-draft-tokens 3 \ --max-batch-size 128 \ --enable-metrics ``` ### Deepseek MTP #### Prepare Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation) ```shell git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla cd flash-mla git submodule update --init --recursive pip install -v . ``` #### pipeline ```python from lmdeploy import PytorchEngineConfig, pipeline from lmdeploy.messages import SpeculativeConfig if __name__ == '__main__': model_path = 'deepseek-ai/DeepSeek-V3' spec_cfg = SpeculativeConfig( method='deepseek_mtp', num_speculative_tokens=3, ) pipe = pipeline(model_path, backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), speculative_config=spec_cfg) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` #### serving ```shell lmdeploy serve api_server \ deepseek-ai/DeepSeek-V3 \ --backend pytorch \ --server-port 24545 \ --tp 16 \ --speculative-algorithm deepseek_mtp \ --speculative-num-draft-tokens 3 \ --max-batch-size 128 \ --enable-metrics ``` ================================================ FILE: docs/en/advance/structed_output.md ================================================ # Structured output Structured output, also known as guided decoding, forces the model to generate text that exactly matches a user-supplied JSON schema, grammar, or regex. Both the PyTorch and Turbomind backends now support structured (schema-constrained) generation. Below are examples for the pipeline API and the API server. ## pipeline ```python from lmdeploy import pipeline from lmdeploy.messages import GenerationConfig, PytorchEngineConfig model = 'internlm/internlm2-chat-1_8b' guide = { 'type': 'object', 'properties': { 'name': { 'type': 'string' }, 'skills': { 'type': 'array', 'items': { 'type': 'string', 'maxLength': 10 }, 'minItems': 3 }, 'work history': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'company': { 'type': 'string' }, 'duration': { 'type': 'string' } }, 'required': ['company'] } } }, 'required': ['name', 'skills', 'work history'] } pipe = pipeline(model, backend_config=PytorchEngineConfig(), log_level='INFO') gen_config = GenerationConfig( response_format=dict(type='json_schema', json_schema=dict(name='test', schema=guide))) response = pipe(['Make a self introduction please.'], gen_config=gen_config) print(response) ``` ## api_server Firstly, start the api_server service for the InternLM2 model. ```shell lmdeploy serve api_server internlm/internlm2-chat-1_8b --backend pytorch ``` The client can test using OpenAI’s python package: The output result is a response in JSON format. ```python from openai import OpenAI guide = { 'type': 'object', 'properties': { 'name': { 'type': 'string' }, 'skills': { 'type': 'array', 'items': { 'type': 'string', 'maxLength': 10 }, 'minItems': 3 }, 'work history': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'company': { 'type': 'string' }, 'duration': { 'type': 'string' } }, 'required': ['company'] } } }, 'required': ['name', 'skills', 'work history'] } response_format=dict(type='json_schema', json_schema=dict(name='test',schema=guide)) messages = [{'role': 'user', 'content': 'Make a self-introduction please.'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, response_format=response_format, top_p=0.8) print(response) ``` ================================================ FILE: docs/en/advance/update_weights.md ================================================ # Update Weights LMDeploy supports update model weights online for scenes such as RL training. Here are the steps to do so. ## Step 1: Launch server For pytorch backend you have to add `--distributed-executor-backend ray`. ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend ``` ## Step 2: Offloads weights & kv cache Before update model weights, the server should offloads weights and kv cache. ```python from lmdeploy.utils import serialize_state_dict import requests BASE_URL = 'http://0.0.0.0:23333' api_key = 'sk-xxx' headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } # offloads weights and kv cache with level=2 response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2)) assert response.status_code == 200, response.status_code # wake up weights, the server is ready for update weights response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights'])) assert response.status_code == 200, response.status_code ``` ## Step 3: Update weights Split model weights into multi segments and update through `update_weights` endpoint. ```python segmented_state_dict: List[Dict[str, torch.Tensor]] = ... num_segment = len(segmented_state_dict) for seg_idx in range(num_segment): serialized_data = serialize_state_dict(segmented_state_dict[seg_idx]) data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1) response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) assert response.status_code == 200, f"response.status_code = {response.status_code}" ``` **Note**: For pytorch backend, lmdeploy also supports flattened bucket tensors: ```python from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata segmented_state_dict: List[Dict[str, torch.Tensor]] = ... num_segment = len(segmented_state_dict) for seg_idx in range(num_segment): named_tensors = list(segmented_state_dict[seg_idx].items()) bucket = FlattenedTensorBucket(named_tensors=named_tensors) metadata = bucket.get_metadata() flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata) serialized_data = serialize_state_dict(flattened_tensor_data) data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket') response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) assert response.status_code == 200, f"response.status_code = {response.status_code}" ``` ## Step 4: Wakeup server After update model weights, the server should onloads kv cache and provide serving again with the new updated weights. ```python response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache'])) assert response.status_code == 200, response.status_code ``` ================================================ FILE: docs/en/api/cli.rst ================================================ Command-line Tools =================== .. sphinx_argparse_cli:: :module: lmdeploy.cli :func: run :hook: :prog: lmdeploy ================================================ FILE: docs/en/api/openapi.rst ================================================ OpenAPI Endpoints ================== .. currentmodule:: lmdeploy OpenAI Compatible API Endpoints ------------------------------- .. openapi:: ../_static/openai.yaml :request: :examples: Proxy Server API ---------------- .. openapi:: ../_static/proxy.yaml :request: :examples: ================================================ FILE: docs/en/api/pipeline.rst ================================================ Inference pipeline ================== .. currentmodule:: lmdeploy Pipeline -------- .. autofunction:: pipeline .. autoclass:: Pipeline :undoc-members: :show-inheritance: :members: __init__, infer, stream_infer, chat, get_ppl :member-order: bysource Config ------------------- .. autoclass:: PytorchEngineConfig .. autoclass:: TurbomindEngineConfig .. autoclass:: GenerationConfig .. autoclass:: ChatTemplateConfig ================================================ FILE: docs/en/benchmark/a100_fp16.md ================================================ # TurboMind Benchmark on A100 All the following results are tested on A100-80G(x8) CUDA 11.8. The tested lmdeploy version is `v0.2.0` ## Request Throughput Benchmark - `batch`: the max batch size during inference - `tp`: the number of GPU cards for tensor parallelism - `num_prompts`: the number of prompts, i.e. the number of requests - `PRS`: **R**equest **P**er **S**econd - `FTL`: **F**irst **T**oken **L**atency ### FP16 | model | batch | tp | num_promts | RPS | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | throughput(out tok/s) | throughput(total tok/s) | | ------------ | ----- | --- | ---------- | ------ | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | --------------------- | ----------------------- | | llama2-7b | 256 | 1 | 3000 | 14.556 | 0.526 | 0.092 | 4.652 | 0.066 | 0.101 | 0.155 | 0.220 | 3387.419 | 6981.159 | | llama2-13b | 128 | 1 | 3000 | 7.950 | 0.352 | 0.075 | 4.193 | 0.051 | 0.067 | 0.138 | 0.202 | 1850.145 | 3812.978 | | internlm-20b | 128 | 2 | 3000 | 10.291 | 0.287 | 0.073 | 3.845 | 0.053 | 0.072 | 0.113 | 0.161 | 2053.266 | 4345.057 | | llama2-70b | 256 | 4 | 3000 | 7.231 | 1.075 | 0.139 | 14.524 | 0.102 | 0.153 | 0.292 | 0.482 | 1682.738 | 3467.969 | ## Static Inference Benchmark - `batch`: the max batch size during inference - `tp`: the number of GPU cards for tensor parallelism - `prompt_tokens`: the number of input tokens - `output_tokens`: the number of generated tokens - `throughput`: the number of generated tokens per second - `FTL`: **F**irst **T**oken **L**atency ### FP16 llama2-7b | batch | tp | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | | ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | | 1 | 1 | 1 | 128 | 100.02 | 76.55 | 0.011 | 0.01 | 0.011 | 0.009 | 0.009 | 0.01 | 0.011 | | 1 | 1 | 128 | 128 | 102.21 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | | 1 | 1 | 128 | 2048 | 98.92 | 76.59 | 0.022 | 0.022 | 0.022 | 0.01 | 0.01 | 0.01 | 0.01 | | 1 | 1 | 2048 | 128 | 86.1 | 76.77 | 0.139 | 0.139 | 0.14 | 0.01 | 0.01 | 0.01 | 0.011 | | 1 | 1 | 2048 | 2048 | 93.78 | 76.77 | 0.14 | 0.139 | 0.141 | 0.011 | 0.011 | 0.011 | 0.011 | | 16 | 1 | 1 | 128 | 1504.72 | 76.59 | 0.021 | 0.011 | 0.031 | 0.01 | 0.011 | 0.011 | 0.013 | | 16 | 1 | 128 | 128 | 1272.47 | 76.77 | 0.129 | 0.023 | 0.149 | 0.011 | 0.011 | 0.012 | 0.014 | | 16 | 1 | 128 | 2048 | 1010.62 | 76.77 | 0.13 | 0.023 | 0.144 | 0.015 | 0.018 | 0.02 | 0.021 | | 16 | 1 | 2048 | 128 | 348.87 | 78.3 | 2.897 | 0.143 | 3.576 | 0.02 | 0.021 | 0.022 | 0.025 | | 16 | 1 | 2048 | 2048 | 601.63 | 78.3 | 2.678 | 0.142 | 3.084 | 0.025 | 0.028 | 0.03 | 0.031 | | 32 | 1 | 1 | 128 | 2136.73 | 76.62 | 0.079 | 0.014 | 0.725 | 0.011 | 0.012 | 0.013 | 0.021 | | 32 | 1 | 128 | 128 | 2125.47 | 76.99 | 0.214 | 0.022 | 0.359 | 0.012 | 0.013 | 0.014 | 0.035 | | 32 | 1 | 128 | 2048 | 1462.12 | 76.99 | 0.2 | 0.026 | 0.269 | 0.021 | 0.026 | 0.031 | 0.033 | | 32 | 1 | 2048 | 128 | 450.43 | 78.3 | 4.288 | 0.143 | 5.267 | 0.031 | 0.032 | 0.034 | 0.161 | | 32 | 1 | 2048 | 2048 | 733.34 | 78.34 | 4.118 | 0.19 | 5.429 | 0.04 | 0.045 | 0.05 | 0.053 | | 64 | 1 | 1 | 128 | 4154.81 | 76.71 | 0.042 | 0.013 | 0.21 | 0.012 | 0.018 | 0.028 | 0.041 | | 64 | 1 | 128 | 128 | 3024.07 | 77.43 | 0.44 | 0.026 | 1.061 | 0.014 | 0.018 | 0.026 | 0.158 | | 64 | 1 | 128 | 2048 | 1852.06 | 77.96 | 0.535 | 0.027 | 1.231 | 0.03 | 0.041 | 0.048 | 0.053 | | 64 | 1 | 2048 | 128 | 493.46 | 78.4 | 6.59 | 0.142 | 16.235 | 0.046 | 0.049 | 0.055 | 0.767 | | 64 | 1 | 2048 | 2048 | 755.65 | 78.4 | 39.105 | 0.142 | 116.285 | 0.047 | 0.049 | 0.051 | 0.207 | ================================================ FILE: docs/en/benchmark/benchmark.md ================================================ # Benchmark Please install the lmdeploy precompiled package and download the script and the test dataset: ```shell pip install lmdeploy # clone the repo to get the benchmark script git clone --depth=1 https://github.com/InternLM/lmdeploy cd lmdeploy # switch to the tag corresponding to the installed version: git fetch --tags # Check the installed lmdeploy version: pip show lmdeploy | grep Version # Then, check out the corresponding tag (replace with the version string): git checkout # download the test dataset wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` ## Benchmark offline pipeline API ```shell python3 benchmark/profile_pipeline_api.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct ``` For a comprehensive list of available arguments, please execute `python3 benchmark/profile_pipeline_api.py -h` ## Benchmark offline engine API ```shell python3 benchmark/profile_throughput.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct ``` Detailed argument specification can be retrieved by running `python3 benchmark/profile_throughput.py -h` ## Benchmark online serving Launch the server first (you may refer [here](../llm/api_server.md) for guide) and run the following command: ```shell python3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` For detailed argument specification of `profile_restful_api.py`, please run the help command `python3 benchmark/profile_restful_api.py -h`. ================================================ FILE: docs/en/benchmark/evaluate_with_opencompass.md ================================================ # Model Evaluation Guide This document describes how to evaluate a model's capabilities on academic datasets using OpenCompass and LMDeploy. The complete evaluation process consists of two main stages: inference stage and evaluation stage. During the inference stage, the target model is first deployed as an inference service using LMDeploy. OpenCompass then sends dataset content as requests to this service and collects the generated responses. In the evaluation stage, the OpenCompass evaluation model `opencompass/CompassVerifier-32B` is deployed as a service via LMDeploy. OpenCompass subsequently submits the inference results to this service to obtain final evaluation scores. If sufficient computational resources are available, please refer to the [End-to-End Evaluation](#end-to-end-evaluation) section for complete workflow execution. Otherwise, we recommend following the [Step-by-Step Evaluation](#step-by-step-evaluation) section to execute both stages sequentially. ## Environment Setup ```shell pip install lmdeploy pip install "opencompass[full]" # Download the lmdeploy source code, which will be used in subsequent steps to access eval script and configuration git clone --depth=1 https://github.com/InternLM/lmdeploy.git ``` It is recommended to install LMDeploy and OpenCompass in separate Python virtual environments to avoid potential dependency conflicts. ## End-to-End Evaluation 1. **Deploy Target Model** ```shell lmdeploy serve api_server --server-port 10000 <--other-options> ``` 2. **Deploy Evaluation Model (Judger)** ```shell lmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 ``` 3. **Generate Evaluation Configuration and Execute** ```shell cd {the/root/path/of/lmdeploy/repo} ## Specify the dataset path. OC will download the datasets automatically if they are ## not found in the path export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache python eval/eval.py {task_name} \ --mode all \ --api-server http://{api-server-ip}:10000 \ --judger-server http://{judger-server-ip}:20000 \ -w {oc_output_dir} ``` For detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`. After evaluation completion, results are saved in `{oc_output_dir}/{yyyymmdd_hhmmss}`, where `{yyyymmdd_hhmmss}` represents the task timestamp. ## Step-by-Step Evaluation ### Inference Stage This stage generates model responses for the dataset. 1. **Deploy Target Model** ```shell lmdeploy serve api_server --server-port 10000 <--other-options> ``` 2. **Generate Inference Configuration and Execute** ```shell cd {the/root/path/of/lmdeploy/repo} ## Specify the dataset path. OC will download the datasets automatically if they are ## not found in the path export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets # Run inference task python eval/eval.py {task_name} \ --mode infer \ --api-server http://{api-server-ip}:10000 \ -w {oc_output_dir} ``` For detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`. ### Evaluation Stage This stage uses the evaluation model (Judger) to assess the quality of inference results. 1. **Deploy Evaluation Model (Judger)** ```shell lmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 --session-len 65536 ``` 2. **Generate Evaluation Configuration and Execute** ```shell cd {the/root/path/of/lmdeploy/repo} ## Specify the dataset path. OC will download the datasets automatically if they are ## not found in the path export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets # Run evaluation task opencompass /path/to/judger_config.py -m eval -w {oc_output_dir} -r {yyyymmdd_hhmmss} ``` Important Notes: - `task_name` must be identical to the one used in the inference stage - The `oc_output_dir` specified with `-w` must match the directory used in the inference stage - The `-r` parameter indicates "previous outputs & results" and should specify the timestamp directory generated during the inference stage (the subdirectory under `{oc_output_dir}`) For detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`. ================================================ FILE: docs/en/benchmark/evaluate_with_vlmevalkit.md ================================================ # Multi-Modal Model Evaluation Guide This document describes how to evaluate multi-modal models' capabilities using VLMEvalKit and LMDeploy. ## Environment Setup ```shell pip install lmdeploy git clone https://github.com/open-compass/VLMEvalKit.git cd VLMEvalKit && pip install -e . ``` It is recommended to install LMDeploy and VLMEvalKit in separate Python virtual environments to avoid potential dependency conflicts. ## Evaluations 1. **Deploy Large Multi-Modality Models (LMMs)** ```shell lmdeploy serve api_server --server-port 23333 <--other-options> ``` 2. **Config the Evaluation Settings** Modify `VLMEvalKit/vlmeval/config.py`, add following LMDeploy API configurations in the `api_models` dictionary. The `` is a custom name for your evaluation task (e.g., `lmdeploy_qwen3vl-4b`). The `model` parameter should match the `` used in the `lmdeploy serve` command. ```python // filepath: VLMEvalKit/vlmeval/config.py // ...existing code... api_models = { # lmdeploy api ..., "": partial( LMDeployAPI, api_base="http://0.0.0.0:23333/v1/chat/completions", model="", retry=4, timeout=1200, temperature=0.7, # modify if needed max_new_tokens=16384, # modify if needed ), ... } // ...existing code... ``` 3. **Start Evaluations** ```shell cd VLMEvalKit python run.py --data OCRBench --model --api-nproc 16 --reuse --verbose --api 123 ``` The `` should match the one used in the above config file. Parameter explanations: - `--data`: Specify the dataset for evaluation (e.g., `OCRBench`). - `--model`: Specify the model name, which must match the `` in your `config.py`. - `--api-nproc`: Specify the number of parallel API calls. - `--reuse`: Reuse previous inference results to avoid re-running completed evaluations. - `--verbose`: Enable verbose logging. ================================================ FILE: docs/en/conf.py ================================================ # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys from pathlib import Path from fastapi import FastAPI from fastapi.responses import Response from yaml import safe_dump sys.path.insert(0, os.path.abspath('../..')) from lmdeploy.serve.openai.api_server import router # noqa: E402 from lmdeploy.serve.proxy.proxy import app as proxy_server # noqa: E402 version_file = '../../lmdeploy/version.py' with open(version_file, 'r') as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # -- Project information ----------------------------------------------------- project = 'lmdeploy' copyright = '2021-2024, OpenMMLab' author = 'LMDeploy Authors' # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- Generate OpenAPI Spec ----------------------------------------------------- openai_server = FastAPI() openai_server.include_router(router) @openai_server.get('/metrics', response_class=Response, responses={ 200: { 'content': { 'text/plain': {} }, 'description': 'Prometheus metrics data' }, 404: { 'description': 'Metrics Endpoint not enabled' } }) def metrics(): """**[Optional]** Prometheus metrics endpoint.""" pass spec_dir = Path('_static') spec_dir.mkdir(exist_ok=True) with open(spec_dir / 'openai.yaml', 'w', encoding='utf-8') as f: f.write(safe_dump(openai_server.openapi())) with open(spec_dir / 'proxy.yaml', 'w', encoding='utf-8') as f: f.write(safe_dump(proxy_server.openapi())) # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'myst_parser', 'sphinx_argparse_cli', 'sphinx.ext.autodoc', 'sphinx.ext.autosectionlabel', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', 'sphinx_copybutton', 'sphinx_tabs.tabs', 'sphinxcontrib.mermaid', 'sphinxcontrib.openapi', ] # yapf: disable autosectionlabel_prefix_document = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { '.rst': 'restructuredtext', '.md': 'markdown', } # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = 'en' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # # html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_book_theme' html_logo = '_static/image/lmdeploy-logo.svg' html_title = project html_copy_source = True html_last_updated_fmt = '' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { 'path_to_docs': 'docs/en', 'repository_url': 'https://github.com/InternLM/lmdeploy', 'repository_branch': 'main', # 'show_navbar_depth': 3, # 'navigation_depth': 4, # 'collapse_navigation': False, 'use_edit_page_button': True, 'use_source_button': True, 'use_issues_button': True, 'use_repository_button': True, 'use_download_button': True, 'use_sidenotes': True, # 'show_toc_level': 2, # "icon_links": [ # { # "name": "切换至简体中文", # "url": "https://lmdeploy.readthedocs.io/en/latest", # "icon": "https://img.shields.io/badge/Doc-%E7%AE%80%E4%BD%93%E4%B8%AD%E6%96%87-blue", # noqa: #501 # "type": "url", # }, # ], } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] html_css_files = ['css/readthedocs.css'] # Enable ::: for my_st myst_enable_extensions = [ 'dollarmath', 'amsmath', 'deflist', # "html_admonition", # "html_image", 'colon_fence', # "smartquotes", # "replacements", # "linkify", # "substitution", ] myst_heading_anchors = 5 # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'lmdeploydoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'lmdeploy.tex', 'lmdeploy Documentation', 'LMDeploy Contributors', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, 'lmdeploy', 'lmdeploy Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'lmdeploy', 'lmdeploy Documentation', author, 'lmdeploy', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] # -- Extension configuration ------------------------------------------------- # Ignore >>> when copying code copybutton_prompt_text = r'>>> |\.\.\. ' copybutton_prompt_is_regexp = True autodoc_preserve_defaults = True navigation_with_keys = False # Mock out external dependencies here, # otherwise the autodoc pages may be blank. autodoc_mock_imports = [ 'torch', 'torchvision', 'transformers', '_turbomind', 'triton', ] autodoc_type_aliases = {'PydanticDataclass': 'pydantic.dataclasses.PydanticDataclass'} intersphinx_mapping = { 'python': ('https://docs.python.org/3.10', None), 'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest', None), 'pillow': ('https://pillow.readthedocs.io/en/stable', None), 'numpy': ('https://numpy.org/doc/stable', None), 'torch': ('https://pytorch.org/docs/stable', None), 'torchvision': ('https://pytorch.org/vision/stable', None), } ================================================ FILE: docs/en/faq.md ================================================ # FAQ ## ModuleNotFoundError ### No module named 'mmengine.config.lazy' There is probably a cached mmengine in your local host. Try to install its latest version. ```shell pip install --upgrade mmengine ``` ### No module named '\_turbomind' It may have been caused by the following reasons. 1. You haven't installed lmdeploy's precompiled package. `_turbomind` is the pybind package of c++ turbomind, which involves compilation. It is recommended that you install the precompiled one. ```shell pip install lmdeploy[all] ``` 2. If you have installed it and still encounter this issue, it is probably because you are executing turbomind-related command in the root directory of lmdeploy source code. Switching to another directory will fix it. But if you are a developer, you often need to develop and compile locally. The efficiency of installing whl every time is too low. You can specify the path of lib after compilation through symbolic links. ```shell # mkdir and build locally mkdir bld && cd bld && bash ../generate.sh && ninja -j$(nproc) # go to the lmdeploy subdirectory from bld and set symbolic links cd ../lmdeploy && ln -s ../bld/lib . # go to the lmdeploy root directory cd .. # use the python command such as check_env python3 -m lmdeploy check_env ``` If you still encounter problems finding turbomind so, it means that maybe there are multiple Python environments on your local machine, and the version of Python does not match during compilation and execution. In this case, you need to set `PYTHON_EXECUTABLE` in `lmdeploy/generate.sh` according to the actual situation, such as `-DPYTHON_EXECUTABLE=/usr/local/bin/python3`. And it needs to be recompiled. ## Libs ### libnccl.so.2 not found Make sure you have install lmdeploy (>=v0.0.5) through `pip install lmdeploy[all]`. If the issue still exists after lmdeploy installation, add the path of `libnccl.so.2` to environment variable LD_LIBRARY_PATH. ```shell # Get the location of nvidia-nccl-cu11 package pip show nvidia-nccl-cu11|grep Location # insert the path of "libnccl.so.2" to LD_LIBRARY_PATH export LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH ``` ### symbol cudaFreeAsync version libcudart.so.11.0 not defined in file libcudart.so.11.0 with link time reference It's probably due to a low-version cuda toolkit. LMDeploy runtime requires a minimum CUDA version of 11.2 ## Inference ### RuntimeError: \[TM\]\[ERROR\] CUDA runtime error: out of memory /workspace/lmdeploy/src/turbomind/utils/allocator.h This is usually due to a disproportionately large memory ratio for the k/v cache, which is dictated by `TurbomindEngineConfig.cache_max_entry_count`. The implications of this parameter have slight variations in different versions of lmdeploy. For specifics, please refer to the source code for the \[detailed notes\] (https://github.com/InternLM/lmdeploy/blob/52419bd5b6fb419a5e3aaf3c3b4dea874b17e094/lmdeploy/messages.py#L107) If you encounter this issue while using the pipeline interface, please reduce the `cache_max_entry_count` in `TurbomindEngineConfig` like following: ```python from lmdeploy import pipeline, TurbomindEngineConfig backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` If OOM occurs when you run CLI tools, please pass `--cache-max-entry-count` to decrease k/v cache memory ratio. For example: ```shell # chat command lmdeploy chat internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2 # server command lmdeploy serve api_server internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2 ``` ## Serve ### Api Server Fetch Timeout The image URL fetch timeout for the API server can be configured via the environment variable `LMDEPLOY_FETCH_TIMEOUT`. By default, requests may take up to 10 seconds before timing out. See [lmdeploy/vl/utils.py](https://github.com/InternLM/lmdeploy/blob/7b6876eafcb842633e0efe8baabe5906d7beeeea/lmdeploy/vl/utils.py#L31) for usage. ## Quantization ### RuntimeError: \[enforce fail at inline_container.cc:337\] . unexpected pos 4566829760 vs 4566829656 Please check your disk space. This error is due to insufficient disk space when saving weights, which might be encountered when quantizing the 70B model ### ModuleNotFoundError: No module named 'flash_attn' Quantizing `qwen` requires the installation of `flash-attn`. But based on feedback from community users, `flash-attn` can be challenging to install. Therefore, we have removed it from lmdeploy dependencies and now recommend that users install it it manually as needed. ================================================ FILE: docs/en/get_started/ascend/get_started.md ================================================ # Get Started with Huawei Ascend We currently support running lmdeploy on **Atlas 800T A3, Atlas 800T A2 and Atlas 300I Duo**. The usage of lmdeploy on a Huawei Ascend device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy. Please read the original [Get Started](../get_started.md) guide before reading this tutorial. Here is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms). > \[!IMPORTANT\] > We have uploaded a docker image with KUNPENG CPU to aliyun. > Please try to pull the image by following command: > > Atlas 800T A3: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a3-latest` > > (Atlas 800T A3 currently supports only the Qwen-series with eager mode.) > > Atlas 800T A2: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest` > > 300I Duo: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:300i-duo-latest` > > (Atlas 300I Duo currently works only with graph mode.) > > To build the environment yourself, refer to the Dockerfiles [here](../../../../docker). ## Offline batch inference ### LLM inference Set `device_type="ascend"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="ascend")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM inference Set `device_type="ascend"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='ascend')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## Online serving ### Serve a LLM model Add `--device ascend` in the serve command. ```bash lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat ``` Run the following commands to launch docker container for lmdeploy LLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat" ``` ### Serve a VLM model Add `--device ascend` in the serve command ```bash lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B ``` Run the following commands to launch docker container for lmdeploy VLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B" ``` ## Inference with Command line Interface Add `--device ascend` in the serve command. ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device ascend ``` Run the following commands to launch lmdeploy chatting after starting container: ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy chat --backend pytorch --device ascend internlm/internlm2_5-7b-chat" ``` ## Quantization ### w4a16 AWQ Run the following commands to quantize weights on Atlas 800T A2. ```bash lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` Please check [supported_models](../../supported_models/supported_models.md) before use this feature. ### w8a8 SMOOTH_QUANT Run the following commands to quantize weights on Atlas 800T A2. ```bash lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu ``` Please check [supported_models](../../supported_models/supported_models.md) before use this feature. ### int8 KV-cache Quantization Ascend backend has supported offline int8 KV-cache Quantization on eager mode. Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details. ## Limitations on 300I Duo 1. only support dtype=float16. 2. only support graph mode, please do not add --eager-mode. ================================================ FILE: docs/en/get_started/camb/get_started.md ================================================ # Cambricon The usage of lmdeploy on a Cambricon device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy. Please read the original [Get Started](../get_started.md) guide before reading this tutorial. Here is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms). > \[!IMPORTANT\] > We have uploaded a docker image to aliyun. > Please try to pull the image by following command: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest` > \[!IMPORTANT\] > Currently, launching multi-device inference on Cambricon accelerators requires manually starting Ray. > > Below is an example for a 2-devices setup: > > ```shell > export MLU_VISIBLE_DEVICES=0,1 > ray start --head --resources='{"MLU": 2}' > ``` ## Offline batch inference ### LLM inference Set `device_type="camb"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="camb")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM inference Set `device_type="camb"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='camb')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## Online serving ### Serve a LLM model Add `--device camb` in the serve command. ```bash lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat ``` Run the following commands to launch docker container for lmdeploy LLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat" ``` ### Serve a VLM model Add `--device camb` in the serve command ```bash lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B ``` Run the following commands to launch docker container for lmdeploy VLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B" ``` ## Inference with Command line Interface Add `--device camb` in the serve command. ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device camb ``` Run the following commands to launch lmdeploy chatting after starting container: ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy chat --backend pytorch --device camb internlm/internlm2_5-7b-chat" ``` ================================================ FILE: docs/en/get_started/get_started.md ================================================ # Quick Start This tutorial shows the usage of LMDeploy on CUDA platform: - Offline inference of LLM model and VLM model - Serve a LLM or VLM model by the OpenAI compatible server - Console CLI to interactively chat with LLM model Before reading further, please ensure that you have installed lmdeploy as outlined in the [installation guide](installation.md) ## Offline batch inference ### LLM inference ```python from lmdeploy import pipeline pipe = pipeline('internlm/internlm2_5-7b-chat') response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` When constructing the `pipeline`, if an inference engine is not designated between the TurboMind Engine and the PyTorch Engine, LMDeploy will automatically assign one based on [their respective capabilities](../supported_models/supported_models.md), with the TurboMind Engine taking precedence by default. However, you have the option to manually select an engine. For instance, ```python from lmdeploy import pipeline, TurbomindEngineConfig pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=TurbomindEngineConfig( max_batch_size=32, enable_prefix_caching=True, cache_max_entry_count=0.8, session_len=8192, )) ``` or, ```python from lmdeploy import pipeline, PytorchEngineConfig pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=PytorchEngineConfig( max_batch_size=32, enable_prefix_caching=True, cache_max_entry_count=0.8, session_len=8192, )) ``` ```{note} The parameter "cache_max_entry_count" significantly influences the GPU memory usage. It means the proportion of FREE GPU memory occupied by the K/V cache after the model weights are loaded. The default value is 0.8. The K/V cache memory is allocated once and reused repeatedly, which is why it is observed that the built pipeline and the "api_server" mentioned later in the next consumes a substantial amount of GPU memory. If you encounter an Out-of-Memory(OOM) error, you may need to consider lowering the value of "cache_max_entry_count". ``` When use the callable `pipe()` to perform token generation with given prompts, you can set the sampling parameters via `GenerationConfig` as below: ```python from lmdeploy import GenerationConfig, pipeline pipe = pipeline('internlm/internlm2_5-7b-chat') prompts = ['Hi, pls intro yourself', 'Shanghai is'] response = pipe(prompts, gen_config=GenerationConfig( max_new_tokens=1024, top_p=0.8, top_k=40, temperature=0.6 )) ``` In the `GenerationConfig`, `top_k=1` or `temperature=0.0` indicates greedy search. For more information about pipeline, please read the [detailed tutorial](../llm/pipeline.md) ### VLM inference The usage of VLM inference pipeline is akin to that of LLMs, with the additional capability of processing image data with the pipeline. For example, you can utilize the following code snippet to perform the inference with an InternVL model: ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` In VLM pipeline, the default image processing batch size is 1. This can be adjusted by `VisionConfig`. For instance, you might set it like this: ```python from lmdeploy import pipeline, VisionConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B', vision_config=VisionConfig( max_batch_size=8 )) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` However, the larger the image batch size, the greater risk of an OOM error, because the LLM component within the VLM model pre-allocates a massive amount of memory in advance. We encourage you to manually choose between the TurboMind Engine and the PyTorch Engine based on their respective capabilities, as detailed in [the supported-models matrix](../supported_models/supported_models.md). Additionally, follow the instructions in [LLM Inference](#llm-inference) section to reduce the values of memory-related parameters ## Serving As demonstrated in the previous [offline batch inference](#offline-batch-inference) section, this part presents the respective serving methods for LLMs and VLMs. ### Serve a LLM model ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat ``` This command will launch an OpenAI-compatible server on the localhost at port `23333`. You can specify a different server port by using the `--server-port` option. For more options, consult the help documentation by running `lmdeploy serve api_server --help`. Most of these options align with the engine configuration. To access the service, you can utilize the official OpenAI Python package `pip install openai`. Below is an example demonstrating how to use the entrypoint `v1/chat/completions` ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": " provide three suggestions about time management"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` We encourage you to refer to the detailed guide for more comprehensive information about [serving with Docker](../llm/api_server.md), [function calls](../llm/api_server_tools.md) and other topics ### Serve a VLM model ```shell lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` ```{note} LMDeploy reuses the vision component from upstream VLM repositories. Each upstream VLM model may have different dependencies. Consequently, LMDeploy has decided not to include the dependencies of the upstream VLM repositories in its own dependency list. If you encounter an "ImportError" when using LMDeploy for inference with VLM models, please install the relevant dependencies yourself. ``` After the service is launched successfully, you can access the VLM service in a manner similar to how you would access the `gptv4` service by modifying the `api_key` and `base_url` parameters: ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ## Inference with Command line Interface LMDeploy offers a very convenient CLI tool for users to chat with the LLM model locally. For example: ```shell lmdeploy chat internlm/internlm2_5-7b-chat --backend turbomind ``` It is designed to assist users in checking and verifying whether LMDeploy supports their model, whether the chat template is applied correctly, and whether the inference results are delivered smoothly. Another tool, `lmdeploy check_env`, aims to gather the essential environment information. It is crucial when reporting an issue to us, as it helps us diagnose and resolve the problem more effectively. If you have any doubt about their usage, you can try using the `--help` option to obtain detailed information. ================================================ FILE: docs/en/get_started/index.rst ================================================ On Other Platforms ================================= .. toctree:: :maxdepth: 1 :caption: OtherPF ascend/get_started.md maca/get_started.md camb/get_started.md ================================================ FILE: docs/en/get_started/installation.md ================================================ # Installation LMDeploy is a python library for compressing, deploying, and serving Large Language Models(LLMs) and Vision-Language Models(VLMs). Its core inference engines include TurboMind Engine and PyTorch Engine. The former is developed by C++ and CUDA, striving for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers. It supports LLMs and VLMs deployment on both Linux and Windows platform, with minimum requirement of CUDA version 11.3. Furthermore, it is compatible with the following NVIDIA GPUs: - Volta(sm70): V100 - Turing(sm75): 20 series, T4 - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 - Ada Lovelace(sm89): 40 series ## Install with pip (Recommend) It is recommended installing lmdeploy using pip in a conda environment (python 3.10 - 3.13): ```shell conda create -n lmdeploy python=3.10 -y conda activate lmdeploy pip install lmdeploy ``` The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by: ```shell export LMDEPLOY_VERSION=0.12.2 export PYTHON_VERSION=310 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` ## Install from source By default, LMDeploy will build with NVIDIA CUDA support, utilizing both the Turbomind and PyTorch backends. Before installing LMDeploy, ensure you have successfully installed the CUDA Toolkit. Once the CUDA toolkit is successfully set up, you can build and install LMDeploy with a single command: ```shell pip install git+https://github.com/InternLM/lmdeploy.git ``` You can also explicitly disable the Turbomind backend to avoid CUDA compilation by setting the `DISABLE_TURBOMIND` environment variable: ```shell DISABLE_TURBOMIND=1 pip install git+https://github.com/InternLM/lmdeploy.git ``` If you prefer a specific version instead of the `main` branch of LMDeploy, you can specify it in your command: ```shell pip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.11.0.zip ``` If you want to build LMDeploy with support for Ascend, Cambricon, or MACA, install LMDeploy with the corresponding `LMDEPLOY_TARGET_DEVICE` environment variable. LMDeploy also supports installation on AMD GPUs with ROCm. ```shell #The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies: docker run -it \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --device=/dev/kfd \ --device=/dev/dri \ --group-add video \ --ipc=host \ --network=host \ --shm-size 32G \ -v /root:/workspace \ rocm/pytorch:latest #Once inside the container, install LMDeploy with ROCm support: LMDEPLOY_TARGET_DEVICE=rocm pip install git+https://github.com/InternLM/lmdeploy.git ``` ================================================ FILE: docs/en/get_started/maca/get_started.md ================================================ # MetaX-tech The usage of lmdeploy on a MetaX-tech device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy. Please read the original [Get Started](../get_started.md) guide before reading this tutorial. Here is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms). > \[!IMPORTANT\] > We have uploaded a docker image to aliyun. > Please try to pull the image by following command: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest` ## Offline batch inference ### LLM inference Set `device_type="maca"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="maca")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM inference Set `device_type="maca"` in the `PytorchEngineConfig`: ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='maca')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## Online serving ### Serve a LLM model Add `--device maca` in the serve command. ```bash lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat ``` Run the following commands to launch docker container for lmdeploy LLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat" ``` ### Serve a VLM model Add `--device maca` in the serve command ```bash lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B ``` Run the following commands to launch docker container for lmdeploy VLM serving: ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B" ``` ## Inference with Command line Interface Add `--device maca` in the serve command. ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device maca ``` Run the following commands to launch lmdeploy chatting after starting container: ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy chat --backend pytorch --device maca internlm/internlm2_5-7b-chat" ``` ================================================ FILE: docs/en/index.rst ================================================ Welcome to LMDeploy's tutorials! ==================================== .. figure:: ./_static/image/lmdeploy-logo.svg :width: 50% :align: center :alt: LMDeploy :class: no-scaled-link .. raw:: html

LMDeploy is a toolkit for compressing, deploying, and serving LLM.

Star Watch Fork

LMDeploy has the following core features: * **Efficient Inference**: LMDeploy delivers up to 1.8x higher request throughput than vLLM, by introducing key features like persistent batch(a.k.a. continuous batching), blocked KV cache, dynamic split&fuse, tensor parallelism, high-performance CUDA kernels and so on. * **Effective Quantization**: LMDeploy supports weight-only and k/v quantization, and the 4-bit inference performance is 2.4x higher than FP16. The quantization quality has been confirmed via OpenCompass evaluation. * **Effortless Distribution Server**: Leveraging the request distribution service, LMDeploy facilitates an easy and efficient deployment of multi-model services across multiple machines and cards. * **Excellent Compatibility**: LMDeploy supports `KV Cache Quant `_, `AWQ `_ and `Automatic Prefix Caching `_ to be used simultaneously. Documentation ------------- .. _get_started: .. toctree:: :maxdepth: 1 :caption: Get Started get_started/installation.md get_started/get_started.md get_started/index.rst .. _supported_models: .. toctree:: :maxdepth: 1 :caption: Models supported_models/supported_models.md supported_models/reward_models.md .. _llm_deployment: .. toctree:: :maxdepth: 1 :caption: Large Language Models(LLMs) Deployment llm/pipeline.md llm/api_server.md llm/api_server_tools.md llm/api_server_reasoning.md llm/api_server_lora.md llm/proxy_server.md .. _vlm_deployment: .. toctree:: :maxdepth: 1 :caption: Vision-Language Models(VLMs) Deployment multi_modal/vl_pipeline.md multi_modal/api_server_vl.md multi_modal/index.rst .. _quantization: .. toctree:: :maxdepth: 1 :caption: Quantization quantization/w4a16.md quantization/w8a8.md quantization/kv_quant.md quantization/llm_compressor.md .. _benchmark: .. toctree:: :maxdepth: 1 :caption: Benchmark benchmark/benchmark.md benchmark/evaluate_with_opencompass.md benchmark/evaluate_with_vlmevalkit.md .. toctree:: :maxdepth: 1 :caption: Advanced Guide inference/turbomind.md inference/pytorch.md advance/pytorch_new_model.md advance/long_context.md advance/chat_template.md advance/debug_turbomind.md advance/structed_output.md advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md advance/context_parallel.md advance/spec_decoding.md advance/update_weights.md .. toctree:: :maxdepth: 1 :caption: API Reference api/pipeline.rst api/openapi.rst api/cli.rst Indices and tables ================== * :ref:`genindex` * :ref:`search` * :ref:`routingtable` ================================================ FILE: docs/en/inference/load_hf.md ================================================ # Load huggingface model directly Starting from v0.1.0, Turbomind adds the ability to pre-process the model parameters on-the-fly while loading them from huggingface style models. ## Supported model type Currently, Turbomind support loading three types of model: 1. A lmdeploy-quantized model hosted on huggingface.co, such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. 2. Other LM models on huggingface.co like Qwen/Qwen-7B-Chat ## Usage ### 1) A lmdeploy-quantized model For models quantized by `lmdeploy.lite` such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc. ``` repo_id=internlm/internlm-chat-20b-4bit model_name=internlm-chat-20b # or # repo_id=/path/to/downloaded_model # Inference by TurboMind lmdeploy chat $repo_id --model-name $model_name # Serving with Restful API lmdeploy serve api_server $repo_id --model-name $model_name --tp 1 ``` ### 2) Other LM models For other LM models such as Qwen/Qwen-7B-Chat or baichuan-inc/Baichuan2-7B-Chat. LMDeploy supported models can be viewed through `lmdeploy list`. ``` repo_id=Qwen/Qwen-7B-Chat model_name=qwen-7b # or # repo_id=/path/to/Qwen-7B-Chat/local_path # Inference by TurboMind lmdeploy chat $repo_id --model-name $model_name # Serving with Restful API lmdeploy serve api_server $repo_id --model-name $model_name --tp 1 ``` ================================================ FILE: docs/en/inference/pytorch.md ================================================ # Architecture of lmdeploy.pytorch `lmdeploy.pytorch` is an inference engine in LMDeploy that offers a developer-friendly framework to users interested in deploying their own models and developing new features. ## Design ![pytorch arch](https://github.com/grimoire/lmdeploy/blob/media/lmdeploy_pytorch_arch.png?raw=true) ## API `lmdeploy.pytorch` shares service interfaces with `Turbomind`, and the inference service is implemented by `Engine` and `EngineInstance`. `EngineInstance` acts as the sender of inference requests, encapsulating and sending requests to the `Engine` to achieve streaming inference. The inference interface of `EngineInstance` is thread-safe, allowing instances in different threads to initiate requests simultaneously. The `Engine` will automatically perform batch processing based on the current system resources. Engine is the request receiver and executor. It contain modules: - `ModelAgent` serves as a wrapper for the model, handling tasks such as loading model/adapters, managing the cache, and implementing tensor parallelism. - The `Scheduler` functions as the sequence manager, determining the sequences and adapters to participate in the current step, and subsequently allocating resources for them. - `RequestManager` is tasked with sending and receiving requests. acting as the bridge between the `Engine` and `EngineInstance`. ## Engine The Engine responses to requests in a sub-thread, following this looping sequence: 1. Get new requests through `RequestManager`. These requests are cached for now. 2. The `Scheduler` performs scheduling, deciding which cached requests should be processed and allocating resources for them. 3. `ModelAgent` swaps the caches according to the information provided by the Scheduler, then performs inference with the patched model. 4. The `Scheduler` updates the status of requests based to the inference results from `ModelAgent`. 5. `RequestManager` responds to the sender (`EngineInstance`), and the process return to step 1. Now, Let's delve deeper into the modules that participate in these steps. ### Scheduler In LLM inference, caching history key and value states is a common practice to prevent redundant computation. However, as history lengths vary in a batch of sequences, we need to pad the caches to enable batching inference. Unfortunately, this padding can lead to significant memory wastage, limiting the transformer's performance. [vLLM](https://docs.vllm.ai) employs a paging-based strategy, allocating caches in page blocks to minimize extra memory usage. Our Scheduler module in the Engine shares a similar design, allocating resources based on sequence length in blocks and evicting unused blocks to support larger batching and longer session lengths. Additionally, we support [S-LoRA](https://github.com/S-LoRA/S-LoRA), which enables the use of multiple LoRA adapters on limited memory. ### ModelAgent `lmdeploy.pytorch` supports Tensor Parallelism, which leads to complex model initialization, cache allocation, and weight partitioning. ModelAgent is designed to abstract these complexities, allowing the Engine to focus solely on maintaining the pipeline. ModelAgent consists of two components: 1. \`**patched_model**: : This is the transformer model after patching. In comparison to the original model, the patched model incorporates additional features such as Tensor Parallelism, quantization, and high-performance kernels. 2. **cache_engine**: This component manages the caches. It receives commands from the Scheduler and performs host-device page swaps. Only GPU blocks are utilized for caching key/value pairs and adapters. ## Features `lmdeploy.pytorch` supports new features including: - **Continuous Batching**: As the sequence length in a batch may vary, padding is often necessary for batching inference. However, large padding can lead to additional memory usage and unnecessary computation. To address this, we employ continuous batching, where all sequences are concatenated into a single long sequence to avoid padding. - **Tensor Parallelism**: The GPU memory usage of LLM might exceed the capacity of a single GPU. Tensor parallelism is utilized to accommodate such models on multiple devices. Each device handles parts of the model simultaneously, and the results are gathered to ensure correctness. - **S-LoRA**: LoRA adapters can be used to train LLM on devices with limited memory. While it's common practice to merge adapters into the model weights before deployment, loading multiple adapters in this way can consume a significant amount of memory. We support S-LoRA, where adapters are paged and swapped in when necessary. Special kernels are developed to support inference with unmerged adapters, enabling the loading of various adapters efficiently. - **Quantization**: Model quantization involves performing computations with low precision. `lmdeploy.pytorch` supports w8a8 quantization. For more details, refer to [w8a8](../quantization/w8a8.md). ================================================ FILE: docs/en/inference/turbomind.md ================================================ # Architecture of TurboMind TurboMind is an inference engine that supports high throughput inference for conversational LLMs. It's based on NVIDIA's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). Major features of TurboMind include an efficient LLaMa implementation, the persistent batch inference model and an extendable KV cache manager. ## High level overview of TurboMind ``` +--------------------+ | API | +--------------------+ | ^ request | | stream callback v | +--------------------+ fetch +-------------------+ | Persistent Batch | <-------> | KV Cache Manager | +--------------------+ update +-------------------+ ^ | v +------------------------+ | LLaMA implementation | +------------------------+ | FT kernels & utilities | +------------------------+ ``` ## Persistent Batch You may recognize this feature as "continuous batching" in other repos. But during the concurrent development of the feature, we modeled the inference of a conversational LLM as a persistently running batch whose lifetime spans the entire serving process, hence the name "persistent batch". To put it simply - The persistent batch as N pre-configured batch slots. - Requests join the batch when there are free slots available. A batch slot is released and can be reused once the generation of the requested tokens is finished. - __On cache-hits (see below), history tokens don't need to be decoded in every round of a conversation; generation of response tokens will start instantly.__ - The batch grows or shrinks automatically to minimize unnecessary computations. ## KV Cache Manager The [KV cache manager](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/SequenceManager.h) of TurboMind is a memory-pool-liked object that also implements LRU policy so that it can be viewed as a form of __cache of KV caches__. It works in the following way - All device memory required for KV cache is allocated by the manager. A fixed number of slots is pre-configured to match the memory size of the system. Each slot corresponds to the memory required by the KV cache of a single sequence. Allocation chunk-size can be configure to implement pre-allocate/on-demand style allocation policy (or something in-between). - When space for the KV cache of a new sequence is requested but no free slots left in the pool, the least recently used sequence is evicted from the cache and its device memory is directly reused by the new sequence. However, this is not the end of the story. - Fetching sequence currently resides in one of the slots resembles a _cache-hit_, the history KV cache is returned directly and no context decoding is needed. - Victim (evicted) sequences are not erased entirely but converted to its most compact form, i.e. token IDs. When the same sequence id is fetched later (_cache-miss_) the token IDs will be decoded by FMHA backed context decoder and converted back to KV cache. - The eviction and conversion are handled automatically inside TurboMind and thus transparent to the users. __From the user's aspect, system that use TurboMind has access to infinite device memory.__ ## LLaMa implementation Our implementation of the LLaMa family models is modified from Gpt-NeoX model in FasterTransformer. In addition to basic refactoring and modifications to support the LLaMa family, we made some improvements to enable high performance inference of conversational models, most importantly: - To support fast context decoding in multi-round conversations. We replaced the attention implementation in context decoder with a [cutlass](https://github.com/NVIDIA/cutlass)-based FMHA implementation that supports mismatched Q/K lengths. - We introduced indirect buffer pointers in both context FMHA and generation FMHA to support the discontinuity in KV cache within the batch. - To support concurrent inference with persistent batch, new synchronization mechanism was designed to orchestrate the worker threads running in tensor parallel mode. - To maximize the throughput, we implement INT8 KV cache support to increase the max batch size. It's effective because in real-world serving scenarios, KV cache costs more memory and consumes more memory bandwidth than weights or other activations. - We resolved an NCCL hang issue when running multiple model instances in TP mode within a single process, NCCL APIs are now guarded by host-side synchronization barriers. ## API TurboMind supports a Python API that enables streaming output and tensor parallel mode. ## Difference between FasterTransformer and TurboMind Apart of the features described above, there are still many minor differences that we don't cover in this document. Notably, many capabilities of FT are dropped in TurboMind because of the difference in objectives (e.g. prefix prompt, beam search, context embedding, sparse GEMM, GPT/T5/other model families, etc) ## FAQ ### Supporting Huggingface models For historical reasons, TurboMind's weight layout is based on [the original LLaMa implementation](https://github.com/facebookresearch/llama) (differ only by a transpose). The implementation in huggingface transformers uses a [different layout](https://github.com/huggingface/transformers/blob/45025d92f815675e483f32812caa28cce3a960e7/src/transformers/models/llama/convert_llama_weights_to_hf.py#L123C76-L123C76) for `W_q` and `W_k` which is handled in [deploy.py](https://github.com/InternLM/lmdeploy/blob/ff4648a1d09e5aec74cf70efef35bfaeeac552e0/lmdeploy/serve/turbomind/deploy.py#L398). ================================================ FILE: docs/en/inference/turbomind_config.md ================================================ # TurboMind Config TurboMind is one of the inference engines of LMDeploy. When using it to do model inference, you need to convert the input model into a TurboMind model. In the TurboMind model folder, besides model weight files, the TurboMind model also includes some other files, among which the most important is the configuration file `triton_models/weights/config.ini` that is closely related to inference performance. If you are using LMDeploy version 0.0.x, please refer to the [turbomind 1.0 config](#turbomind-10-config) section to learn the relevant content in the configuration. Otherwise, please read [turbomind 2.0 config](#turbomind-2x-config) to familiarize yourself with the configuration details. ## TurboMind 2.x config Take the `llama-2-7b-chat` model as an example. In TurboMind 2.x, its config.ini content is as follows: ```toml [llama] model_name = "llama2" tensor_para_size = 1 head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 session_len = 4104 weight_type = "fp16" rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 group_size = 0 max_batch_size = 64 max_context_token_num = 1 step_length = 1 cache_max_entry_count = 0.5 cache_block_seq_len = 128 cache_chunk_size = 1 enable_prefix_caching = false quant_policy = 0 max_position_embeddings = 2048 rope_scaling_factor = 0.0 use_logn_attn = 0 ``` These parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**. ```toml model_name = "llama2" head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 ``` Comparing to TurboMind 1.0, the model attribute part in the config remains the same with TurboMind 1.0, while the inference parameters have changed In the following sections, we will focus on introducing the inference parameters. ### data type `weight_type` and `group_size` are the relevant parameters, **which cannot be modified**. `weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included. ### batch size The maximum batch size is still set through `max_batch_size`. But its default value has been changed from 32 to 64, and `max_batch_size` is no longer related to `cache_max_entry_count`. ### k/v cache size k/v cache memory is determined by `cache_block_seq_len` and `cache_max_entry_count`. TurboMind 2.x has implemented Paged Attention, managing the k/v cache in blocks. `cache_block_seq_len` represents the length of the token sequence in a k/v block with a default value 128. TurboMind calculates the memory size of the k/v block according to the following formula: ``` cache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type) ``` For the llama2-7b model, when storing k/v as the `half` type, the memory of a k/v block is: `128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB` The meaning of `cache_max_entry_count` varies depending on its value: - When it's a decimal between (0, 1), `cache_max_entry_count` represents the percentage of memory used by k/v blocks. For example, if turbomind launches on a A100-80G GPU with `cache_max_entry_count` being `0.5`, the total memory used by the k/v blocks is `80 * 0.5 = 40G`. - When lmdeploy is greater than v0.2.1, `cache_max_entry_count` determines the percentage of **free memory** for k/v blocks, defaulting to `0.8`. For example, with Turbomind on an A100-80G GPU running a 13b model, the memory for k/v blocks would be `(80 - 26) * 0.8 = 43.2G`, utilizing 80% of the free 54G. - When it's an integer > 0, it represents the total number of k/v blocks The `cache_chunk_size` indicates the size of the k/v cache chunk to be allocated each time new k/v cache blocks are needed. Different values represent different meanings: - When it is an integer > 0, `cache_chunk_size` number of k/v cache blocks are allocated. - When the value is -1, `cache_max_entry_count` number of k/v cache blocks are allocated. - When the value is 0, `sqrt(cache_max_entry_count)` number of k/v cache blocks are allocated. ### prefix caching switch Prefix caching feature can be controlled by setting the `enable_prefix_caching` parameter. When set to `True`, it indicates that the feature is enabled, and when set to `False`, it indicates that the feature is disabled. The default value is `False`. Prefix caching feature is mainly applicable to scenarios where multiple requests have the same prompt prefix (such as system prompt). The k/v blocks of this identical prefix part will be cached and reused by multiple requests, thereby saving the overhead of redundant computations and improving inference performance. The longer the identical prompt prefix, the greater the performance improvement. Since k/v block is the smallest granularity for reuse in prefix caching, if the identical prompt prefix is less than one block (prefix length \< cache_block_seq_len), there will be no improvement in inference performance. ### kv quantization and inference switch - `quant_policy=4` means 4bit k/v quantization and inference - `quant_policy=8` indicates 8bit k/v quantization and inference Please refer to [kv quant](../quantization/kv_quant.md) for detailed guide. ### long context switch By setting `rope_scaling_factor = 1.0`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output. Regarding the principle of Dynamic NTK, please refer to: 1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases 2. https://kexue.fm/archives/9675 You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`. ## TurboMind 1.0 config Taking the `llama-2-7b-chat` model as an example, in TurboMind 1.0, its `config.ini` content is as follows: ```toml [llama] model_name = "llama2" tensor_para_size = 1 head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 session_len = 4104 weight_type = "fp16" rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 group_size = 0 max_batch_size = 32 max_context_token_num = 4 step_length = 1 cache_max_entry_count = 48 cache_chunk_size = 1 use_context_fmha = 1 quant_policy = 0 max_position_embeddings = 2048 use_dynamic_ntk = 0 use_logn_attn = 0 ``` These parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**. ```toml model_name = "llama2" head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 ``` In the following sections, we will focus on introducing the inference parameters. ### data type `weight_type` and `group_size` are the relevant parameters, **which cannot be modified**. `weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included. ### batch size `max_batch_size` determines the max size of a batch during inference. In general, the larger the batch size is, the higher the throughput is. But make sure that `max_batch_size <= cache_max_entry_count` ### k/v cache size TurboMind allocates k/v cache memory based on `session_len`, `cache_chunk_size`, and `cache_max_entry_count`. - `session_len` denotes the maximum length of a sequence, i.e., the size of the context window. - `cache_chunk_size` indicates the size of k/v sequences to be allocated when new sequences are added. - `cache_max_entry_count` signifies the maximum number of k/v sequences that can be cached. ### kv int8 switch When initiating 8bit k/v inference, change `quant_policy = 4` and `use_context_fmha = 0`. Please refer to [kv int8](../quantization/kv_quant.md) for a guide. ### long context switch By setting `use_dynamic_ntk = 1`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output. Regarding the principle of Dynamic NTK, please refer to: 1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases 2. https://kexue.fm/archives/9675 You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`. ================================================ FILE: docs/en/llm/api_server.md ================================================ # OpenAI Compatible Server This article primarily discusses the deployment of a single LLM model across multiple GPUs on a single node, providing a service that is compatible with the OpenAI interface, as well as the usage of the service API. For the sake of convenience, we refer to this service as `api_server`. Regarding parallel services with multiple models, please refer to the guide about [Request Distribution Server](proxy_server.md). In the following sections, we will first introduce methods for starting the service, choosing the appropriate one based on your application scenario. Next, we focus on the definition of the service's RESTful API, explore the various ways to interact with the interface, and demonstrate how to try the service through the Swagger UI or LMDeploy CLI tools. ## Launch Service Take the [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) model hosted on huggingface hub as an example, you can choose one the following methods to start the service. ### Option 1: Launching with lmdeploy CLI ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 ``` The arguments of `api_server` can be viewed through the command `lmdeploy serve api_server -h`, for instance, `--tp` to set tensor parallelism, `--session-len` to specify the max length of the context window, `--cache-max-entry-count` to adjust the GPU mem ratio for k/v cache etc. ### Option 2: Deploying with docker With LMDeploy [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags), you can run OpenAI compatible server as follows: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server internlm/internlm2_5-7b-chat ``` The parameters of `api_server` are the same with that mentioned in "[option 1](#option-1-launching-with-lmdeploy-cli)" section ### Option 3: Deploying to Kubernetes cluster Connect to a running Kubernetes cluster and deploy the internlm2_5-7b-chat model service with [kubectl](https://kubernetes.io/docs/reference/kubectl/) command-line tool (replace `` with your huggingface hub token): ```shell sed 's/{{HUGGING_FACE_HUB_TOKEN}}//' k8s/deployment.yaml | kubectl create -f - \ && kubectl create -f k8s/service.yaml ``` In the example above the model data is placed on the local disk of the node (hostPath). Consider replacing it with high-availability shared storage if multiple replicas are desired, and the storage can be mounted into container using [PersistentVolume](https://kubernetes.io/docs/concepts/storage/persistent-volumes/). ## RESTful API LMDeploy's RESTful API is compatible with the following three OpenAI interfaces: - /v1/chat/completions - /v1/models - /v1/completions You can overview and try out the offered RESTful APIs by the website `http://0.0.0.0:23333` as shown in the below image after launching the service successfully. ![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) If you need to integrate the service into your own projects or products, we recommend the following approach: ### Integrate with `OpenAI` Here is an example of interaction with the endpoint `v1/chat/completions` service via the openai package. Before running it, please install the openai package by `pip install openai` ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": " provide three suggestions about time management"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` If you want to use async functions, may try the following example: ```python import asyncio from openai import AsyncOpenAI async def main(): client = AsyncOpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_cards = await client.models.list()._get_page() response = await client.chat.completions.create( model=model_cards.data[0].id, messages=[ { 'role': 'system', 'content': 'You are a helpful assistant.' }, { 'role': 'user', 'content': ' provide three suggestions about time management' }, ], temperature=0.8, top_p=0.8) print(response) asyncio.run(main()) ``` You can invoke other OpenAI interfaces using similar methods. For more detailed information, please refer to the [OpenAI API guide](https://platform.openai.com/docs/guides/text-generation) ### Integrate with lmdeploy `APIClient` Below are some examples demonstrating how to visit the service through `APIClient` If you want to use the `/v1/chat/completions` endpoint, you can try the following code: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient('http://{server_ip}:{server_port}') model_name = api_client.available_models[0] messages = [{"role": "user", "content": "Say this is a test!"}] for item in api_client.chat_completions_v1(model=model_name, messages=messages): print(item) ``` For the `/v1/completions` endpoint, you can try: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient('http://{server_ip}:{server_port}') model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` ### Tools May refer to [api_server_tools](./api_server_tools.md). ### Integrate with Java/Golang/Rust May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client. Here is an example: ```shell $ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust $ ls rust/* rust/Cargo.toml rust/git_push.sh rust/README.md rust/docs: ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md rust/src: apis lib.rs models ``` ### Integrate with cURL cURL is a tool for observing the output of the RESTful APIs. - list served models `v1/models` ```bash curl http://{server_ip}:{server_port}/v1/models ``` - chat `v1/chat/completions` ```bash curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` - text completions `v1/completions` ```shell curl http://{server_ip}:{server_port}/v1/completions \ -H 'Content-Type: application/json' \ -d '{ "model": "llama", "prompt": "two steps to build a house:" }' ``` ## Launch multiple api servers Following are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes. 1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url. 2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example. ```python import os import socket from typing import List, Literal import fire def get_host_ip(): try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(('8.8.8.8', 80)) ip = s.getsockname()[0] finally: s.close() return ip def main(model_path: str, tp: int = 1, proxy_url: str = 'http://0.0.0.0:8000', port: int = 23333, backend: Literal['turbomind', 'pytorch'] = 'turbomind'): local_rank = int(os.environ.get('LOCAL_RANK', -1)) world_size = int(os.environ.get('WORLD_SIZE', -1)) local_ip = get_host_ip() if isinstance(port, List): assert len(port) == world_size port = port[local_rank] else: port += local_rank * 10 if (world_size - local_rank) % tp == 0: rank_list = ','.join([str(local_rank + i) for i in range(tp)]) command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ f'--server-name {local_ip} --server-port {port} --tp {tp} '\ f'--proxy-url {proxy_url} --backend {backend}' print(f'running command: {command}') os.system(command) if __name__ == '__main__': fire.Fire(main) ``` ## FAQ 1. When user got `"finish_reason":"length"`, it means the session is too long to be continued. The session length can be modified by passing `--session_len` to api_server. 2. When OOM appeared at the server side, please reduce the `cache_max_entry_count` of `backend_config` when launching the service. 3. Regarding the stop words, we only support characters that encode into a single index. Furthermore, there may be multiple indexes that decode into results containing the stop word. In such cases, if the number of these indexes is too large, we will only use the index encoded by the tokenizer. If you want use a stop symbol that encodes into multiple indexes, you may consider performing string matching on the streaming client side. Once a successful match is found, you can then break out of the streaming loop. 4. To customize a chat template, please refer to [chat_template.md](../advance/chat_template.md). ================================================ FILE: docs/en/llm/api_server_lora.md ================================================ # Serving LoRA ## Launch LoRA LoRA is currently only supported by the PyTorch backend. Its deployment process is similar to that of other models, and you can view the commands using lmdeploy `serve api_server -h`. Among the parameters supported by the PyTorch backend, there are configuration options for LoRA. ``` PyTorch engine arguments: --adapters [ADAPTERS [ADAPTERS ...]] Used to set path(s) of lora adapter(s). One can input key-value pairs in xxx=yyy format for multiple lora adapters. If only have one adapter, one can only input the path of the adapter.. Default: None. Type: str ``` The user only needs to pass the Hugging Face model path of the LoRA weights in the form of a dictionary to `--adapters`. ```shell lmdeploy serve api_server THUDM/chatglm2-6b --adapters mylora=chenchi/lora-chatglm2-6b-guodegang ``` After the service starts, you can find two available model names in the Swagger UI: ‘THUDM/chatglm2-6b’ and ‘mylora’. The latter is the key in the `--adapters` dictionary. ## Client usage ### CLI When using the OpenAI endpoint, the `model` parameter can be used to select either the base model or a specific LoRA weight for inference. The following example chooses to use the provided `chenchi/lora-chatglm2-6b-guodegang` for inference. ```shell curl -X 'POST' \ 'http://localhost:23334/v1/chat/completions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "model": "mylora", "messages": [ { "content": "hi", "role": "user" } ] }' ``` And here is the output: ```json { "id": "2", "object": "chat.completion", "created": 1721377275, "model": "mylora", "choices": [ { "index": 0, "message": { "role": "assistant", "content": " 很高兴哪有什么赶凳儿?(按东北语说的“起早哇”),哦,东北人都学会外语了?", "tool_calls": null }, "logprobs": null, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 17, "total_tokens": 43, "completion_tokens": 26 } } ``` ### python ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = 'mylora' response = client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": "hi"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` The printed response content is: ``` ChatCompletion(id='4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' 很高兴能够见到你哪,我也在辐射区开了个愣儿,你呢,还活着。', role='assistant', function_call=None, tool_calls=None))], created=1721377497, model='mylora', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=22, prompt_tokens=17, total_tokens=39)) ``` ================================================ FILE: docs/en/llm/api_server_reasoning.md ================================================ # Reasoning Outputs For models that support reasoning capabilities, such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), LMDeploy supports parsing the reasoning results in the service and separately records the reasoning content using `reasoning_content`. ## Examples ### DeepSeek R1 We can start the DeepSeek R1 model's api_server service just like launching other models. The difference is that we need to specify --reasoning-parser\` parameter. ``` lmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1 ``` Then, we can call the service's functionality from the client: ```python from openai import OpenAI openai_api_key = "Your API key" openai_api_base = "http://0.0.0.0:23333/v1" client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) models = client.models.list() model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] response = client.chat.completions.create(model=model, messages=messages, stream=True) for stream_response in response: print('reasoning content: ',stream_response.choices[0].delta.reasoning_content) print('content: ', stream_response.choices[0].delta.content) response = client.chat.completions.create(model=model, messages=messages, stream=False) reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content print("reasoning_content:", reasoning_content) print("content:", content) ``` ## Custom parser You only need to add a similar parser class in `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py`. ```python # import the required packages from typing import Sequence, Union, Tuple, Optional from lmdeploy.serve.openai.reasoning_parser import ( ReasoningParser, ReasoningParserManager) from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaMessage) # define a reasoning parser and register it to lmdeploy # the name list in register_module can be used # in --reasoning-parser. @ReasoningParserManager.register_module(["example"]) class ExampleParser(ReasoningParser): def __init__(self, tokenizer: object): super().__init__(tokenizer) def extract_reasoning_content_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting reasoning from an incomplete response; for use when handling reasoning calls and streaming. Has to be an instance method because it requires state - the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. Used for non-streaming responses where we have the entire model response available before sending to the client. Args: model_output (str): The model-generated string to extract reasoning content from. request (ChatCompletionRequest): he request object that was used to generate the model_output. Returns: reasoning_content (str | None): The reasoning content. final_output (str | None): The content. """ ``` Similarly, the command to start the service becomes: ``` lmdeploy serve api_server $model_path --reasoning-parser example ``` ================================================ FILE: docs/en/llm/api_server_tools.md ================================================ # Tools Calling LMDeploy supports tools for InternLM2, InternLM2.5, llama3.1 and Qwen2.5 models. Please use `--tool-call-parser` to specify which parser to use when launching the api_server. Supported names are: 1. internlm 2. qwen 3. llama3 ## Single Round Invocation Please start the service of models before running the following example. ```python from openai import OpenAI tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, } } ] messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] client = OpenAI(api_key='YOUR_API_KEY',base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) ``` ## Multiple Round Invocation ### InternLM A complete toolchain invocation process can be demonstrated through the following example. ```python from openai import OpenAI def add(a: int, b: int): return a + b def mul(a: int, b: int): return a * b tools = [{ 'type': 'function', 'function': { 'name': 'add', 'description': 'Compute the sum of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }, { 'type': 'function', 'function': { 'name': 'mul', 'description': 'Calculate the product of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }] messages = [{'role': 'user', 'content': 'Compute (3+5)*2'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) func1_name = response.choices[0].message.tool_calls[0].function.name func1_args = response.choices[0].message.tool_calls[0].function.arguments func1_out = eval(f'{func1_name}(**{func1_args})') print(func1_out) messages.append(response.choices[0].message) messages.append({ 'role': 'tool', 'content': f'3+5={func1_out}', 'tool_call_id': response.choices[0].message.tool_calls[0].id }) response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) func2_name = response.choices[0].message.tool_calls[0].function.name func2_args = response.choices[0].message.tool_calls[0].function.arguments func2_out = eval(f'{func2_name}(**{func2_args})') print(func2_out) ``` Using the InternLM2-Chat-7B model to execute the above example, the following results will be printed. ``` ChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"a": 3, "b": 5}', name='add'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=263, total_tokens=288)) 8 ChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='1', function=Function(arguments='{"a": 8, "b": 2}', name='mul'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=293, total_tokens=318)) 16 ``` ### Llama 3.1 Meta announces in [Llama3's official user guide](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) that, > There are three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: > > 1. Brave Search: Tool call to perform web searches. > 2. Wolfram Alpha: Tool call to perform complex mathematical calculations. > 3. Code Interpreter: Enables the model to output python code. Additionally, it cautions: "**Note:** We recommend using Llama 70B-instruct or Llama 405B-instruct for applications that combine conversation and tool calling. Llama 8B-Instruct can not reliably maintain a conversation alongside tool calling definitions. It can be used for zero-shot tool calling, but tool instructions should be removed for regular conversations between the model and the user." Therefore, we utilize [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) to show how to invoke the tool calling by LMDeploy `api_server`. On a A100-SXM-80G node, you can start the service as follows: ```shell lmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4 ``` For an in-depth understanding of the api_server, please refer to the detailed documentation available [here](./api_server.md). The following code snippet demonstrates how to utilize the 'Wolfram Alpha' tool. It is assumed that you have already registered on the [Wolfram Alpha](https://www.wolframalpha.com) website and obtained an API key. Please ensure that you have a valid API key to access the services provided by Wolfram Alpha ```python from openai import OpenAI import requests def request_llama3_1_service(messages): client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False) return response.choices[0].message.content # The role of "system" MUST be specified, including the required tools messages = [ { "role": "system", "content": "Environment: ipython\nTools: wolfram_alpha\n\n Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\nYou are a helpful Assistant." # noqa }, { "role": "user", "content": "Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0" # noqa } ] # send request to the api_server of llama3.1-70b and get the response # the "assistant_response" is supposed to be: # <|python_tag|>wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0") assistant_response = request_llama3_1_service(messages) print(assistant_response) # Call the API of Wolfram Alpha with the query generated by the model app_id = 'YOUR-Wolfram-Alpha-API-KEY' params = { "input": assistant_response, "appid": app_id, "format": "plaintext", "output": "json", } wolframalpha_response = requests.get( "https://api.wolframalpha.com/v2/query", params=params ) wolframalpha_response = wolframalpha_response.json() # Append the contents obtained by the model and the wolframalpha's API # to "messages", and send it again to the api_server messages += [ { "role": "assistant", "content": assistant_response }, { "role": "ipython", "content": wolframalpha_response } ] assistant_response = request_llama3_1_service(messages) print(assistant_response) ``` ### Qwen2.5 Qwen2.5 supports multi tool calling, which means that multiple tool requests can be initiated in one request ```python from openai import OpenAI import json def get_current_temperature(location: str, unit: str = "celsius"): """Get current temperature at a location. Args: location: The location to get the temperature for, in the format "City, State, Country". unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) Returns: the temperature, the location, and the unit in a dict """ return { "temperature": 26.1, "location": location, "unit": unit, } def get_temperature_date(location: str, date: str, unit: str = "celsius"): """Get temperature at a location and date. Args: location: The location to get the temperature for, in the format "City, State, Country". date: The date to get the temperature for, in the format "Year-Month-Day". unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) Returns: the temperature, the location, the date and the unit in a dict """ return { "temperature": 25.9, "location": location, "date": date, "unit": unit, } def get_function_by_name(name): if name == "get_current_temperature": return get_current_temperature if name == "get_temperature_date": return get_temperature_date tools = [{ 'type': 'function', 'function': { 'name': 'get_current_temperature', 'description': 'Get current temperature at a location.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'unit': { 'type': 'string', 'enum': [ 'celsius', 'fahrenheit' ], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': [ 'location' ] } } }, { 'type': 'function', 'function': { 'name': 'get_temperature_date', 'description': 'Get temperature at a location and date.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'date': { 'type': 'string', 'description': 'The date to get the temperature for, in the format \'Year-Month-Day\'.' }, 'unit': { 'type': 'string', 'enum': [ 'celsius', 'fahrenheit' ], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': [ 'location', 'date' ] } } }] messages = [{'role': 'user', 'content': 'Today is 2024-11-14, What\'s the temperature in San Francisco now? How about tomorrow?'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response.choices[0].message.tool_calls) messages.append(response.choices[0].message) for tool_call in response.choices[0].message.tool_calls: tool_call_args = json.loads(tool_call.function.arguments) tool_call_result = get_function_by_name(tool_call.function.name)(**tool_call_args) messages.append({ 'role': 'tool', 'name': tool_call.function.name, 'content': tool_call_result, 'tool_call_id': tool_call.id }) response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response.choices[0].message.content) ``` Using the Qwen2.5-14B-Instruct, similar results can be obtained as follows ``` [ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"location": "San Francisco, California, USA"}', name='get_current_temperature'), type='function'), ChatCompletionMessageToolCall(id='1', function=Function(arguments='{"location": "San Francisco, California, USA", "date": "2024-11-15"}', name='get_temperature_date'), type='function')] The current temperature in San Francisco, California, USA is 26.1°C. For tomorrow, 2024-11-15, the temperature is expected to be 25.9°C. ``` It is important to note that in scenarios involving multiple tool calls, the order of the tool call results can affect the response quality. The tool_call_id has not been correctly provided to the LLM. ================================================ FILE: docs/en/llm/codellama.md ================================================ # codellama ## Introduction [codellama](https://github.com/facebookresearch/codellama) features enhanced coding capabilities. It can generate code and natural language about code, from both code and natural language prompts (e.g., “Write me a function that outputs the fibonacci sequence”). It can also be used for code completion and debugging. It supports many of the most popular programming languages used today, including Python, C++, Java, PHP, Typescript (Javascript), C#, Bash and more. There are three sizes (7b, 13b, 34b) as well as three flavours (base model, Python fine-tuned, and instruction tuned) released on [HuggingFace](https://huggingface.co/codellama). | Base Model | Python | Instruct | | ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | | [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf) | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) | | [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) | | [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) | The correspondence between the model and capabilities is: | models | code completion | infilling | instructions / chat | python specialist | | ---------- | --------------- | ----------------- | ------------------- | ----------------- | | Base Model | Y | Y(7B,13B), N(34B) | N | N | | Python | Y | N | N | Y | | Instruct | Y | Y(7B,13B), N(34B) | Y | N | ## Inference Based on the above table, this section shows how to utilize CodeLlama's capabilities by examples ### Completion ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='completion' )) response = pipe( 'import socket\n\ndef ping_exponential_backoff(host: str):', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ### Infilling ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='infilling' )) prompt = """ def remove_non_ascii(s: str) -> str: \"\"\" \"\"\" return result """ response = pipe( prompt, gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95, max_new_tokens=500 ) ) print(response.text) ``` ### Chat ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-Instruct-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='chat' )) response = pipe( 'implement quick sort in C++', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ### Python specialist ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-Python-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='python' )) response = pipe( 'implement quick sort', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ## Quantization TBD ## Serving Prepare a chat template json file, for instance "codellama.json", with the following content: ```json { "model_name": "codellama", "capability": "completion" } ``` Then launch the service as follows: ```shell lmdeploy serve api_server meta-llama/CodeLlama-7b-Instruct-hf --chat-template codellama.json ``` After the service is launched successfully, you can access the service with `openai` package: ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": "import socket\n\ndef ping_exponential_backoff(host: str):"}, ], temperature=0.1, top_p=0.95, max_tokens=500 ) print(response) ``` Regarding the detailed information of the api_server, you can refer to the [guide](../llm/api_server.md). ================================================ FILE: docs/en/llm/pipeline.md ================================================ # Offline Inference Pipeline In this tutorial, We will present a list of examples to introduce the usage of `lmdeploy.pipeline`. You can overview the detailed pipeline API in [this](https://lmdeploy.readthedocs.io/en/latest/api/pipeline.html) guide. ## Usage ### A 'Hello, world' example ```python from lmdeploy import pipeline pipe = pipeline('internlm/internlm2_5-7b-chat') response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` In this example, the pipeline by default allocates a predetermined percentage of GPU memory for storing k/v cache. The ratio is dictated by the parameter `TurbomindEngineConfig.cache_max_entry_count`. There have been alterations to the strategy for setting the k/v cache ratio throughout the evolution of LMDeploy. The following are the change histories: 1. `v0.2.0 <= lmdeploy <= v0.2.1` `TurbomindEngineConfig.cache_max_entry_count` defaults to 0.5, indicating 50% GPU **total memory** allocated for k/v cache. Out Of Memory (OOM) errors may occur if a 7B model is deployed on a GPU with memory less than 40G. If you encounter an OOM error, please decrease the ratio of the k/v cache occupation as follows: ```python from lmdeploy import pipeline, TurbomindEngineConfig # decrease the ratio of the k/v cache occupation to 20% backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` 2. `lmdeploy > v0.2.1` The allocation strategy for k/v cache is changed to reserve space from the **GPU free memory** proportionally. The ratio `TurbomindEngineConfig.cache_max_entry_count` has been adjusted to 0.8 by default. If OOM error happens, similar to the method mentioned above, please consider reducing the ratio value to decrease the memory usage of the k/v cache. ### Set tensor parallelism ```python from lmdeploy import pipeline, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` ### Set sampling parameters ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) print(response) ``` ### Apply OpenAI format prompt ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] response = pipe(prompts, gen_config=gen_config) print(response) ``` ### Apply streaming output ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] for item in pipe.stream_infer(prompts, gen_config=gen_config): print(item) ``` ### Get logits for generated tokens ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('internlm/internlm2_5-7b-chat') gen_config=GenerationConfig(output_logits='generation', max_new_tokens=10) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) logits = [x.logits for x in response] ``` ### Get last layer's hidden states for generated tokens ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('internlm/internlm2_5-7b-chat') gen_config=GenerationConfig(output_last_hidden_state='generation', max_new_tokens=10) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) hidden_states = [x.last_hidden_state for x in response] ``` ### Calculate ppl ```python from transformers import AutoTokenizer from lmdeploy import pipeline model_repoid_or_path = 'internlm/internlm2_5-7b-chat' pipe = pipeline(model_repoid_or_path) tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True) messages = [ {"role": "user", "content": "Hello, how are you?"}, ] input_ids = tokenizer.apply_chat_template(messages) # ppl is a list of float numbers ppl = pipe.get_ppl(input_ids) print(ppl) ``` ```{note} - When input_ids is too long, an OOM (Out Of Memory) error may occur. Please apply it with caution - get_ppl returns the cross entropy loss without applying the exponential operation afterwards ``` ### Use PyTorchEngine ```shell pip install triton>=2.1.0 ``` ```python from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig backend_config = PytorchEngineConfig(session_len=2048) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] response = pipe(prompts, gen_config=gen_config) print(response) ``` ### Inference with LoRA ```python from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig backend_config = PytorchEngineConfig(session_len=2048, adapters=dict(lora_name_1='chenchi/lora-chatglm2-6b-guodegang')) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('THUDM/chatglm2-6b', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': '您猜怎么着' }]] response = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1') print(response) ``` ### Release pipeline You can release the pipeline explicitly by calling its `close()` method, or alternatively, use the `with` statement as demonstrated below: ```python from lmdeploy import pipeline with pipeline('internlm/internlm2_5-7b-chat') as pipe: response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` ## FAQs - **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**. If you got this for tp>1 in pytorch backend. Please make sure the python script has following ```python if __name__ == '__main__': ``` Generally, in the context of multi-threading or multi-processing, it might be necessary to ensure that initialization code is executed only once. In this case, `if __name__ == '__main__':` can help to ensure that these initialization codes are run only in the main program, and not repeated in each newly created process or thread. - To customize a chat template, please refer to [chat_template.md](../advance/chat_template.md). - If the weight of lora has a corresponding chat template, you can first register the chat template to lmdeploy, and then use the chat template name as the adapter name. ================================================ FILE: docs/en/llm/proxy_server.md ================================================ # Request Distributor Server The request distributor service can parallelize multiple api_server services. Users only need to access the proxy URL, and they can indirectly access different api_server services. The proxy service will automatically distribute requests internally, achieving load balancing. ## Startup Start the proxy service: ```shell lmdeploy serve proxy --server-name {server_name} --server-port {server_port} --routing-strategy "min_expected_latency" --serving-strategy Hybrid ``` After startup is successful, the URL of the proxy service will also be printed by the script. Access this URL in your browser to open the Swagger UI. Subsequently, users can add it directly to the proxy service when starting the `api_server` service by using the `--proxy-url` command. For example: `lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。 In this way, users can access the services of the `api_server` through the proxy node, and the usage of the proxy node is exactly the same as that of the `api_server`, both of which are compatible with the OpenAI format. - /v1/models - /v1/chat/completions - /v1/completions ## Node Management Through Swagger UI, we can see multiple APIs. Those related to api_server node management include: - /nodes/status - /nodes/add - /nodes/remove They respectively represent viewing all api_server service nodes, adding a certain node, and deleting a certain node. ### Node Management through curl ```shell curl -X 'GET' \ 'http://localhost:8000/nodes/status' \ -H 'accept: application/json' ``` ```shell curl -X 'POST' \ 'http://localhost:8000/nodes/add' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "url": "http://0.0.0.0:23333" }' ``` ```shell curl -X 'POST' \ 'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \ -H 'accept: application/json' \ -d '' ``` ### Node Management through python ```python # query all nodes import requests url = 'http://localhost:8000/nodes/status' headers = {'accept': 'application/json'} response = requests.get(url, headers=headers) print(response.text) ``` ```python # add a new node import requests url = 'http://localhost:8000/nodes/add' headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } data = {"url": "http://0.0.0.0:23333"} response = requests.post(url, headers=headers, json=data) print(response.text) ``` ```python # delete a node import requests url = 'http://localhost:8000/nodes/remove' headers = {'accept': 'application/json',} params = {'node_url': 'http://0.0.0.0:23333',} response = requests.post(url, headers=headers, data='', params=params) print(response.text) ``` ## Serving Strategy LMDeploy currently supports two serving strategies: - Hybrid: Does not distinguish between Prefill and Decoding instances, following the traditional inference deployment mode. - DistServe: Separates Prefill and Decoding instances, deploying them on different service nodes to achieve more flexible and efficient resource scheduling and scalability. ## Dispatch Strategy The current distribution strategies of the proxy service are as follows: - random: dispatches based on the ability of each api_server node provided by the user to process requests. The greater the request throughput, the more likely it is to be allocated. Nodes that do not provide throughput are treated according to the average throughput of other nodes. - min_expected_latency: allocates based on the number of requests currently waiting to be processed on each node, and the throughput capability of each node, calculating the expected time required to complete the response. The shortest one gets allocated. Nodes that do not provide throughput are treated similarly. - min_observed_latency: allocates based on the average time required to handle a certain number of past requests on each node. The one with the shortest time gets allocated. ================================================ FILE: docs/en/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/en/multi_modal/api_server_vl.md ================================================ # OpenAI Compatible Server This article primarily discusses the deployment of a single large vision language model across multiple GPUs on a single node, providing a service that is compatible with the OpenAI interface, as well as the usage of the service API. For the sake of convenience, we refer to this service as `api_server`. Regarding parallel services with multiple models, please refer to the guide about [Request Distribution Server](../llm/proxy_server.md). In the following sections, we will first introduce two methods for starting the service, choosing the appropriate one based on your application scenario. Next, we focus on the definition of the service's RESTful API, explore the various ways to interact with the interface, and demonstrate how to try the service through the Swagger UI or LMDeploy CLI tools. ## Launch Service Take the [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) model hosted on huggingface hub as an example, you can choose one the following methods to start the service. ### Option 1: Launching with lmdeploy CLI ```shell lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b --server-port 23333 ``` The arguments of `api_server` can be viewed through the command `lmdeploy serve api_server -h`, for instance, `--tp` to set tensor parallelism, `--session-len` to specify the max length of the context window, `--cache-max-entry-count` to adjust the GPU mem ratio for k/v cache etc. ### Option 2: Deploying with docker With LMDeploy [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags), you can run OpenAI compatible server as follows: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b ``` The parameters of `api_server` are the same with that mentioned in "[option 1](#option-1-launching-with-lmdeploy-cli)" section Each model may require specific dependencies not included in the Docker image. If you run into issues, you may need to install those yourself on a case-by-case basis. If in doubt, refer to the specific model's project for documentation. For example, for Llava: ``` FROM openmmlab/lmdeploy:latest RUN apt-get update && apt-get install -y python3 python3-pip git WORKDIR /app RUN pip3 install --upgrade pip RUN pip3 install timm RUN pip3 install git+https://github.com/haotian-liu/LLaVA.git --no-deps COPY . . CMD ["lmdeploy", "serve", "api_server", "liuhaotian/llava-v1.6-34b"] ``` ## RESTful API LMDeploy's RESTful API is compatible with the following three OpenAI interfaces: - /v1/chat/completions - /v1/models - /v1/completions The interface for image interaction is `/v1/chat/completions`, which is consistent with OpenAI. You can overview and try out the offered RESTful APIs by the website `http://0.0.0.0:23333` as shown in the below image after launching the service successfully. ![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) If you need to integrate the service into your own projects or products, we recommend the following approach: ### Integrate with `OpenAI` Here is an example of interaction with the endpoint `v1/chat/completions` service via the openai package. Before running it, please install the openai package by `pip install openai` ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ### Integrate with lmdeploy `APIClient` Below are some examples demonstrating how to visit the service through `APIClient` If you want to use the `/v1/chat/completions` endpoint, you can try the following code: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient(f'http://0.0.0.0:23333') model_name = api_client.available_models[0] messages = [{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }] }] for item in api_client.chat_completions_v1(model=model_name, messages=messages): print(item) ``` ### Integrate with Java/Golang/Rust May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client. Here is an example: ```shell $ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust $ ls rust/* rust/Cargo.toml rust/git_push.sh rust/README.md rust/docs: ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md rust/src: apis lib.rs models ``` ================================================ FILE: docs/en/multi_modal/cogvlm.md ================================================ # CogVLM ## Introduction CogVLM is a powerful open-source visual language model (VLM). LMDeploy supports CogVLM-17B models like [THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf) and CogVLM2-19B models like [THUDM/cogvlm2-llama3-chat-19B](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B) in PyTorch engine. ## Quick Start Install LMDeploy by following the [installation guide](../get_started/installation.md) ### Prepare When deploying the **CogVLM** model using LMDeploy, it is necessary to download the model first, as the **CogVLM** model repository does not include the tokenizer model. However, this step is not required for **CogVLM2**. Taking one **CogVLM** model `cogvlm-chat-hf` as an example, you can prepare it as follows: ```shell huggingface-cli download THUDM/cogvlm-chat-hf --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False huggingface-cli download lmsys/vicuna-7b-v1.5 special_tokens_map.json tokenizer.model tokenizer_config.json --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False ``` ### Offline inference pipeline The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('cogvlm-chat-hf') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/en/multi_modal/deepseek_vl2.md ================================================ # DeepSeek-VL2 ## Introduction DeepSeek-VL2, an advanced series of large Mixture-of-Experts (MoE) Vision-Language Models that significantly improves upon its predecessor, DeepSeek-VL. DeepSeek-VL2 demonstrates superior capabilities across various tasks, including but not limited to visual question answering, optical character recognition, document/table/chart understanding, and visual grounding. LMDeploy supports [deepseek-vl2-tiny](https://huggingface.co/deepseek-ai/deepseek-vl2-tiny), [deepseek-vl2-small](https://huggingface.co/deepseek-ai/deepseek-vl2-small) and [deepseek-vl2](https://huggingface.co/deepseek-ai/deepseek-vl2) in PyTorch engine. ## Quick Start Install LMDeploy by following the [installation guide](../get_started/installation.md). ### Prepare When deploying the **DeepSeek-VL2** model using LMDeploy, you must install the official GitHub repository and related 3-rd party libs. This is because LMDeploy reuses the image processing functions provided in the official repository. ``` pip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git --no-deps pip install attrdict timm 'transformers<4.48.0' ``` Worth noticing that it may fail with `transformers>=4.48.0`, as known in this [issue](https://github.com/deepseek-ai/DeepSeek-VL2/issues/45). ### Offline inference pipeline The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md). To construct valid DeepSeek-VL2 prompts with image inputs, users should insert `` manually. ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('deepseek-ai/deepseek-vl2-tiny') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/en/multi_modal/gemma3.md ================================================ # Gemma3 ## Introduction Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. Gemma 3 models are multimodal, handling text and image input and generating text output, with open weights for both pre-trained variants and instruction-tuned variants. Gemma 3 has a large, 128K context window, multilingual support in over 140 languages, and is available in more sizes than previous versions. Gemma 3 models are well-suited for a variety of text generation and image understanding tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as laptops, desktops or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone. ## Quick Start Install LMDeploy by following the [installation guide](../get_started/installation.md). ### Prepare When deploying the **Gemma3** model using LMDeploy, please install the latest transformers. ### Offline inference pipeline The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md). ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('google/gemma-3-12b-it') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/en/multi_modal/index.rst ================================================ Vision-Language Models ================================= .. toctree:: :maxdepth: 2 :caption: Examples deepseek_vl2.md llava.md internvl.md xcomposer2d5.md cogvlm.md minicpmv.md phi3.md qwen2_vl.md qwen2_5_vl.md molmo.md gemma3.md ================================================ FILE: docs/en/multi_modal/internvl.md ================================================ # InternVL LMDeploy supports the following InternVL series of models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :-------------------: | :-----------: | :------------------------: | | InternVL | 13B-19B | TurboMind | | InternVL1.5 | 2B-26B | TurboMind, PyTorch | | InternVL2 | 4B | PyTorch | | InternVL2 | 1B-2B, 8B-76B | TurboMind, PyTorch | | InternVL2.5/2.5-MPO/3 | 1B-78B | TurboMind, PyTorch | | Mono-InternVL | 2B | PyTorch | The next chapter demonstrates how to deploy an InternVL model using LMDeploy, with [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that InternVL2 needs ```shell pip install timm # It is recommended to find the whl package that matches the environment from the releases on https://github.com/Dao-AILab/flash-attention. pip install flash-attn ``` Or, you can build a docker image to set up the inference environment. If the CUDA version on your host machine is `>=12.4`, you can run: ``` docker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile ``` Otherwise, you can go with: ```shell git clone https://github.com/InternLM/lmdeploy.git cd lmdeploy docker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile ``` ## Offline inference The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` More examples are listed below:
multi-image multi-round conversation, combined images ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
multi-image multi-round conversation, separate images ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
video multi-round conversation ```python import numpy as np from lmdeploy import pipeline, GenerationConfig from decord import VideoReader, cpu from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import encode_image_base64 from PIL import Image pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(video_path, bound=None, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) pixel_values_list, num_patches_list = [], [] frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) imgs = [] for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') imgs.append(img) return imgs video_path = 'red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' for i in range(len(imgs)): question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' question += 'What is the red panda doing?' content = [{'type': 'text', 'text': question}] for img in imgs: content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}}) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='Describe this video in detail. Don\'t repeat.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## Online serving You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` You can also start the service using the aforementioned built docker image: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:internvl \ lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` The docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:internvl ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server OpenGVLab/InternVL2-8B deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` Then, you can execute the startup command as below: ```shell docker-compose up -d ``` If you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully. ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) ================================================ FILE: docs/en/multi_modal/llava.md ================================================ # LLaVA LMDeploy supports the following llava series of models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :----------------------------------: | :--: | :------------------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | | llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | | liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | | liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | The next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example. ```{note} PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf ``` ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md). Or, you can go with office docker image: ```shell docker pull openmmlab/lmdeploy:latest ``` ## Offline inference The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import GenerationConfig, TurbomindEngineConfig, pipeline from lmdeploy.vl import load_image pipe = pipeline("llava-hf/llava-interleave-qwen-7b-hf", backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5), gen_config=GenerationConfig(max_new_tokens=512)) image = load_image('https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg') prompt = 'Describe the image.' print(f'prompt:{prompt}') response = pipe((prompt, image)) print(response) ``` More examples are listed below:
multi-image multi-round conversation, combined images ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('llava-hf/llava-interleave-qwen-7b-hf', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## Online serving You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf ``` You can also start the service using the aforementioned built docker image: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf ``` The docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:latest ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` Then, you can execute the startup command as below: ```shell docker-compose up -d ``` If you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully. ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) ================================================ FILE: docs/en/multi_modal/minicpmv.md ================================================ # MiniCPM-V LMDeploy supports the following MiniCPM-V series of models, which are detailed in the table below: | Model | Supported Inference Engine | | :------------------: | :------------------------: | | MiniCPM-Llama3-V-2_5 | TurboMind | | MiniCPM-V-2_6 | TurboMind | The next chapter demonstrates how to deploy an MiniCPM-V model using LMDeploy, with [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md). ## Offline inference The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('openbmb/MiniCPM-V-2_6') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` More examples are listed below:
Chat with multiple images ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
In-context few-shot learning ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') question = "production date" messages = [ dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='example1.jpg')), ]), dict(role='assistant', content='2023.08.04'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='example2.jpg')), ]), dict(role='assistant', content='2007.04.24'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='test.jpg')), ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
Chat with video ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl import encode_image_base64 import torch from PIL import Image from transformers import AutoModel, AutoTokenizer from decord import VideoReader, cpu # pip install decord pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') MAX_NUM_FRAMES=64 # if cuda OOM set a smaller number def encode_video(video_path): def uniform_sample(l, n): gap = len(l) / n idxs = [int(i * gap + gap / 2) for i in range(n)] return [l[i] for i in idxs] vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] if len(frame_idx) > MAX_NUM_FRAMES: frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype('uint8')) for v in frames] print('num frames:', len(frames)) return frames video_path="video_test.mp4" frames = encode_video(video_path) question = "Describe the video" content=[dict(type='text', text=question)] for frame in frames: content.append(dict(type='image_url', image_url=dict(use_image_id=False, max_slice_nums=2, url=f'data:image/jpeg;base64,{encode_image_base64(frame)}'))) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
## Online serving You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server openbmb/MiniCPM-V-2_6 ``` You can also start the service using the official lmdeploy docker image: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server openbmb/MiniCPM-V-2_6 ``` The docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:latest ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server openbmb/MiniCPM-V-2_6 deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` Then, you can execute the startup command as below: ```shell docker-compose up -d ``` If you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully. ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) ================================================ FILE: docs/en/multi_modal/molmo.md ================================================ # Molmo LMDeploy supports the following molmo series of models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :-------------: | :--: | :------------------------: | | Molmo-7B-D-0924 | 7B | TurboMind | | Molmo-72-0924 | 72B | TurboMind | The next chapter demonstrates how to deploy a molmo model using LMDeploy, with [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md) ## Offline inference The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('allenai/Molmo-7B-D-0924') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` More examples are listed below:
multi-image multi-round conversation, combined images ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('allenai/Molmo-7B-D-0924', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(do_sample=False)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(do_sample=False)) ```
## Online serving You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server allenai/Molmo-7B-D-0924 ``` You can also start the service using the docker image: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server allenai/Molmo-7B-D-0924 ``` If you find the following logs, it means the service launches successfully. ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) ================================================ FILE: docs/en/multi_modal/phi3.md ================================================ # Phi-3 Vision ## Introduction [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) is a family of small language and multi-modal models from MicroSoft. LMDeploy supports the multi-modal models as below. | Model | Size | Supported Inference Engine | | :-------------------------------------------------------------------------------------------------: | :--: | :------------------------: | | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | 4.2B | PyTorch | | [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct) | 4.2B | PyTorch | The next chapter demonstrates how to deploy an Phi-3 model using LMDeploy, with [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md) and install the dependency [Flash-Attention](https://github.com/Dao-AILab/flash-attention) ```shell # It is recommended to find the whl package that matches the environment from the releases on https://github.com/Dao-AILab/flash-attention. pip install flash-attn ``` ## Offline inference The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('microsoft/Phi-3.5-vision-instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## Online serving ### Launch Service You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server microsoft/Phi-3.5-vision-instruct ``` ### Integrate with `OpenAI` Here is an example of interaction with the endpoint `v1/chat/completions` service via the openai package. Before running it, please install the openai package by `pip install openai` ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ================================================ FILE: docs/en/multi_modal/qwen2_5_vl.md ================================================ # Qwen2.5-VL LMDeploy supports the following Qwen-VL series of models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :--------: | :--------------: | :------------------------: | | Qwen2.5-VL | 3B, 7B, 32B, 72B | PyTorch | The next chapter demonstrates how to deploy a Qwen-VL model using LMDeploy, with [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that Qwen2.5-VL needs ```shell # Qwen2.5-VL requires the latest transformers (transformers >= 4.49.0) pip install git+https://github.com/huggingface/transformers # It's highly recommended to use `[decord]` feature for faster video loading. pip install qwen-vl-utils[decord]==0.0.8 ``` ## Offline inference The following sample code shows the basic usage of the VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` More examples are listed below:
multi-image multi-round conversation, combined images ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
image resolution for performance boost ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') min_pixels = 64 * 28 * 28 max_pixels = 64 * 28 * 28 messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
video multi-round conversation ```python import numpy as np from lmdeploy import pipeline, GenerationConfig from decord import VideoReader, cpu from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import encode_image_base64 from PIL import Image pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(video_path, bound=None, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) pixel_values_list, num_patches_list = [], [] frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) imgs = [] for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') imgs.append(img) return imgs video_path = 'red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' for i in range(len(imgs)): question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' question += 'What is the red panda doing?' content = [{'type': 'text', 'text': question}] for img in imgs: content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}}) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='Describe this video in detail. Don\'t repeat.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
================================================ FILE: docs/en/multi_modal/qwen2_vl.md ================================================ # Qwen2-VL LMDeploy supports the following Qwen-VL series of models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | | Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | The next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example. ## Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that Qwen2-VL needs ```shell pip install qwen_vl_utils ``` Or, you can build a docker image to set up the inference environment. If the CUDA version on your host machine is `>=12.4`, you can run: ``` git clone https://github.com/InternLM/lmdeploy.git cd lmdeploy docker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile ``` Otherwise, you can go with: ```shell docker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile ``` ## Offline inference The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` More examples are listed below:
multi-image multi-round conversation, combined images ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
image resolution for performance boost ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') min_pixels = 64 * 28 * 28 max_pixels = 64 * 28 * 28 messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## Online serving You can launch the server by the `lmdeploy serve api_server` CLI: ```shell lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` You can also start the service using the aforementioned built docker image: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:qwen2vl \ lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` The docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:qwen2vl ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` Then, you can execute the startup command as below: ```shell docker-compose up -d ``` If you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully. ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) ================================================ FILE: docs/en/multi_modal/vl_pipeline.md ================================================ # Offline Inference Pipeline LMDeploy abstracts the complex inference process of multi-modal Vision-Language Models (VLM) into an easy-to-use pipeline, similar to the Large Language Model (LLM) inference [pipeline](../llm/pipeline.md). The supported models are listed [here](../supported_models/supported_models.md). We genuinely invite the community to contribute new VLM support to LMDeploy. Your involvement is truly appreciated. This article showcases the VLM pipeline using the [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) model as a case study. You'll learn about the simplest ways to leverage the pipeline and how to gradually unlock more advanced features by adjusting engine parameters and generation arguments, such as tensor parallelism, context window sizing, random sampling, and chat template customization. Moreover, we will provide practical inference examples tailored to scenarios with multiple images, batch prompts etc. Using the pipeline interface to infer other VLM models is similar, with the main difference being the configuration and installation dependencies of the models. You can read [here](https://lmdeploy.readthedocs.io/en/latest/multi_modal/index.html) for environment installation and configuration methods for different models. ## A 'Hello, world' example ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` If `ImportError` occurs while executing this case, please install the required dependency packages as prompted. In the above example, the inference prompt is a tuple structure consisting of (prompt, image). Besides this structure, the pipeline also supports prompts in the OpenAI format: ```python from lmdeploy import pipeline pipe = pipeline('OpenGVLab/InternVL2_5-8B') prompts = [ { 'role': 'user', 'content': [ {'type': 'text', 'text': 'describe this image'}, {'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}} ] } ] response = pipe(prompts) print(response) ``` ### Set tensor parallelism Tensor paramllelism can be activated by setting the engine parameter `tp` ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(tp=2)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### Set context window size When creating the pipeline, you can customize the size of the context window by setting the engine parameter `session_len`. ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### Set sampling parameters You can change the default sampling parameters of pipeline by passing `GenerationConfig` ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(tp=2, session_len=8192)) gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image), gen_config=gen_config) print(response) ``` ### Customize image token position By default, LMDeploy inserts the special image token into the user prompt following the chat template defined by the upstream algorithm repository. However, for certain models where the image token's position is unrestricted, such as deepseek-vl, or when users require a customized image token placement, manual insertion of the special image token into the prompt is necessary. LMDeploy use `` as the special image token. ```python from lmdeploy import pipeline from lmdeploy.vl import load_image from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('deepseek-ai/deepseek-vl-1.3b-chat') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image{IMAGE_TOKEN}', image)) print(response) ``` ### Set chat template While performing inference, LMDeploy identifies an appropriate chat template from its builtin collection based on the model path and subsequently applies this template to the input prompts. However, when a chat template cannot be told from its model path, users have to specify it. For example, [liuhaotian/llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) employs the ['llava-v1'](https://github.com/haotian-liu/LLaVA/blob/v1.2.2/llava/conversation.py#L325-L335) chat template, if user have a custom folder name instead of the official 'llava-v1.5-7b', the user needs to specify it by setting 'llava-v1' to `ChatTemplateConfig` as follows: ```python from lmdeploy import pipeline, ChatTemplateConfig from lmdeploy.vl import load_image pipe = pipeline('local_model_folder', chat_template_config=ChatTemplateConfig(model_name='llava-v1')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` For more information about customizing a chat template, please refer to [this](../advance/chat_template.md) guide ### Setting vision model parameters The default parameters of the visual model can be modified by setting `VisionConfig`. ```python from lmdeploy import pipeline, VisionConfig from lmdeploy.vl import load_image vision_config=VisionConfig(max_batch_size=16) pipe = pipeline('liuhaotian/llava-v1.5-7b', vision_config=vision_config) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### Output logits for generated tokens ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image), gen_config=GenerationConfig(output_logits='generation')) logits = response.logits print(logits) ``` ## Multi-images inference When dealing with multiple images, you can put them all in one list. Keep in mind that multiple images will lead to a higher number of input tokens, and as a result, the size of the [context window](#set-context-window-size) typically needs to be increased. ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image_urls=[ 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg', 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg' ] images = [load_image(img_url) for img_url in image_urls] response = pipe(('describe these images', images)) print(response) ``` ## Batch prompts inference Conducting inference with batch prompts is quite straightforward; just place them within a list structure: ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image_urls=[ "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg", "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg" ] prompts = [('describe this image', load_image(img_url)) for img_url in image_urls] response = pipe(prompts) print(response) ``` ## Multi-turn conversation There are two ways to do the multi-turn conversations with the pipeline. One is to construct messages according to the format of OpenAI and use above introduced method, the other is to use the `pipeline.chat` interface. ```python from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg') gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.8) sess = pipe.chat(('describe this image', image), gen_config=gen_config) print(sess.response.text) sess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config) print(sess.response.text) ``` ## Release pipeline You can release the pipeline explicitly by calling its `close()` method, or alternatively, use the `with` statement as demonstrated below: ```python from lmdeploy import pipeline from lmdeploy import pipeline from lmdeploy.vl import load_image with pipeline('OpenGVLab/InternVL2_5-8B') as pipe: image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) # Clear the torch cache and perform garbage collection if needed import torch import gc torch.cuda.empty_cache() gc.collect() ``` ================================================ FILE: docs/en/multi_modal/xcomposer2d5.md ================================================ # InternLM-XComposer-2.5 ## Introduction [InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) excels in various text-image comprehension and composition applications, achieving GPT-4V level capabilities with merely 7B LLM backend. IXC-2.5 is trained with 24K interleaved image-text contexts, it can seamlessly extend to 96K long contexts via RoPE extrapolation. This long-context capability allows IXC-2.5 to perform exceptionally well in tasks requiring extensive input and output contexts. LMDeploy supports model [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) in TurboMind engine. ## Quick Start ### Installation Please install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that InternLM-XComposer-2.5 needs ```shell pip install decord ``` ### Offline inference pipeline The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('internlm/internlm-xcomposer2d5-7b') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` ## Lora Model InternLM-XComposer-2.5 trained the LoRA weights for webpage creation and article writing. As TurboMind backend doesn't support slora, only one LoRA model can be deployed at a time, and the LoRA weights need to be merged when deploying the model. LMDeploy provides the corresponding conversion script, which is used as follows: ``` export HF_MODEL=internlm/internlm-xcomposer2d5-7b export WORK_DIR=internlm/internlm-xcomposer2d5-7b-web export TASK=web python -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK ``` ## Quantization The following takes the base model as an example to show the quantization method. If you want to use the LoRA model, please merge the LoRA model according to the previous section. ```shell export HF_MODEL=internlm/internlm-xcomposer2d5-7b export WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit lmdeploy lite auto_awq \ $HF_MODEL \ --work-dir $WORK_DIR ``` ## More examples
Video Understanding The following uses the `pipeline.chat` interface api as an example to demonstrate its usage. Other interfaces apis also support inference but require manually splicing of conversation content. ```python from lmdeploy import pipeline, GenerationConfig from transformers.dynamic_module_utils import get_class_from_dynamic_module HF_MODEL = 'internlm/internlm-xcomposer2d5-7b' load_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL) frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL) Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL) get_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL) video = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4 img = frame2img(video, get_font()) img = Video_transform(img) pipe = pipeline(HF_MODEL) gen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0) query = 'Here are some frames of a video. Describe this video in detail' sess = pipe.chat((query, img), gen_config=gen_config) print(sess.response.text) query = 'tell me the athlete code of Liu Xiang' sess = pipe.chat(query, session=sess, gen_config=gen_config) print(sess.response.text) ```
Multi-Image ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import load_image query = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one' urls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg', 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg', 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg'] images = [load_image(url) for url in urls] pipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO') output = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939)) ``` Since LMDeploy does not support beam search, the generated results will be quite different from those using beam search with transformers. It is recommended to turn off top_k or use a larger top_k sampling to increase diversity.
Instruction to Webpage Please first convert the web model using the instructions above. ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO') pipe.chat_template.meta_instruction = None query = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.' output = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048)) ``` When using transformers for testing, it is found that if repetition_penalty is set, there is a high probability that the decode phase will not stop if `num_beams` is set to 1. As LMDeploy does not support beam search, it is recommended to turn off repetition_penalty when using LMDeploy for inference.
Write Article Please first convert the write model using the instructions above. ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO') pipe.chat_template.meta_instruction = None query = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence' output = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192)) ```
================================================ FILE: docs/en/quantization/kv_quant.md ================================================ # INT4/INT8 KV Cache Since v0.4.0, LMDeploy has supported **online** key-value (kv) cache quantization with int4 and int8 numerical precision, utilizing an asymmetric quantization method that is applied on a per-head, per-token basis. The original kv offline quantization method has been removed. Intuitively, quantization is beneficial for increasing the number of kv block. Compared to fp16, the number of kv block for int4/int8 kv can be increased by 4 times and 2 times respectively. This means that under the same memory conditions, the system can support a significantly increased number of concurrent operations after kv quantization, thereby ultimately enhancing throughput. However, quantization typically brings in some loss of model accuracy. We have used OpenCompass to evaluate the accuracy of several models after applying int4/int8 quantization. int8 kv keeps the accuracy while int4 kv has slight loss. The detailed results are presented in the [Evaluation](#evaluation) section. You can refer to the information and choose wisely based on your requirements. LMDeploy inference with quantized kv supports the following NVIDIA GPU models: - Volta architecture (sm70): V100 - Turing architecture (sm75): 20 series, T4 - Ampere architecture (sm80, sm86): 30 series, A10, A16, A30, A100 - Ada Lovelace architecture (sm89): 40 series - Hopper architecture (sm90): H100, H200 In summary, LMDeploy kv quantization has the following advantages: 1. data-free online quantization 2. Supports all nvidia GPU models with Volta architecture (sm70) and above 3. KV int8 quantization has almost lossless accuracy, and KV int4 quantization accuracy is within an acceptable range 4. Efficient inference, with int8/int4 kv quantization applied to llama2-7b, RPS is improved by round 30% and 40% respectively compared to fp16 In the next section, we will take `internlm2-chat-7b` model as an example, introducing the usage of kv quantization and inference of lmdeploy. But before that, please ensure that lmdeploy is installed. ```shell pip install lmdeploy ``` ## Usage Applying kv quantization and inference via LMDeploy is quite straightforward. Simply set the `quant_policy` parameter. **LMDeploy specifies that `quant_policy=4` stands for 4-bit kv, whereas `quant_policy=8` indicates 8-bit kv.** ### Offline inference ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig(quant_policy=8) pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` ### Serving ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --quant-policy 8 ``` ## Evaluation We apply kv quantization of LMDeploy to several LLM models and utilize OpenCompass to evaluate the inference accuracy. The results are shown in the table below: | - | - | - | llama2-7b-chat | - | - | internlm2-chat-7b | - | - | internlm2.5-chat-7b | - | - | qwen1.5-7b-chat | - | - | | ----------- | ------- | ------------- | -------------- | ------- | ------- | ----------------- | ------- | ------- | ------------------- | ------- | ------- | --------------- | ------- | ------- | | dataset | version | metric | kv fp16 | kv int8 | kv int4 | kv fp16 | kv int8 | kv int4 | kv fp16 | kv int8 | kv int4 | fp16 | kv int8 | kv int4 | | ceval | - | naive_average | 28.42 | 27.96 | 27.58 | 60.45 | 60.88 | 60.28 | 78.06 | 77.87 | 77.05 | 70.56 | 70.49 | 68.62 | | mmlu | - | naive_average | 35.64 | 35.58 | 34.79 | 63.91 | 64 | 62.36 | 72.30 | 72.27 | 71.17 | 61.48 | 61.56 | 60.65 | | triviaqa | 2121ce | score | 56.09 | 56.13 | 53.71 | 58.73 | 58.7 | 58.18 | 65.09 | 64.87 | 63.28 | 44.62 | 44.77 | 44.04 | | gsm8k | 1d7fe4 | accuracy | 28.2 | 28.05 | 27.37 | 70.13 | 69.75 | 66.87 | 85.67 | 85.44 | 83.78 | 54.97 | 56.41 | 54.74 | | race-middle | 9a54b6 | accuracy | 41.57 | 41.78 | 41.23 | 88.93 | 88.93 | 88.93 | 92.76 | 92.83 | 92.55 | 87.33 | 87.26 | 86.28 | | race-high | 9a54b6 | accuracy | 39.65 | 39.77 | 40.77 | 85.33 | 85.31 | 84.62 | 90.51 | 90.42 | 90.42 | 82.53 | 82.59 | 82.02 | For detailed evaluation methods, please refer to [this](../benchmark/evaluate_with_opencompass.md) guide. Remember to pass `quant_policy` to the inference engine in the config file. ## Performance | model | kv type | test settings | RPS | v.s. kv fp16 | | ----------------- | ------- | ---------------------------------------- | ----- | ------------ | | llama2-chat-7b | fp16 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 14.98 | 1.0 | | - | int8 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 19.01 | 1.27 | | - | int4 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 20.81 | 1.39 | | llama2-chat-13b | fp16 | tp1 / ratio 0.9 / bs 128 / prompts 10000 | 8.55 | 1.0 | | - | int8 | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 10.96 | 1.28 | | - | int4 | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 11.91 | 1.39 | | internlm2-chat-7b | fp16 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 24.13 | 1.0 | | - | int8 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.28 | 1.05 | | - | int4 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.80 | 1.07 | The performance data is obtained by `benchmark/profile_throughput.py` ================================================ FILE: docs/en/quantization/llm_compressor.md ================================================ # llm-compressor Support This guide aims to introduce how to use LMDeploy's TurboMind inference engine to run models quantized by the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) tool. Currently supported `llm-compressor` quantization types include: - int4 quantization (e.g., AWQ, GPTQ) These quantized models can run via the TurboMind engine on the following NVIDIA GPU architectures: | Compute Capability | Micro-architecture | GPUs | | ------------------ | ------------------ | ------------------------------- | | 7.0 | Volta | V100 | | 7.2 | Volta | Jetson Xavier | | 7.5 | Turing | GeForce RTX 20 series, T4 | | 8.0 | Ampere | A100, A800, A30 | | 8.6 | Ampere | GeForce RTX 30 series, A40, A10 | | 8.7 | Ampere | Jetson Orin | | 8.9 | Ada Lovelace | GeForce RTX 40 series, L40, L20 | | 9.0 | Hopper | H20, H200, H100, GH200 | | 12.0 | Blackwell | GeForce RTX 50 series | LMDeploy will continue to follow up and expand support for the `llm-compressor` project. The remainder of this document consists of the following sections: - [Model Quantization](#model-quantization) - [Model Deployment](#model-deployment) - [Accuracy Evaluation](#accuracy-evaluation) ## Model Quantization `llm-compressor` provides a wealth of model quantization [examples](https://github.com/vllm-project/llm-compressor/tree/main/examples). Please refer to its tutorials to select a quantization algorithm supported by LMDeploy to complete your model quantization work. LMDeploy also provides a built-in [script](https://github.com/InternLM/lmdeploy/blob/main/examples/lite/qwen3_30b_a3b_awq.py) for AWQ quantization of **Qwen3-30B-A3B** using `llm-compressor` for your reference: ```shell # Create conda environment conda create -n lmdeploy python=3.10 -y conda activate lmdeploy # Install llm-compressor pip install llmcompressor # Clone lmdeploy source code and run the quantization example git clone https://github.com/InternLM/lmdeploy cd lmdeploy python examples/lite/qwen3_30b_a3b_awq.py --work-dir ./qwen3_30b_a3b_awq ``` In the following sections, we will use this quantized model as an example to introduce model deployment and accuracy evaluation methods. ## Model Deployment ### Offline Inference With the quantized model, offline batch processing can be implemented with just a few lines of code: ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig() with pipeline("./qwen3_30b_a3b_4bit", backend_config=engine_config) as pipe: response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` For a detailed introduction to the pipeline, please refer to [here](https://lmdeploy.readthedocs.io/en/latest/llm/pipeline.html). ### Online Serving LMDeploy api_server supports encapsulating the model as a service with a single command. The provided RESTful APIs are compatible with OpenAI interfaces. Below is an example of starting the service: ```shell lmdeploy serve api_server ./qwen3_30b_a3b_4bit --backend turbomind ``` The default service port is 23333. After the server starts, you can access the service via the OpenAI SDK. For command arguments and methods to access the service, please read [this](https://lmdeploy.readthedocs.io/en/latest/llm/api_server.html) document. ## Accuracy Evaluation Aftering deploying AWQ symmetric/asymmetric quantized models of Qwen3-8B (Dense) and Qwen3-30B-A3B (MoE) as services via LMDeploy, we evaluated their accuracy on several academic datasets using [opencompass](https://github.com/open-compass/opencompass). Results indicate that, for Qwen3-8B, asymmetric quantization generally outperforms symmetric quantization, while Qwen3-30B-A3B shows no substantial difference between symmetric and asymmetric quantization. Compared with BF16, Qwen3-8B shows a smaller accuracy gap under both symmetric and asymmetric quantization than Qwen3-30B-A3B. Compared with BF16, accuracy drops significantly on long-output datasets such as aime2025 (avg 17,635 tokens) and LCB (avg 14,157 tokens), while on medium/short-output datasets like ifeval (avg 1,885 tokens) and mmlu_pro (avg 2,826 tokens), the accuracy is as expected. | dataset | Qwen3-8B | | | Qwen3-30B-A3B | | | | ----------------- | -------- | ------- | -------- | ------------- | ------- | -------- | | | bf16 | awq sym | awq asym | bf16 | awq sym | awq asym | | ifeval | 85.58 | 83.73 | 85.77 | 86.32 | 84.10 | 84.29 | | hle | 5.05 | 5.05 | 5.24 | 7.00 | 5.47 | 5.65 | | gpqa | 59.97 | 56.57 | 59.47 | 61.74 | 57.95 | 57.07 | | aime2025 | 69.48 | 64.38 | 63.96 | 73.44 | 64.79 | 66.67 | | mmlu_pro | 73.69 | 71.73 | 72.34 | 77.85 | 75.77 | 75.69 | | LCBCodeGeneration | 50.86 | 44.10 | 46.95 | 56.67 | 50.86 | 49.24 | For reproduction methods, please refer to [this](https://lmdeploy.readthedocs.io/en/latest/benchmark/evaluate_with_opencompass.html) document. ================================================ FILE: docs/en/quantization/w4a16.md ================================================ # AWQ/GPTQ LMDeploy TurboMind engine supports the inference of 4bit quantized models that are quantized both by [AWQ](https://arxiv.org/abs/2306.00978) and [GPTQ](https://github.com/AutoGPTQ/AutoGPTQ), but its quantization module only supports the AWQ quantization algorithm. The following NVIDIA GPUs are available for AWQ/GPTQ INT4 inference: - V100(sm70): V100 - Turing(sm75): 20 series, T4 - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 - Ada Lovelace(sm89): 40 series Before proceeding with the quantization and inference, please ensure that lmdeploy is installed by following the [installation guide](../get_started/installation.md) The remainder of this article is structured into the following sections: - [Quantization](#quantization) - [Evaluation](#evaluation) - [Inference](#inference) - [Service](#service) - [Performance](#performance) ## Quantization A single command execution is all it takes to quantize the model. The resulting quantized weights are then stored in the $WORK_DIR directory. ```shell export HF_MODEL=internlm/internlm2_5-7b-chat export WORK_DIR=internlm/internlm2_5-7b-chat-4bit lmdeploy lite auto_awq \ $HF_MODEL \ --calib-dataset 'wikitext2' \ --calib-samples 128 \ --calib-seqlen 2048 \ --w-bits 4 \ --w-group-size 128 \ --batch-size 1 \ --work-dir $WORK_DIR ``` Typically, the above command doesn't require filling in optional parameters, as the defaults usually suffice. For instance, when quantizing the [internlm/internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) model, the command can be condensed as: ```shell lmdeploy lite auto_awq internlm/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-4bit ``` **Note:** - We recommend that you specify the --work-dir parameter, including the model name as demonstrated in the example above. This facilitates LMDeploy in fuzzy matching the --work-dir with an appropriate built-in chat template. Otherwise, you will have to designate the chat template during inference. - If the quantized model’s accuracy is compromised, it is recommended to enable --search-scale for re-quantization and increase the --batch-size, for example, to 8. When search_scale is enabled, the quantization process will take more time. The --batch-size affects the amount of memory used, which can be adjusted according to actual conditions as needed. Upon completing quantization, you can engage with the model efficiently using a variety of handy tools. For example, you can initiate a conversation with it via the command line: ```shell lmdeploy chat ./internlm2_5-7b-chat-4bit --model-format awq ``` ## Evaluation Please refer to [OpenCompass](https://opencompass.readthedocs.io/en/latest/index.html) about model evaluation with LMDeploy. Here is the [guide](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html) ## Inference Trying the following codes, you can perform the batched offline inference with the quantized model: ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig(model_format='awq') pipe = pipeline("./internlm2_5-7b-chat-4bit", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` For more information about the pipeline parameters, please refer to [here](../llm/pipeline.md). In addition to performing inference with the quantized model on localhost, LMDeploy can also execute inference for the 4bit quantized model derived from AWQ algorithm available on Huggingface Hub, such as models from the [lmdeploy space](https://huggingface.co/lmdeploy) and [TheBloke space](https://huggingface.co/TheBloke) ```python # inference with models from lmdeploy space from lmdeploy import pipeline, TurbomindEngineConfig pipe = pipeline("lmdeploy/llama2-chat-70b-4bit", backend_config=TurbomindEngineConfig(model_format='awq', tp=4)) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) # inference with models from thebloke space from lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig pipe = pipeline("TheBloke/LLaMA2-13B-Tiefighter-AWQ", backend_config=TurbomindEngineConfig(model_format='awq'), chat_template_config=ChatTemplateConfig(model_name='llama2') ) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` ## Service LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup: ```shell lmdeploy serve api_server ./internlm2_5-7b-chat-4bit --backend turbomind --model-format awq ``` The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`: ```shell lmdeploy serve api_client http://0.0.0.0:23333 ``` You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md). ## Performance We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quantization on NVIDIA GeForce RTX 4090. And we measure the token generation throughput (tokens/s) by setting a single prompt token and generating 512 tokens. All the results are measured for single batch inference. | model | llm-awq | mlc-llm | turbomind | | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | ## FAQs 1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1. ================================================ FILE: docs/en/quantization/w8a8.md ================================================ # SmoothQuant LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8). And the following NVIDIA GPUs are available for INT8/FP8 inference respectively: - INT8 - V100(sm70): V100 - Turing(sm75): 20 series, T4 - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 - Ada Lovelace(sm89): 40 series - Hopper(sm90): H100 - FP8 - Ada Lovelace(sm89): 40 series - Hopper(sm90): H100 First of all, run the following command to install lmdeploy: ```shell pip install lmdeploy[all] ``` ## 8-bit Weight Quantization Performing 8-bit weight quantization involves three steps: 1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing. 2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file. 3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model. lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help` Here are two examples: - int8 ```shell lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 ``` - fp8 ```shell lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 ``` ## Inference Trying the following codes, you can perform the batched offline inference with the quantized model: ```python from lmdeploy import pipeline, PytorchEngineConfig engine_config = PytorchEngineConfig(tp=1) pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` ## Service LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup: ```shell lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch ``` The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`: ```shell lmdeploy serve api_client http://0.0.0.0:23333 ``` You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md). ================================================ FILE: docs/en/supported_models/reward_models.md ================================================ # Reward Models LMDeploy supports reward models, which are detailed in the table below: | Model | Size | Supported Inference Engine | | :--------------: | :-----------: | :------------------------: | | Qwen2.5-Math-RM | 72B | PyTorch | | InternLM2-Reward | 1.8B, 7B, 20B | PyTorch | | POLAR | 1.8B, 7B | PyTorch | ## Offline Inference We take `internlm/internlm2-1_8b-reward` as an example: ```python from transformers import AutoTokenizer from lmdeploy import pipeline, PytorchEngineConfig model_path = "internlm/internlm2-1_8b-reward" chat = [ {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, {"role": "user", "content": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"}, {"role": "assistant", "content": "To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to follow these steps:\n\n1. Calculate the total number of eggs laid by the ducks each day.\n2. Determine how many eggs Janet eats and bakes for herself each day.\n3. Find out how many eggs are left to be sold.\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\n\nLet's start with the first step:\n\n1. Janet's ducks lay 16 eggs per day.\n\nNext, we calculate how many eggs Janet eats and bakes for herself each day:\n\n2. Janet eats 3 eggs for breakfast every morning.\n3. Janet bakes 4 eggs for her friends every day.\n\nSo, the total number of eggs Janet eats and bakes for herself each day is:\n\\[ 3 + 4 = 7 \\text{ eggs} \\]\n\nNow, we find out how many eggs are left to be sold:\n\\[ 16 - 7 = 9 \\text{ eggs} \\]\n\nFinally, we calculate the revenue from selling the remaining eggs at $2 per egg:\n\\[ 9 \\times 2 = 18 \\text{ dollars} \\]\n\nTherefore, Janet makes 18 dollars every day at the farmers' market."} ] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) conversation_str = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=False ) input_ids = tokenizer.encode( conversation_str, add_special_tokens=False ) if __name__ == '__main__': engine_config = PytorchEngineConfig(tp=tp) with pipeline(model_path, backend_config=engine_config) as pipe: score = pipe.get_reward_score(input_ids) print(f'score: {score}') ``` ## Online Inference Start the API server: ```bash lmdeploy serve api_server internlm/internlm2-1_8b-reward --backend pytorch ``` Get the reward score from the `/pooling` API endpoint: ``` curl http://0.0.0.0:23333/pooling \ -H "Content-Type: application/json" \ -d '{ "model": "internlm/internlm2-1_8b-reward", "input": "Who are you?" }' ``` ================================================ FILE: docs/en/supported_models/supported_models.md ================================================ # Supported Models The following tables detail the models supported by LMDeploy's TurboMind engine and PyTorch engine across different platforms. ## TurboMind on CUDA Platform | Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | | :------------------------------: | :--------------: | :--: | :-------: | :-----: | :-----: | :---: | | Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | | InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Intern-S1 | 241B | MLLM | Yes | Yes | Yes | No | | Intern-S1-mini | 8.3B | MLLM | Yes | Yes | Yes | No | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | | Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | | Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | | Qwen3 | 0.6B-235B | LLM | Yes | Yes | Yes\* | Yes\* | | Qwen3.5\[3\] | 0.8B-397B | MLLM | Yes | Yes | No | Yes | | Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | | DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | | Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | | Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | | InternVL2\[2\] | 1 - 2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL3\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL3.5\[3\] | 1 - 241BA28B | MLLM | Yes | Yes\* | Yes\* | No | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | | GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | | Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | | gpt-oss | 20B,120B | LLM | Yes | Yes | Yes | Yes | "-" means not verified yet. ```{note} * [1] The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. * [2] When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference * [3] TurboMind does not currently support the vision encoder for the Qwen3.5 series. ``` ## PyTorchEngine on CUDA Platform | Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | | :----------------------------: | :-------------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | | Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama4 | Scout, Maverick | MLLM | Yes | Yes | Yes | - | - | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | | Intern-S1 | 241B | MLLM | Yes | Yes | Yes | Yes | - | | Intern-S1-mini | 8.3B | MLLM | Yes | Yes | Yes | Yes | - | | Intern-S1-Pro | 1TB | MLLM | Yes | - | - | - | No | | Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | | Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | | ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | | QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* | | QWen3-Next | 80B | LLM | Yes | No | No | No | No | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | | QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No | | QWen3.5 | 0.8B-397B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | | MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | | Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | | Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | | Phi-4-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No | | InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - | | InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - | | InternVL3 | 1B-78B | MLLM | Yes | Yes | Yes | - | - | | InternVL3.5 | 1B-241BA28B | MLLM | Yes | Yes | Yes | No | No | | Mono-InternVL\[1\] | 2B | MLLM | Yes | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | | Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | | Gemma3 | 1B-27B | MLLM | Yes | Yes | Yes | - | - | | GLM-4 | 9B | LLM | Yes | Yes | Yes | No | No | | GLM-4-0414 | 9B | LLM | Yes | Yes | Yes | - | - | | GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes | | GLM-4.1V-Thinking | 9B | MLLM | Yes | Yes | Yes | - | - | | GLM-4.5 | 355B | LLM | Yes | Yes | Yes | - | - | | GLM-4.5-Air | 106B | LLM | Yes | Yes | Yes | - | - | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | | SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | | GLM-4.7-Flash | 30B | LLM | Yes | No | No | No | No | | GLM-5 | 754B | LLM | Yes | No | No | No | No | ```{note} * [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. * [2] PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf Starting from version 0.11.1, PytorchEngine no longer provides support for mllama. ``` ## PyTorchEngine on Other Platforms | | | | Atlas 800T A2 | Atlas 800T A2 | Atlas 800T A2 | Atlas 800T A2 | Atlas 300I Duo | Atlas 800T A3 | Maca C500 | Cambricon | | :------------: | :-------: | :--: | :--------------: | :--------------: | :-----------: | :-----------: | :------------: | :--------------: | :-------: | :-------: | | Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | FP16(graph) | FP16/BF16(eager) | BF/FP16 | BF/FP16 | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | - | Yes | Yes | Yes | | Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | Mixtral | 8x7B | LLM | Yes | Yes | No | No | Yes | - | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | - | - | Yes | - | | QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | - | - | Yes | No | | QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | Yes | - | Yes | No | | QWen2-MoE | A14.57B | LLM | Yes | - | No | No | - | - | Yes | - | | QWen3 | 0.6B-235B | LLM | Yes | Yes | No | No | Yes | Yes | Yes | Yes | | DeepSeek-V2 | 16B | LLM | No | Yes | No | No | - | - | - | - | | InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | - | - | Yes | - | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | InternVL3 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | CogVLM2-chat | 19B | MLLM | Yes | No | - | - | - | - | Yes | - | | GLM4V | 9B | MLLM | Yes | No | - | - | - | - | - | - | ================================================ FILE: docs/zh_cn/.readthedocs.yaml ================================================ version: 2 formats: all build: os: "ubuntu-22.04" tools: python: "3.10" sphinx: configuration: docs/zh_cn/conf.py python: install: - requirements: requirements/docs.txt - requirements: requirements/readthedocs.txt ================================================ FILE: docs/zh_cn/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/zh_cn/_static/css/readthedocs.css ================================================ table.autosummary td { width: 50% } img.align-center { display: block; margin-left: auto; margin-right: auto; } ================================================ FILE: docs/zh_cn/advance/chat_template.md ================================================ # 自定义对话模板 被应用的对话模板效果,可以通过设置日志等级为`INFO`进行观测。 LMDeploy 支持两种添加对话模板的形式: - 一种是利用现有对话模板,直接配置一个如下的 json 文件使用。 ```json { "model_name": "your awesome chat template name", "system": "<|im_start|>system\n", "meta_instruction": "You are a robot developed by LMDeploy.", "eosys": "<|im_end|>\n", "user": "<|im_start|>user\n", "eoh": "<|im_end|>\n", "assistant": "<|im_start|>assistant\n", "eoa": "<|im_end|>", "separator": "\n", "capability": "chat", "stop_words": ["<|im_end|>"] } ``` 这样一个模板将会以下面的形式进行拼接。 ``` {system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}{assistant_content}{eoa}{separator}{user}... ``` 在使用 CLI 工具时,可以通过 `--chat-template` 传入自定义对话模板,比如: ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE} ``` 也可以在通过接口函数传入,比如: ```python from lmdeploy import ChatTemplateConfig, serve serve('internlm/internlm2_5-7b-chat', chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}')) ``` - 一种是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优点是自定义程度高,可控性强。 下面是一个注册 LMDeploy 对话模板的例子: ```python from lmdeploy.model import MODELS, BaseChatTemplate @MODELS.register_module(name='customized_model') class CustomizedModel(BaseChatTemplate): """A customized chat template.""" def __init__(self, system='<|im_start|>system\n', meta_instruction='You are a robot developed by LMDeploy.', user='<|im_start|>user\n', assistant='<|im_start|>assistant\n', eosys='<|im_end|>\n', eoh='<|im_end|>\n', eoa='<|im_end|>', separator='\n', stop_words=['<|im_end|>', '<|action_end|>']): super().__init__(system=system, meta_instruction=meta_instruction, eosys=eosys, user=user, eoh=eoh, assistant=assistant, eoa=eoa, separator=separator, stop_words=stop_words) from lmdeploy import ChatTemplateConfig, pipeline messages = [{'role': 'user', 'content': 'who are you?'}] pipe = pipeline('internlm/internlm2_5-7b-chat', chat_template_config=ChatTemplateConfig('customized_model')) for response in pipe.stream_infer(messages): print(response.text, end='') ``` 在这个例子中,我们注册了一个 LMDeploy 的对话模板,该模板将模型设置为由 LMDeploy 创造,所以当用户提问模型是谁的时候,模型就会回答由 LMDeploy 所创。 ================================================ FILE: docs/zh_cn/advance/context_parallel.md ================================================ # 序列并行 在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点: 1. 可用的 kvcache 数量减半,进而减少请求最大推理长度 2. 降低推理的最大 batch 数量,减少吞吐量。 为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如 ``` cp_rank=2, prompt_len=5, generation_len=4 kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 kv_cache stored on cp_rank1: 1, 3, 5, 7 ``` ## 使用说明 以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署 ``` lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 ``` ================================================ FILE: docs/zh_cn/advance/debug_turbomind.md ================================================ # 如何调试 Turbomind Turbomind 使用 C++ 实现,不像 Python 一样易于调试。该文档提供了调试 Turbomind 的基本方法。 ## 前置工作 首先,根据构建[命令](../get_started/installation.md)完成源码编译和安装。 ## 配置 Python 调试环境 由于目前许多大公司在线上生产环境中使用 Centos 7,我们将以 Centos 7 为例来说明配置过程。 ### 获取 `glibc` 和 `python3` 的版本 ```bash rpm -qa | grep glibc rpm -qa | grep python3 ``` 结果类似于这样: ``` [username@hostname workdir]# rpm -qa | grep glibc glibc-2.17-325.el7_9.x86_64 glibc-common-2.17-325.el7_9.x86_64 glibc-headers-2.17-325.el7_9.x86_64 glibc-devel-2.17-325.el7_9.x86_64 [username@hostname workdir]# rpm -qa | grep python3 python3-pip-9.0.3-8.el7.noarch python3-rpm-macros-3-34.el7.noarch python3-rpm-generators-6-2.el7.noarch python3-setuptools-39.2.0-10.el7.noarch python3-3.6.8-21.el7_9.x86_64 python3-devel-3.6.8-21.el7_9.x86_64 python3.6.4-sre-1.el6.x86_64 ``` 根据上述信息,我们可以看到 `glibc` 的版本是 `2.17-325.el7_9.x86_64`,`python3` 的版本是 `3.6.8-21.el7_9.x86_64`。 ### 下载并安装 `debuginfo` 库 从 http://debuginfo.centos.org/7/x86_64 下载 `glibc-debuginfo-common-2.17-325.el7.x86_64.rpm`、`glibc-debuginfo-2.17-325.el7.x86_64.rpm` 和 `python3-debuginfo-3.6.8-21.el7.x86_64.rpm`。 ```bash rpm -ivh glibc-debuginfo-common-2.17-325.el7.x86_64.rpm rpm -ivh glibc-debuginfo-2.17-325.el7.x86_64.rpm rpm -ivh python3-debuginfo-3.6.8-21.el7.x86_64.rpm ``` ### 升级 GDB ```bash sudo yum install devtoolset-10 -y echo "source scl_source enable devtoolset-10" >> ~/.bashrc source ~/.bashrc ``` ### 验证 ```bash gdb python3 ``` 输出类似于这样: ``` [username@hostname workdir]# gdb python3 GNU gdb (GDB) Red Hat Enterprise Linux 9.2-10.el7 Copyright (C) 2020 Free Software Foundation, Inc. License GPLv3+: GNU GPL version 3 or later This is free software: you are free to change and redistribute it. There is NO WARRANTY, to the extent permitted by law. Type "show copying" and "show warranty" for details. This GDB was configured as "x86_64-redhat-linux-gnu". Type "show configuration" for configuration details. For bug reporting instructions, please see: . Find the GDB manual and other documentation resources online at: . For help, type "help". Type "apropos word" to search for commands related to "word"... Reading symbols from python3... (gdb) ``` 如果显示 `Reading symbols from python3`,说明配置成功。 对于其他操作系统,请参考 [DebuggingWithGdb](https://wiki.python.org/moin/DebuggingWithGdb)。 ## 设置符号链接 设置符号链接后,不需要每次都通过 `pip` 进行本地安装。 ```bash # 更改目录到 lmdeploy,例如 cd /workdir/lmdeploy # 因为编译文件在 build 文件夹中 # 设置 lib 的软链接 cd lmdeploy && ln -s ../build/lib . && cd .. # (可选)创建 compile_commands.json 软链接,用于 clangd 构建 index ln -s build/compile_commands.json . ``` ## 开始调试 ````bash # 使用 gdb 启动 API Server,例如 gdb --args python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf # 在 gdb 中设置 lmdeploy 文件夹路径 Reading symbols from python3... (gdb) set directories /workdir/lmdeploy # 使用相对路径设置断点,例如 (gdb) b src/turbomind/models/llama/BlockManager.cc:104 # 当出现 # ``` # No source file named src/turbomind/models/llama/BlockManager.cc. # Make breakpoint pending on future shared library load? (y or [n]) # ``` # 输入 y 并回车 # 运行 (gdb) r # (可选) 使用 https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py 发送请求 python3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json ```` ## 使用 GDB 参考 [GDB Execution Commands](https://lldb.llvm.org/use/map.html) 进行调试。 ================================================ FILE: docs/zh_cn/advance/long_context.md ================================================ # 长文本外推 长文本外推指 LLM 推理时处理比训练文本更长数据的能力。TurboMind 引擎目前支持 [LlamaDynamicNTKScalingRotaryEmbedding](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L178), 并与 HuggingFace 的实现对齐。 ## 如何使用 如果要直接加载 HuggingFace 格式的模型,可以通过修改 TurbomindEngineConfig 参数的方式赋予模型外推能力。将 `session_len` 修改为外推的长度,并将 `rope_scaling_factor` 修改为不小于 1.0 的值。 以具有 **1M 上下文长度**的`internlm2_5-7b-chat-1m`为例,可以使用如下方式,激活长文本推理能力: ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=1000000, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config) prompt = 'Use a long prompt to replace this sentence' gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) response = pipe(prompt, gen_config=gen_config) print(response) ``` ## 评测 我们使用多种方式评测 LMDeploy 长文本推理能力,分别是 [passkey retrieval 实验](#passkey-retrieval)、[大海捞针实验](#大海捞针) 和[计算困惑度](#困惑度) ### Passkey Retrieval 执行如下代码,可以测试在长文本中找到特殊 key 成功和失败的次数 ```python import numpy as np from lmdeploy import pipeline from lmdeploy import TurbomindEngineConfig import time session_len = 1000000 backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=session_len, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config) def passkey_retrieval(session_len, n_round=5): # create long context input tok = pipe.tokenizer task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.' garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.' for _ in range(n_round): start = time.perf_counter() n_times = (session_len - 1000) // len(tok.encode(garbage)) n_garbage_prefix = np.random.randint(0, n_times) n_garbage_suffix = n_times - n_garbage_prefix garbage_prefix = ' '.join([garbage] * n_garbage_prefix) garbage_suffix = ' '.join([garbage] * n_garbage_suffix) pass_key = np.random.randint(1, 50000) information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.' # noqa: E501 final_question = 'What is the pass key? The pass key is' lines = [ task_description, garbage_prefix, information_line, garbage_suffix, final_question, ] # inference prompt = ' '.join(lines) response = pipe([prompt]) print(pass_key, response) end = time.perf_counter() print(f'duration: {end - start} s') passkey_retrieval(session_len, 5) ``` 在 A100-80G GPU上,执行上述实验,每轮测试大约需要 364 秒 ### 大海捞针 可使用 OpenCompass 进行测评,具体使用方法,请参考[文档](https://github.com/open-compass/opencompass/blob/main/docs/zh_cn/advanced_guides/needleinahaystack_eval.md) ### 困惑度 下面展示使用 LMDeploy 计算困惑度的用法 ```python from transformers import AutoTokenizer from lmdeploy import TurbomindEngineConfig, pipeline import numpy as np # load model and tokenizer model_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m' backend_config = TurbomindEngineConfig( rope_scaling_factor=2.5, session_len=1000000, max_batch_size=1, cache_max_entry_count=0.7, tp=4) pipe = pipeline(model_repoid_or_path, backend_config=backend_config) tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True) # get perplexity text = 'Use a long prompt to replace this sentence' input_ids = tokenizer.encode(text) loss = pipe.get_ppl(input_ids)[0] print(ppl) ``` ================================================ FILE: docs/zh_cn/advance/metrics.md ================================================ # 生产环境指标监控 LMDeploy 通过 Prometheus 暴露监控指标,并通过 Grafana 提供可视化界面。 ## 配置指南 本节介绍如何设置 `lmdeploy/monitoring` 目录中提供的监控套件(Prometheus + Grafana) ## 前提条件 - 已安装 [Docker](https://docs.docker.com/engine/install/) 和 [Docker Compose](https://docs.docker.com/compose/install/) - 已启用指标系统的 LMDeploy 服务正在运行 ## 使用说明 (DP = 1) 1. **启动已启用指标的 LMDeploy 服务** ``` lmdeploy serve api_server Qwen/Qwen2.5-7B-Instruct --enable-metrics ``` 请根据需求替换模型路径。默认 metrics endpoint 位于 `http://:23333/metrics`。 2. **进入监控目录** ``` cd lmdeploy/monitoring ``` 3. **启动监控套件** ``` docker compose up ``` 此命令将在后台启动 Prometheus 和 Grafana。 4. **访问监控界面** - Prometheus:浏览器访问 http://localhost:9090. - Grafana:浏览器访问 http://localhost:3000. 5. **登录 Grafana** - 默认用户名:`admin` - 默认密码:`admin` (首次登录后会提示修改密码) 6. **查看仪表盘** 预配置的 LMDeploy 仪表盘将自动加载。 ## 使用说明 (DP > 1) 1. **启动已启用指标的 LMDeploy 服务** 以模型 `Qwen/Qwen2.5-7B-Instruct` 为例,使用 `DP=2,TP=2` 启动服务: ```bash # Proxy server lmdeploy serve proxy --server-port 8000 --routing-strategy 'min_expected_latency' --serving-strategy Hybrid --log-level INFO # API server LMDEPLOY_DP_MASTER_ADDR=127.0.0.1 \ LMDEPLOY_DP_MASTER_PORT=29555 \ lmdeploy serve api_server \ Qwen/Qwen2.5-7B-Instruct \ --backend pytorch \ --tp 2 \ --dp 2 \ --proxy-url http://0.0.0.0:8000 \ --nnodes 1 \ --node-rank 0 \ --enable-metrics ``` 您应该能在代理服务器列表中看到多个 API 服务实例。详细信息可以在 `lmdeploy/serve/proxy/proxy_config.json` 中找到。 例如,您可能会看到如下 API 服务地址: ``` http://$host_ip:$api_server_port1 http://$host_ip:$api_server_port2 ``` 2. **修改 Prometheus 配置** 当 DP > 1 时,LMDeploy 会为每个 DP Rank 启动一个 API 服务。如果你想监控其中某个 API 服务,例如:`http://$host_ip:$api_server_port1`,请修改配置文件 `lmdeploy/monitoring/prometheus.yaml` 如下所示。 > 注意:这里应使用实际主机的 IP 地址而非 127.0.0.1,因为当 DP > 1 时,LMDeploy 是通过实际主机 IP 启动 API 服务的。 ``` global: scrape_interval: 5s evaluation_interval: 30s scrape_configs: - job_name: lmdeploy static_configs: - targets: - '$host_ip:$api_server_port1' # <= 修改此处 ``` 3. **进入监控目录并执行上述相同步骤** ## 故障排除 1. **端口冲突** 检查端口 `23333` (LMDeploy 服务端口)、`9090` (Prometheus 端口) 或 `3000` (Grafana 端口) 是否被占用。解决方案,关闭冲突的端口或如下修改配置文件: - 修改 Prometheus 抓取的 LMDeploy 服务端口 在 `lmdeploy/monitoring/prometheus.yaml` 中 ``` global: scrape_interval: 5s evaluation_interval: 30s scrape_configs: - job_name: lmdeploy static_configs: - targets: - '127.0.0.1:23333' # <= 修改此处的 LMDeploy 服务端口 23333,需与实际运行端口一致 ``` - 修改 Prometheus 端口 在 `lmdeploy/monitoring/grafana/datasources/datasource.yaml` 中 ``` apiVersion: 1 datasources: - name: Prometheus type: prometheus access: proxy url: http://localhost:9090 # <= 修改此处的 Prometheus 接口端口 9090 isDefault: true editable: false ``` - 修改 Grafana 端口 在 `lmdeploy/monitoring/docker-compose.yaml` 中操作(例如改为 3090 端口): 方案一:在环境变量中添加 `GF_SERVER_HTTP_PORT` ``` environment: - GF_AUTH_ANONYMOUS_ENABLED=true - GF_SERVER_HTTP_PORT=3090 # <= 添加此行 ``` 方案二:使用端口映射 ``` grafana: image: grafana/grafana:latest container_name: grafana ports: - "3090:3000" # <= 主机端口:容器端口映射 ``` - **仪表盘无数据** 尝试向 LMDeploy 服务发送请求生成流量: ``` python3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` 刷新后仪表盘应显示数据。 ================================================ FILE: docs/zh_cn/advance/pytorch_multinodes.md ================================================ # PyTorchEngine 多节点部署指南 为了支持更大规模的模型部署需求,PyTorchEngine 提供了多节点部署的支持。以下是如何在两个8卡节点上部署 tp=16 模型的详细步骤。 ## 1. 创建 Docker 容器(可选) 为了确保集群环境的一致性,建议使用 Docker 搭建集群。在每个节点上创建容器: ```bash docker run -it \ --network host \ -v $MODEL_PATH:$CONTAINER_MODEL_PATH \ openmmlab/lmdeploy:latest ``` > \[!IMPORTANT\] > 请确保将模型放置在各个节点容器的相同目录中。 ## 2. 使用 ray 搭建集群 ### 2.1 启动主节点 选择其中一个节点做为`主节点`,并在该节点的容器中运行以下命令: ```bash ray start --head --port=$DRIVER_PORT ``` ### 2.2 加入集群 在其他节点的容器中,使用以下命令加入主节点所在的集群: ```bash ray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT ``` 完成后可以在主节点使用 `ray status` 查看集群状态,确保所有节点都被成功加入集群。 > \[!IMPORTANT\] > 请确保 `DRIVER_NODE_ADDR` 为主节点的地址,`DRIVER_PORT` 与主节点初始化时使用的端口号一致。 ## 3. 使用 LMDeploy 接口 在主节点的容器中,您可以正常使用 PyTorchEngine 的所有功能。 ### 3.1 启动服务 API ```bash lmdeploy serve api_server \ $CONTAINER_MODEL_PATH \ --backend pytorch \ --tp 16 ``` ### 3.2 使用 pipeline 接口 ```python from lmdeploy import pipeline, PytorchEngineConfig if __name__ == '__main__': model_path = '/path/to/model' backend_config = PytorchEngineConfig(tp=16) with pipeline(model_path, backend_config=backend_config) as pipe: outputs = pipe('Hakuna Matata') ``` > \[!NOTE\] > PytorchEngine 会根据 tp 数以及集群上的设备数量自动选择合适的启动方式(单机/多机)。如果希望强制使用 ray 集群,可以配置 `PytorchEngineConfig` 中的 `distributed_executor_backend='ray'` 或使用环境变量 `LMDEPLOY_EXECUTOR_BACKEND=ray`。 通过以上步骤,您可以成功在多节点环境中部署 PyTorchEngine,并利用 Ray 集群进行分布式计算。 > \[!WARNING\] > 为了能够得到更好的性能,我们建议用户配置更好的网络环境(比如使用 [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand))以提高引擎运行效率 ================================================ FILE: docs/zh_cn/advance/pytorch_multithread.md ================================================ # PyTorchEngine 多线程推理 自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发, 如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。 ```python import threading from queue import Queue import asyncio from lmdeploy import pipeline, PytorchEngineConfig model_path = 'Llama-3.2-1B-Instruct' async def _batch_infer(inque: Queue, outque: Queue, pipe): while True: if inque.empty(): await asyncio.sleep(0) continue input = inque.get_nowait() output = await pipe.async_batch_infer(input) outque.put(output) def server(inques, outques): event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) pipe = pipeline(model_path, backend_config=PytorchEngineConfig()) for inque, outque in zip(inques, outques): event_loop.create_task(_batch_infer(inque, outque, pipe)) event_loop.run_forever() def client(inque, outque, message): inque.put(message) print(outque.get().text) inques = [Queue(), Queue()] outques = [Queue(), Queue()] t_server = threading.Thread(target=server, args=(inques, outques)) t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata')) t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures')) t_server.start() t_client0.start() t_client1.start() t_client0.join() t_client1.join() ``` > \[!WARNING\] > 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定 ================================================ FILE: docs/zh_cn/advance/pytorch_new_model.md ================================================ # lmdeploy.pytorch 新模型支持 lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,用户可以根据自己的需求适配新的模型。 ## 模型支持 ### 配置加载(可选) lmdeploy.pytorch 会根据模型的参数初始化引擎,如果需要接入的模型的参数命名与 transformers 中常见模型不同,可能存在解析错误的情况。可以添加自定义的 ConfigBuilder 来解析配置 ```python # lmdeploy/pytorch/configurations/gemma.py from lmdeploy.pytorch.config import ModelConfig from .builder import AutoModelConfigBuilder class GemmaModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): # 判断 hf_config 是否适配该 builder return hf_config.model_type in ['gemma', 'gemma2'] @classmethod def build(cls, hf_config, model_path: str = None): # 使用 transformers 加载的 hf_config # 构造 pytorch engine 的 ModelConfig return ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, num_attention_heads=hf_config.num_attention_heads, num_key_value_heads=hf_config.num_key_value_heads, bos_token_id=hf_config.bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=hf_config.head_dim, vocab_size=hf_config.vocab_size) ``` 可以使用 `lmdeploy.pytorch.check_env.check_model` 函数验证配置是否能够正确解析 ### 实现模型 在确保能够正确解析配置后,就可以开始实现模型逻辑。以 llama 的实现为例,我们需要通过 transformers 的配置文件创建模型 ```python class LlamaForCausalLM(nn.Module): # 构造函数,通过传入的 config 搭建模型 # ctx_mgr 是上下文管理器,可以通过它传入引擎配置或额外参数 def __init__(self, config: LlamaConfig, ctx_mgr: StepContextManager, dtype: torch.dtype = None, device: torch.device = None): super().__init__() self.config = config self.ctx_mgr = ctx_mgr # build LLamaModel self.model = LlamaModel(config, dtype=dtype, device=device) # build lm_head self.lm_head = build_rowwise_linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) # 模型推理函数 # 推荐尽可能使用与下面相同的参数 def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, ): hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) logits = self.lm_head(hidden_states) logits = logits.float() return logits ``` 除了这些以外,还有如下内容需要添加 ```python class LlamaForCausalLM(nn.Module): ... # 标注该模型是否支持 cudagraph # 可以是一个 callable 对象,接收 forward 输入 # 动态判断是否支持 cudagraph support_cuda_graph = True # 构建模型输入 # 返回词典,词典的 key 必须是 forward 的输入 def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], inputs_embeds: Optional[torch.Tensor] = None, context: StepContext = None, ): ... # 加载权重 # 模型的输入是 state dict 的 key value 对 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ... ``` 我们封装了许多融合算子以简化模型的搭建。这些算子能够更好的支持 tensor 并行、量化等各种功能,我们鼓励开发者尽可能使用这些 op 进行开发。 ```python # 使用预定义的 build_merged_colwise_linear, SiluAndMul, build_rowwise_linear # 可以帮助我们更快搭建模型,并且不用关心 tensor 并发、量化等细节 class LlamaMLP(nn.Module): def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None): super().__init__() quantization_config = getattr(config, 'quantization_config', None) # gate up self.gate_up_proj = build_merged_colwise_linear( config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=config.mlp_bias, dtype=dtype, device=device, quant_config=quantization_config, is_tp=True, ) # silu and mul self.act_fn = SiluAndMul(inplace=True) # down self.down_proj = build_rowwise_linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias, quant_config=quantization_config, dtype=dtype, device=device, is_tp=True) def forward(self, x): """forward.""" gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.down_proj(act) ``` ### 模型注册 为了能够让开发的模型实现可以正常使用,我们还需要在 `lmdeploy/pytorch/models/module_map.py` 中注册该模型 ```python MODULE_MAP.update({ 'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', }) ``` 如果你不希望修改模型源码,也可以从外部传入自定义的 module map,方便整合进其他项目中 ``` from lmdeploy import PytorchEngineConfig, pipeline backend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py') generator = pipeline(model_path, backend_config=backend_config) ``` ================================================ FILE: docs/zh_cn/advance/pytorch_profiling.md ================================================ # PyTorchEngine 性能分析 我们提供了数种分析 PytorchEngine 性能的方式 ## PyTorch Profiler 我们集成了 PyTorch Profiler,可以在启动 pipeline 或 api server 时添加环境变量: ```bash # enable profile cpu export LMDEPLOY_PROFILE_CPU=1 # enable profile cuda export LMDEPLOY_PROFILE_CUDA=1 # profile would start after 3 seconds export LMDEPLOY_PROFILE_DELAY=3 # profile 10 seconds export LMDEPLOY_PROFILE_DURATION=10 # prefix path to save profile files export LMDEPLOY_PROFILE_OUT_PREFIX="/path/to/save/profile_" ``` 这样在退出程序后,统计信息会被存储在 `LMDEPLOY_PROFILE_OUT_PREFIX` 指定的地址,方便进行性能分析。 ## Nsight System 我们也支持使用 Nsight System 分析 nVidia 设备的性能。 ### 单卡 单卡情况下比较简单,可以直接使用 `nsys profile`: ```bash nsys profile python your_script.py ``` ### 多卡 当启用了 DP/TP/EP 等多卡方案时,可以设置环境变量 ```bash # enable nsight system export LMDEPLOY_RAY_NSYS_ENABLE=1 # prefix path to save profile files export LMDEPLOY_RAY_NSYS_OUT_PREFIX="/path/to/save/profile_" ``` 然后正常启动脚本或 api server 即可(注意**不要**添加 `nsys profile`) 这样 profile 的结果就会被保存在 `LMDEPLOY_RAY_NSYS_OUT_PREFIX` 下,如果没有配置 `LMDEPLOY_RAY_NSYS_OUT_PREFIX`,可以在 `/tmp/ray/session_xxx/nsight` 目录下找到。 ## Ray timeline 我们使用 ray 实现多卡支持,如果希望查看 ray timeline,可以配置如下环境变量: ```bash export LMDEPLOY_RAY_TIMELINE_ENABLE=1 export LMDEPLOY_RAY_TIMELINE_OUT_PATH="/path/to/save/timeline.json" ``` ================================================ FILE: docs/zh_cn/advance/spec_decoding.md ================================================ # Speculative Decoding 投机解码是一种优化技术,它通过引入轻量级草稿模型来预测多个后续token,再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比,这种方法可使系统一次性生成多个token。 > \[!NOTE\] > 请注意,这是lmdeploy中的实验性功能。 ## 示例 请参考如下使用示例。 ### Eagle 3 #### 安装依赖 安装 [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) ```shell git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper python setup.py install ``` #### pipeline ```python from lmdeploy import PytorchEngineConfig, pipeline from lmdeploy.messages import SpeculativeConfig if __name__ == '__main__': model_path = 'meta-llama/Llama-3.1-8B-Instruct' spec_cfg = SpeculativeConfig( method='eagle3', num_speculative_tokens=3, model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', ) pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` #### serving ```shell lmdeploy serve api_server \ meta-llama/Llama-3.1-8B-Instruct \ --backend pytorch \ --server-port 24545 \ --speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \ --speculative-algorithm eagle3 \ --speculative-num-draft-tokens 3 \ --max-batch-size 128 \ --enable-metrics ``` ### Deepseek MTP #### 安装依赖 Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation) ```shell git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla cd flash-mla git submodule update --init --recursive pip install -v . ``` #### pipeline ```python from lmdeploy import PytorchEngineConfig, pipeline from lmdeploy.messages import SpeculativeConfig if __name__ == '__main__': model_path = 'deepseek-ai/DeepSeek-V3' spec_cfg = SpeculativeConfig( method='deepseek_mtp', num_speculative_tokens=3, ) pipe = pipeline(model_path, backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), speculative_config=spec_cfg) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` #### serving ```shell lmdeploy serve api_server \ deepseek-ai/DeepSeek-V3 \ --backend pytorch \ --server-port 24545 \ --tp 16 \ --speculative-algorithm deepseek_mtp \ --speculative-num-draft-tokens 3 \ --max-batch-size 128 \ --enable-metrics ``` ================================================ FILE: docs/zh_cn/advance/structed_output.md ================================================ # 结构化输出 结构化输出(也称为引导解码)会强制模型生成与用户提供的 JSON 模式、语法或正则表达式完全匹配的文本。 当前,PyTorch 与 Turbomind 两个后端均已支持这种(受模式约束的)结构化生成。 以下分别为 pipeline API 和 API 服务的使用示例。 ## pipeline ```python from lmdeploy import pipeline from lmdeploy.messages import GenerationConfig, PytorchEngineConfig model = 'internlm/internlm2-chat-1_8b' guide = { 'type': 'object', 'properties': { 'name': { 'type': 'string' }, 'skills': { 'type': 'array', 'items': { 'type': 'string', 'maxLength': 10 }, 'minItems': 3 }, 'work history': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'company': { 'type': 'string' }, 'duration': { 'type': 'string' } }, 'required': ['company'] } } }, 'required': ['name', 'skills', 'work history'] } pipe = pipeline(model, backend_config=PytorchEngineConfig(), log_level='INFO') gen_config = GenerationConfig( response_format=dict(type='json_schema', json_schema=dict(name='test', schema=guide))) response = pipe(['Make a self introduction please.'], gen_config=gen_config) print(response) ``` ## api_server 首先,先启动 InternLM2 模型的 api_server 服务。 ```shell lmdeploy serve api_server internlm/internlm2-chat-1_8b --backend pytorch ``` 客户端可以使用 OpenAI 的 python 包进行测试: ```python from openai import OpenAI guide = { 'type': 'object', 'properties': { 'name': { 'type': 'string' }, 'skills': { 'type': 'array', 'items': { 'type': 'string', 'maxLength': 10 }, 'minItems': 3 }, 'work history': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'company': { 'type': 'string' }, 'duration': { 'type': 'string' } }, 'required': ['company'] } } }, 'required': ['name', 'skills', 'work history'] } response_format=dict(type='json_schema', json_schema=dict(name='test',schema=guide)) messages = [{'role': 'user', 'content': 'Make a self-introduction please.'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, response_format=response_format, top_p=0.8) print(response) ``` 输出结果是一个 json 格式的回答。 ================================================ FILE: docs/zh_cn/advance/update_weights.md ================================================ # 权重更新 LMDeploy支持在线权重更新,方便RL训练等场景下的使用。以下是权重更新的步骤: ## 步骤 1: 启动服务 For pytorch backend you have to add `--distributed-executor-backend ray`. ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend ``` ## 步骤 2: 卸载权重和KV缓存 在权重更新前,需要调用API卸载权重和KV缓存,使推理引擎处于可更新状态: ```python from lmdeploy.utils import serialize_state_dict import requests BASE_URL = 'http://0.0.0.0:23333' api_key = 'sk-xxx' headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } # offloads weights and kv cache with level=2 response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2)) assert response.status_code == 200, response.status_code # wake up weights, the server is ready for update weights response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights'])) assert response.status_code == 200, response.status_code ``` ## 步骤 3: 更新权重 将模型权重切分后调用`update_weights`API进行更新。 ```python segmented_state_dict: List[Dict[str, torch.Tensor]] = ... num_segment = len(segmented_state_dict) for seg_idx in range(num_segment): serialized_data = serialize_state_dict(segmented_state_dict[seg_idx]) data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1) response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) assert response.status_code == 200, f"response.status_code = {response.status_code}" ``` **注意**: 对于pytorch推理后端,lmdeploy还支持扁平化桶张量(flattened bucket tensor)传输方式: ```python from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata segmented_state_dict: List[Dict[str, torch.Tensor]] = ... num_segment = len(segmented_state_dict) for seg_idx in range(num_segment): named_tensors = list(segmented_state_dict[seg_idx].items()) bucket = FlattenedTensorBucket(named_tensors=named_tensors) metadata = bucket.get_metadata() flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata) serialized_data = serialize_state_dict(flattened_tensor_data) data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket') response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) assert response.status_code == 200, f"response.status_code = {response.status_code}" ``` ## 步骤 4: 唤醒引擎 权重更新后,调用API构建KV缓存,唤醒引擎,重新提供推理服务。 ```python response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache'])) assert response.status_code == 200, response.status_code ``` ================================================ FILE: docs/zh_cn/api/cli.rst ================================================ 命令行工具 =========== .. sphinx_argparse_cli:: :module: lmdeploy.cli :func: run :hook: :prog: lmdeploy ================================================ FILE: docs/zh_cn/api/openapi.rst ================================================ OpenAPI 接口 ============ .. currentmodule:: lmdeploy OpenAI 兼容服务器接口 ---------------------- .. openapi:: ../_static/openai.yaml :request: :examples: Proxy 服务器接口 ----------------- .. openapi:: ../_static/proxy.yaml :request: :examples: ================================================ FILE: docs/zh_cn/api/pipeline.rst ================================================ 推理 pipeline ================== .. currentmodule:: lmdeploy Pipeline -------- .. autofunction:: pipeline .. autoclass:: Pipeline :undoc-members: :show-inheritance: :members: __init__, infer, stream_infer, chat, get_ppl :member-order: bysource Config ------------------- .. autoclass:: PytorchEngineConfig .. autoclass:: TurbomindEngineConfig .. autoclass:: GenerationConfig .. autoclass:: ChatTemplateConfig ================================================ FILE: docs/zh_cn/benchmark/benchmark.md ================================================ # 性能测试 测试之前,请安装 lmdeploy 预编译包,并下载测试脚本和数据。 ```shell pip install lmdeploy # 下载 lmdeploy 源码,获取其中的性能测试脚本 git clone --depth=1 https://github.com/InternLM/lmdeploy cd lmdeploy # 切换到与已安装 lmdeploy 版本对应的 tag: git fetch --tags # 查看已安装 lmdeploy 的版本: pip show lmdeploy | grep Version # 切换到对应的 tag(将 替换为实际的版本号): git checkout # 下载测试数据 wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` ## 测试 pipeline 接口 ```shell python3 benchmark/profile_pipeline_api.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct ``` 可通过 `python3 benchmark/profile_pipeline_api.py -h` 查看脚本中的参数详情 ## 测试推理引擎接口 ```shell python3 benchmark/profile_throughput.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct ``` 可通过 `python3 benchmark/profile_throughput.py -h` 查看脚本中的参数详情 ## 测试 api_server 性能 启动模型服务(可以参考[这里](../llm/api_server.md))。接着,使用下面的命令: ```shell python3 benchmark/profile_restful_api.py --backend lmdeploy --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ``` 关于 `profile_restful_api.py`的帮助信息,可以通过`python3 benchmark/profile_restful_api.py -h`查阅 ================================================ FILE: docs/zh_cn/benchmark/evaluate_with_opencompass.md ================================================ # 模型评测指南 本文档介绍如何使用 OpenCompass 和 LMDeploy 对模型在学术数据集上的能力进行评测。完整的评测流程包含两个主要阶段:推理阶段和评判阶段。 在推理阶段,首先通过 LMDeploy 将待评测模型部署为推理服务,随后使用 OpenCompass 将数据集内容作为请求发送至该服务,并获取模型生成的结果。 在评判阶段,需将 OpenCompass 提供的评测模型 `opencompass/CompassVerifier-32B` 通过 LMDeploy 部署为服务,再使用 OpenCompass 将推理阶段生成的结果提交至该服务,从而获得最终的评测结果。 若评测资源充足,建议参考[端到端评测](#端到端评测)章节执行完整流程;若资源有限,则建议按照[逐步评测](#逐步评测)章节依次执行两个阶段。 ## 环境准备 ```shell pip install lmdeploy pip install "opencompass[full]" # 下载 lmdeploy 源码,在后续步骤中会使用到 eval/* 中的评测脚本和配置文件 git clone --depth=1 https://github.com/InternLM/lmdeploy.git ``` 建议将 LMDeploy 和 OpenCompass 安装在不同的 Python 虚拟环境中,以避免可能的依赖冲突。 ## 端到端评测 1. **部署待评测模型** ```shell lmdeploy serve api_server --server-port 10000 <--other-options> ``` 2. **部署评测模型(Judger)** ```shell lmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 --session-len 65536 ``` 3. **生成评测配置并执行评测** ```shell cd {the/root/path/of/lmdeploy/repo} ## 指定数据集路径。如果在路径下没有找到评测数据集,OC会自动下载 export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache python eval/eval.py {task_name} \ --mode all \ --api-server http://{api-server-ip}:10000 \ --judger-server http://{judger-server-ip}:20000 \ -w {oc_output_dir} ``` 关于 `eval.py` 的详细使用方法,比如指定评测集,请通过 `python eval/eval.py --help` 查阅。 评测任务完成后,结果将保存在 `{oc_output_dir}/{yyyymmdd_hhmmss}` 目录中,其中 `{yyyymmdd_hhmmss}` 为任务执行的时间戳。 ## 逐步评测 ### 推理阶段 本阶段用于生成模型对数据集的回答结果。 1. **部署待评测模型** ```shell lmdeploy serve api_server --server-port 10000 <--other-options> ``` 2. **生成推理配置并执行推理** ```shell cd {the/root/path/of/lmdeploy/repo} ## 指定数据集路径。如果在路径下没有找到评测数据集,OC会自动下载 export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache # 执行推理任务 python eval/eval.py {task_name} \ --mode infer \ --api-server http://{api-server-ip}:10000 \ -w {oc_output_dir} ``` 关于 `eval.py` 的详细使用方法,比如指定评测集,请通过 `python eval/eval.py --help` 查阅。 推理完成后,结果将保存在 `{oc_output_dir}/{yyyymmdd_hhmmss}` 目录中,其中 `{yyyymmdd_hhmmss}` 为任务执行的时间戳。 ### 评判阶段 本阶段由评测模型(Judger)对推理阶段生成的结果进行判断。 1. **部署评测模型(Judger)** ```shell lmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 ``` 2. **生成评判配置并执行评判** ```shell cd {the/root/path/of/lmdeploy/repo} ## 指定数据集路径。如果在路径下没有找到评测数据集,OC会自动下载 export HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets export COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache # 执行评测任务 python eval/eval.py {task_name} \ --mode eval \ --judger-server http://{judger-server-ip}:20000 \ -w {oc_output_dir} -r {yyyymmdd_hhmmss} ``` 注意事项: - `task_name` 必须与推理阶段的任务名称保持一致 - `-w` 参数指定的输出目录 `oc_output_dir` 需与推理阶段一致 - `-r` 参数用于指定“之前的输出与结果”,应填入推理阶段生成的时间戳目录名,即 `{oc_output_dir}` 下的子目录名称 关于 `eval.py` 的详细使用方法,比如指定评测集,请通过 `python eval/eval.py --help` 查阅。 ================================================ FILE: docs/zh_cn/benchmark/evaluate_with_vlmevalkit.md ================================================ # 多模态模型评测指南 本文档介绍如何使用 VLMEvalKit 和 LMDeploy 评测多模态模型能力。 ## 环境准备 ```shell pip install lmdeploy git clone https://github.com/open-compass/VLMEvalKit.git cd VLMEvalKit && pip install -e . ``` 建议在不同的 Python 虚拟环境中分别安装 LMDeploy 和 VLMEvalKit,以避免潜在的依赖冲突。 ## 评测 1. **部署大语言多模态模型 (LMMs)** ```shell lmdeploy serve api_server --server-port 23333 <--other-options> ``` 2. **配置评测设置** 修改 `VLMEvalKit/vlmeval/config.py`,在 `api_models` 字典中添加以下 LMDeploy API 配置。 `` 是您评测任务的自定义名称(例如 `lmdeploy_qwen3vl-4b`)。`model` 参数应与 `lmdeploy serve` 命令中使用的 `` 保持一致。 ```python // filepath: VLMEvalKit/vlmeval/config.py // ...existing code... api_models = { # lmdeploy api ..., "": partial( LMDeployAPI, api_base="http://0.0.0.0:23333/v1/chat/completions", model="", retry=4, timeout=1200, temperature=0.7, # modify if needed max_new_tokens=16384, # modify if needed ), ... } // ...existing code... ``` 3. **开始评测** ```shell cd VLMEvalKit python run.py --data OCRBench --model --api-nproc 16 --reuse --verbose --api 123 ``` `` 应与上述配置文件中使用的名称保持一致。 参数说明: - `--data`: 指定用于评测的数据集(例如 `OCRBench`)。 - `--model`: 指定模型名称,必须与您在 `config.py` 中设置的 `` 匹配。 - `--api-nproc`: 指定并行的 API 调用数量。 - `--reuse`: 复用先前的推理结果,以避免重新运行已完成的评测。 - `--verbose`: 启用详细日志记录。 ================================================ FILE: docs/zh_cn/conf.py ================================================ # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys from pathlib import Path from fastapi import FastAPI from fastapi.responses import Response from yaml import safe_dump sys.path.insert(0, os.path.abspath('../..')) from lmdeploy.serve.openai.api_server import router # noqa: E402 from lmdeploy.serve.proxy.proxy import app as proxy_server # noqa: E402 version_file = '../../lmdeploy/version.py' with open(version_file, 'r') as f: exec(compile(f.read(), version_file, 'exec')) __version__ = locals()['__version__'] # -- Project information ----------------------------------------------------- project = 'lmdeploy' copyright = '2021-2024, OpenMMLab' author = 'LMDeploy Authors' # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags release = __version__ # -- Generate OpenAPI Spec ----------------------------------------------------- openai_server = FastAPI() openai_server.include_router(router) @openai_server.get('/metrics', response_class=Response, responses={ 200: { 'content': { 'text/plain': {} }, 'description': 'Prometheus metrics data' }, 404: { 'description': 'Metrics Endpoint not enabled' } }) def metrics(): """**[Optional]** Prometheus metrics endpoint.""" pass spec_dir = Path('_static') spec_dir.mkdir(exist_ok=True) with open(spec_dir / 'openai.yaml', 'w', encoding='utf-8') as f: f.write(safe_dump(openai_server.openapi())) with open(spec_dir / 'proxy.yaml', 'w', encoding='utf-8') as f: f.write(safe_dump(proxy_server.openapi())) # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'myst_parser', 'sphinx_argparse_cli', 'sphinx.ext.autodoc', 'sphinx.ext.autosectionlabel', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', 'sphinx_copybutton', 'sphinx_tabs.tabs', 'sphinxcontrib.mermaid', 'sphinxcontrib.openapi', ] # yapf: disable autosectionlabel_prefix_document = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { '.rst': 'restructuredtext', '.md': 'markdown', } # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = 'zh_CN' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # # html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_book_theme' html_logo = '_static/image/lmdeploy-logo.svg' html_title = project html_copy_source = True html_last_updated_fmt = '' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { 'path_to_docs': 'docs/zh_cn', 'repository_url': 'https://github.com/InternLM/lmdeploy', 'repository_branch': 'main', # 'show_navbar_depth': 3, # 'navigation_depth': 4, # 'collapse_navigation': True, 'use_edit_page_button': True, 'use_source_button': True, 'use_issues_button': True, 'use_repository_button': True, 'use_download_button': True, 'use_sidenotes': True, # 'show_toc_level': 2, # "icon_links": [ # { # "name": "Switch to English", # "url": "https://lmdeploy.readthedocs.io/en/latest", # "icon": "https://img.shields.io/badge/Doc-English-blue", # "type": "url", # }, # ], } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] html_css_files = ['css/readthedocs.css'] # Enable ::: for my_st myst_enable_extensions = [ 'dollarmath', 'amsmath', 'deflist', # "html_admonition", # "html_image", 'colon_fence', # "smartquotes", # "replacements", # "linkify", # "substitution", ] myst_heading_anchors = 5 # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'lmdeploydoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'lmdeploy.tex', 'lmdeploy Documentation', 'LMDeploy Contributors', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, 'lmdeploy', 'lmdeploy Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'lmdeploy', 'lmdeploy Documentation', author, 'lmdeploy', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] # -- Extension configuration ------------------------------------------------- # Ignore >>> when copying code copybutton_prompt_text = r'>>> |\.\.\. ' copybutton_prompt_is_regexp = True autodoc_preserve_defaults = True navigation_with_keys = False # Mock out external dependencies here, # otherwise the autodoc pages may be blank. autodoc_mock_imports = [ 'torch', 'torchvision', 'transformers', '_turbomind', 'triton', ] autodoc_type_aliases = {'PydanticDataclass': 'pydantic.dataclasses.PydanticDataclass'} intersphinx_mapping = { 'python': ('https://docs.python.org/3.10', None), 'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest', None), 'pillow': ('https://pillow.readthedocs.io/en/stable', None), 'numpy': ('https://numpy.org/doc/stable', None), 'torch': ('https://pytorch.org/docs/stable', None), 'torchvision': ('https://pytorch.org/vision/stable', None), } ================================================ FILE: docs/zh_cn/faq.md ================================================ # 常见问题 ## ModuleNotFoundError ### No module named 'mmengine.config.lazy' 可能是因为已经有旧版本的mmengine缓存在了本机。更新到最新班应该可以解决这个问题。 ```shell pip install --upgrade mmengine ``` ### No module named '\_turbomind' 可能是因为: 1. 您没有安装 lmdeploy 的预编译包。`_turbomind`是 turbomind c++ 的 pybind部分,涉及到编译。推荐您直接安装预编译包。 ```shell pip install lmdeploy[all] ``` 2. 如果已经安装了,还是出现这个问题,请检查下执行目录。不要在 lmdeploy 的源码根目录下执行 python -m lmdeploy.turbomind.\*下的package,换到其他目录下执行。 但是如果您是开发人员,通常需要在本地进行开发和编译。每次安装 whl 的效率太低了。您可以通过符号链接在编译后指定 lib 的路径。 ```shell # 创建 bld 和进行本地编译 mkdir bld && cd bld && bash ../generate.sh && ninja -j$(nproc) # 从 bld 中切到 lmdeploy 子目录并设置软链接 cd ../lmdeploy && ln -s ../bld/lib . # 切换到 lmdeploy 根目录 cd .. # 使用 python command 比如 check_env python3 -m lmdeploy check_env ``` 如果您仍然遇到在本地机器上找不到 turbomind so 的问题,这意味着您的本地机器上可能存在多个 Python 环境,并且在编译和执行过程中 Python 的版本不匹配。在这种情况下,您需要根据实际情况设置 `lmdeploy/generate.sh` 中的 `PYTHON_EXECUTABLE`,例如 `-DPYTHON_EXECUTABLE=/usr/local/bin/python3`,并且需要重新编译。 ## Libs ### libnccl.so.2 not found 确保通过 `pip install lmdeploy[all]` 安装了 lmdeploy (>=v0.0.5)。 如果安装之后,问题还存在,那么就把`libnccl.so.2`的路径加入到环境变量 LD_LIBRARY_PATH 中。 ```shell # 获取nvidia-nccl-cu11 package的安装目录 pip show nvidia-nccl-cu11|grep Location # 把"libnccl.so.2"的路径加入到 LD_LIBRARY_PATH export LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH ``` ### symbol cudaFreeAsync version libcudart.so.11.0 not defined in file libcudart.so.11.0 with link time reference 很可能是机器上的 cuda 版本太低导致的。LMDeploy运行时要求 cuda 不低于 11.2 ## 推理 ### RuntimeError: \[TM\]\[ERROR\] CUDA runtime error: out of memory /workspace/lmdeploy/src/turbomind/utils/allocator.h 通常这是因为 k/v cache内存比例过大导致的。比例的控制参数是 `TurbomindEngineConfig.cache_max_entry_count`。该参数在不同版本的 lmdeploy中,含义略有不同。具体请参考代码中的[演进说明](https://github.com/InternLM/lmdeploy/blob/52419bd5b6fb419a5e3aaf3c3b4dea874b17e094/lmdeploy/messages.py#L107) 如果在使用 pipeline 接口遇到该问题,请调低比例,比如 ```python from lmdeploy import pipeline, TurbomindEngineConfig backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` 如果在使用 CLI 工具时遇到此问题,请传入参数`--cache-max-entry-count`,调低 k/v cache缓存使用比例。比如, ```shell # chat 命令 lmdeploy chat internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2 # server 命令 lmdeploy serve api_server internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2 ``` ## 服务 ### Api 服务器获取超时 API 服务器的图像 URL 获取超时可通过环境变量 `LMDEPLOY_FETCH_TIMEOUT` 进行配置。默认情况下,请求可能需要长达 10 秒才会超时。 请参阅 [lmdeploy/vl/utils.py](https://github.com/InternLM/lmdeploy/blob/7b6876eafcb842633e0efe8baabe5906d7beeeea/lmdeploy/vl/utils.py#L31) 了解用法。 ## 量化 ### RuntimeError: \[enforce fail at inline_container.cc:337\] . unexpected pos 4566829760 vs 4566829656 请检查你的硬盘空间。 这个错误是因为保存权重时硬盘空间不足导致的,在量化 70B 模型时可能会遇到 ### ModuleNotFoundError: No module named 'flash_attn' 量化 `qwen` 模型需要安装 `flash-attn`。但是,根据社区用户的反馈,`flash-attn` 比较难安装。所以,lmdeploy 从依赖列表中移除 `flash-attn`,用户在用到的时候,可以进行手动安装。 ================================================ FILE: docs/zh_cn/get_started/ascend/get_started.md ================================================ # 华为昇腾 我们基于 LMDeploy 的 PytorchEngine,增加了华为昇腾设备的支持,目前支持的型号是**Atlas 800T A3,Atlas 800T A2和Atlas 300I Duo**。在华为昇腾上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前,请先阅读原版的[快速开始](../get_started.md)。 支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台). > \[!IMPORTANT\] > 我们已经在阿里云上提供了构建完成的鲲鹏CPU版本的镜像。 > 请使用下面的命令来拉取镜像: > > Atlas 800T A3: > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a3-latest` > (Atlas 800T A3目前只支持Qwen系列的算子模式下运行) > > Atlas 800T A2: > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest` > > Atlas 300I Duo: > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:300i-duo-latest` > (Atlas 300I Duo目前只支持非eager模式) > > 如果您希望自己构建环境,请参考[这里](../../../../docker)的dockerfile来自己构建。 ## 离线批处理 ### LLM 推理 将`device_type="ascend"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="ascend")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM 推理 将`device_type="ascend"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='ascend')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## 在线服务 ### LLM 模型服务 将`--device ascend`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat ``` 也可以运行以下命令启动容器运行LLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat" ``` ### VLM 模型服务 将`--device ascend`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B ``` 也可以运行以下命令启动容器运行VLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B" ``` ## 使用命令行与LLM模型对话 将`--device ascend`加入到服务启动命令中。 ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device ascend ``` 也可以运行以下命令使启动容器后开启lmdeploy聊天 ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \     bash -i -c "lmdeploy chat --backend pytorch --device ascend internlm/internlm2_5-7b-chat" ``` ## 量化 ### w4a16 AWQ 运行下面的代码可以在Atlas 800T A2上对权重进行W4A16量化。 ```bash lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu ``` 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 ### w8a8 SMOOTH_QUANT 运行下面的代码可以在Atlas 800T A2上对权重进行W8A8量化。 ```bash lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu ``` 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 ### int8 KV-cache 量化 昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。 ## Atlas 300I Duo上的限制 1. 只支持dtype=float16。 2. 只支持图模式,请不要加上--eager-mode。 ================================================ FILE: docs/zh_cn/get_started/camb/get_started.md ================================================ # 寒武纪云端加速卡 我们基于 LMDeploy 的 PytorchEngine,增加了寒武纪云端加速卡设备的支持。所以,在寒武纪云端加速卡上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前,请先阅读原版的[快速开始](../get_started.md)。 支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台). > \[!IMPORTANT\] > 我们已经在阿里云上提供了构建完成的寒武纪云端加速卡镜像。 > 请使用下面的命令来拉取镜像: > > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest` > \[!IMPORTANT\] > 目前寒武纪加速卡上启动多卡推理需要手动启动ray。下面是一个2卡的例子: > > ```shell > export MLU_VISIBLE_DEVICES=0,1 > ray start --head --resources='{"MLU": 2}' > ``` ## 离线批处理 ### LLM 推理 将`device_type="camb"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="camb")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM 推理 将`device_type="camb"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='camb')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## 在线服务 ### LLM 模型服务 将`--device camb`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat ``` 也可以运行以下命令启动容器运行LLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat" ``` ### VLM 模型服务 将`--device camb`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B ``` 也可以运行以下命令启动容器运行VLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B" ``` ## 使用命令行与LLM模型对话 将`--device camb`加入到服务启动命令中。 ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device camb ``` 也可以运行以下命令使启动容器后开启lmdeploy聊天 ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \     bash -i -c "lmdeploy chat --backend pytorch --device camb internlm/internlm2_5-7b-chat" ``` ================================================ FILE: docs/zh_cn/get_started/get_started.md ================================================ # 快速开始 LMDeploy提供了快速安装、模型量化、离线批处理、在线推理服务等功能。每个功能只需简单的几行代码或者命令就可以完成。 本教程将展示 LMDeploy 在以下几方面的使用方法: - LLM 模型和 VLM 模型的离线推理 - 搭建与 OpenAI 接口兼容的 LLM 或 VLM 模型服务 - 通过控制台命令行与 LLM 模型进行交互式聊天 在继续阅读之前,请确保你已经按照[安装指南](installation.md)安装了 lmdeploy。 ## 离线批处理 ### LLM 推理 ```python import lmdeploy pipe = lmdeploy.pipeline("internlm/internlm2_5-7b-chat") response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` 在构造 `pipeline` 时,如果没有指定使用 TurboMind 引擎或 PyTorch 引擎进行推理,LMDeploy 将根据[它们各自的能力](../supported_models/supported_models.md)自动分配一个,默认优先使用 TurboMind 引擎。 然而,你可以选择手动选择一个引擎。例如, ```python from lmdeploy import pipeline, TurbomindEngineConfig pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=TurbomindEngineConfig( max_batch_size=32, enable_prefix_caching=True, cache_max_entry_count=0.8, session_len=8192, )) ``` 或者, ```python from lmdeploy import pipeline, PytorchEngineConfig pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=PytorchEngineConfig( max_batch_size=32, enable_prefix_caching=True, cache_max_entry_count=0.8, session_len=8192, )) ``` ```{note} 参数 "cache_max_entry_count" 显著影响 GPU 内存占用。它表示加载模型权重后 K/V 缓存占用的空闲 GPU 内存的比例。 默认值是 0.8。K/V 缓存分配方式是一次性申请,重复性使用,这就是为什么 pipeline 以及下文中的 api_server 在启动后会消耗大量 GPU 内存。 如果你遇到内存不足(OOM)错误的错误,可能需要考虑降低 cache_max_entry_count 的值。 ``` 当使用 `pipe()` 生成提示词的 token 时,你可以通过 `GenerationConfig` 设置采样参数,如下所示: ```python from lmdeploy import GenerationConfig, pipeline pipe = pipeline('internlm/internlm2_5-7b-chat') prompts = ['Hi, pls intro yourself', 'Shanghai is'] response = pipe(prompts, gen_config=GenerationConfig( max_new_tokens=1024, top_p=0.8, top_k=40, temperature=0.6 )) ``` 在 `GenerationConfig` 中,`top_k=1` 或 `temperature=0.0` 表示贪心搜索。 有关 pipeline 的更多信息,请参考[这里](../llm/pipeline.md) ### VLM 推理 VLM 推理 pipeline 与 LLM 类似,但增加了使用 pipeline 处理图像数据的能力。例如,你可以使用以下代码片段对 InternVL 模型进行推理: ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` 在 VLM pipeline 中,默认的图像处理批量大小是 1。这可以通过 `VisionConfig` 调整。例如,你可以这样设置: ```python from lmdeploy import pipeline, VisionConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B', vision_config=VisionConfig( max_batch_size=8 )) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` 然而,图像批量大小越大,OOM 错误的风险越大,因为 VLM 模型中的 LLM 部分会提前预分配大量的内存。 VLM pipeline 对于推理引擎的选择方式与 LLM pipeline 类似。你可以参考 [LLM 推理](#llm-推理)并结合两个引擎支持的 VLM 模型列表,手动选择和配置推理引擎。 ## 模型服务 类似前文[离线批量推理](#离线批处理),我们在本章节介绍 LLM 和 VLM 各自构建服务方法。 ### LLM 模型服务 ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat ``` 此命令将在本地主机上的端口 `23333` 启动一个与 OpenAI 接口兼容的模型推理服务。你可以使用 `--server-port` 选项指定不同的服务器端口。 更多选项,请通过运行 `lmdeploy serve api_server --help` 查阅帮助文档。这些选项大多与引擎配置一致。 要访问服务,你可以使用官方的 OpenAI Python 包 `pip install openai`。以下是演示如何使用入口点 v1/chat/completions 的示例: ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": " provide three suggestions about time management"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` 我们鼓励你参考详细指南,了解关于[使用 Docker 部署服务](../llm/api_server.md)、[工具调用](../llm/api_server_tools.md)和其他更多功能的信息。 ### VLM 模型服务 ```shell lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` ```{note} LMDeploy 复用了上游 VLM 仓库的视觉组件。而每个上游的 VLM 模型,它们的视觉模型可能互不相同,依赖库也各有区别。 因此,LMDeploy 决定不在自身的依赖列表中加入上游 VLM 库的依赖。如果你在使用 LMDeploy 推理 VLM 模型时出现 "ImportError" 的问题,请自行安装相关的依赖。 ``` 服务成功启动后,你可以以类似访问 `gptv4` 服务的方式访问 VLM 服务: ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', # A dummy api_key is required base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ## 使用命令行与 LLM 模型对话 LMDeploy 提供了一个非常方便的 CLI 工具,供用户与 LLM 模型进行本地聊天。例如: ```shell lmdeploy chat internlm/internlm2_5-7b-chat --backend turbomind ``` 它的设计目的是帮助用户检查和验证 LMDeploy 是否支持提供的模型,聊天模板是否被正确应用,以及推理结果是否正确。 另外,`lmdeploy check_env` 收集基本的环境信息。在给 LMDeploy 提交问题报告时,这非常重要,因为它有助于我们更有效地诊断和解决问题。 如果你对它们的使用方法有任何疑问,你可以尝试使用 `--help` 选项获取详细信息。 ================================================ FILE: docs/zh_cn/get_started/index.rst ================================================ 其他软硬件平台 ================================= .. toctree:: :maxdepth: 1 :caption: OtherPF ascend/get_started.md maca/get_started.md camb/get_started.md ================================================ FILE: docs/zh_cn/get_started/installation.md ================================================ # 安装 LMDeploy 是一个用于大型语言模型(LLMs)和视觉-语言模型(VLMs)压缩、部署和服务的 Python 库。 其核心推理引擎包括 TurboMind 引擎和 PyTorch 引擎。前者由 C++ 和 CUDA 开发,致力于推理性能的优化,而后者纯 Python 开发,旨在降低开发者的门槛。 LMDeploy 支持在 Linux 和 Windows 平台上部署 LLMs 和 VLMs,最低要求 CUDA 版本为 11.3。此外,它还与以下 NVIDIA GPU 兼容: Volta(sm70): V100 Turing(sm75): 20 系列,T4 Ampere(sm80,sm86): 30 系列,A10, A16, A30, A100 Ada Lovelace(sm89): 40 系列 ## 使用 pip 安装(推荐) 我们推荐在一个干净的conda环境下(python3.9 - 3.13),安装 lmdeploy: ```shell conda create -n lmdeploy python=3.10 -y conda activate lmdeploy pip install lmdeploy ``` 默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy: ```shell export LMDEPLOY_VERSION=0.12.2 export PYTHON_VERSION=310 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` ## 从源码安装 默认情况下,LMDeploy 将面向 NVIDIA CUDA 环境进行编译安装,并同时启用 Turbomind 和 PyTorch 两种后端引擎。在安装 LMDeploy 之前,请确保已成功安装 CUDA 工具包。 成功安装 CUDA 工具包后,您可以使用以下单行命令构建并安装 LMDeploy: ```shell pip install git+https://github.com/InternLM/lmdeploy.git ``` 您还可以通过设置 `DISABLE_TURBOMIND` 环境变量,显式禁用 Turbomind 后端,以避免 CUDA 编译: ```shell DISABLE_TURBOMIND=1 pip install git+https://github.com/InternLM/lmdeploy.git ``` 如果您希望使用特定版本,而不是 LMDeploy 的 `main` 分支,可以在命令行中指定: ```shell pip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.11.0.zip ``` 如果您希望构建支持昇腾、寒武纪或沐熙的 LMDeploy,请使用相应的 `LMDEPLOY_TARGET_DEVICE` 环境变量进行安装。 LMDeploy 也支持在 AMD GPU 的 ROCm 环境中安装。 ```shell #The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies: docker run -it \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --device=/dev/kfd \ --device=/dev/dri \ --group-add video \ --ipc=host \ --network=host \ --shm-size 32G \ -v /root:/workspace \ rocm/pytorch:latest #Once inside the container, install LMDeploy with ROCm support: LMDEPLOY_TARGET_DEVICE=rocm pip install git+https://github.com/InternLM/lmdeploy.git ``` ================================================ FILE: docs/zh_cn/get_started/maca/get_started.md ================================================ # 沐曦C500 我们基于 LMDeploy 的 PytorchEngine,增加了沐曦C500设备的支持。所以,在沐曦上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前,请先阅读原版的[快速开始](../get_started.md)。 支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台). > \[!IMPORTANT\] > 我们已经在阿里云上提供了构建完成的沐曦的镜像。 > 请使用下面的命令来拉取镜像: > `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest` ## 离线批处理 ### LLM 推理 将`device_type="maca"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=PytorchEngineConfig(tp=1, device_type="maca")) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question) print(response) ``` ### VLM 推理 将`device_type="maca"`加入`PytorchEngineConfig`的参数中。 ```python from lmdeploy import pipeline, PytorchEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-2B', backend_config=PytorchEngineConfig(tp=1, device_type='maca')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## 在线服务 ### LLM 模型服务 将`--device maca`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat ``` 也可以运行以下命令启动容器运行LLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat" ``` ### VLM 模型服务 将`--device maca`加入到服务启动命令中。 ```bash lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B ``` 也可以运行以下命令启动容器运行VLM模型服务。 ```bash docker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B" ``` ## 使用命令行与LLM模型对话 将`--device maca`加入到服务启动命令中。 ```bash lmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device maca ``` 也可以运行以下命令使启动容器后开启lmdeploy聊天 ```bash docker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \     bash -i -c "lmdeploy chat --backend pytorch --device maca internlm/internlm2_5-7b-chat" ``` ================================================ FILE: docs/zh_cn/index.rst ================================================ 欢迎来到 LMDeploy 的中文教程! ==================================== .. figure:: ./_static/image/lmdeploy-logo.svg :width: 50% :align: center :alt: LMDeploy :class: no-scaled-link .. raw:: html

LMDeploy 是一个高效且友好的 LLMs 模型部署工具箱,功能涵盖了量化、推理和服务

Star Watch Fork

LMDeploy 工具箱提供以下核心功能: - **高效的推理:** LMDeploy 开发了 Persistent Batch(即 Continuous Batch),Blocked K/V Cache,动态拆分和融合,张量并行,高效的计算 kernel等重要特性。推理性能是 vLLM 的 1.8 倍 - **可靠的量化:** LMDeploy 支持权重量化和 k/v 量化。4bit 模型推理效率是 FP16 下的 2.4 倍。量化模型的可靠性已通过 OpenCompass 评测得到充分验证。 - **便捷的服务:** 通过请求分发服务,LMDeploy 支持多模型在多机、多卡上的推理服务。 - **卓越的兼容性:** LMDeploy 支持 `KV Cache 量化 `_, `AWQ `_ 和 `Automatic Prefix Caching `_ 同时使用。 中文文档 -------- .. _快速上手: .. toctree:: :maxdepth: 2 :caption: 快速上手 get_started/installation.md get_started/get_started.md get_started/index.rst .. _支持的模型: .. toctree:: :maxdepth: 1 :caption: 模型列表 supported_models/supported_models.md supported_models/reward_models.md .. _llm_部署: .. toctree:: :maxdepth: 1 :caption: 大语言模型(LLMs)部署 llm/pipeline.md llm/api_server.md llm/api_server_tools.md llm/api_server_reasoning.md llm/api_server_lora.md llm/proxy_server.md .. _vlm_部署: .. toctree:: :maxdepth: 1 :caption: 视觉-语言模型(VLMs)部署 multi_modal/vl_pipeline.md multi_modal/api_server_vl.md multi_modal/index.rst .. _量化: .. toctree:: :maxdepth: 1 :caption: 量化 quantization/w4a16.md quantization/w8a8.md quantization/kv_quant.md quantization/llm_compressor.md .. _测试基准: .. toctree:: :maxdepth: 1 :caption: 测试基准 benchmark/benchmark.md benchmark/evaluate_with_opencompass.md benchmark/evaluate_with_vlmevalkit.md .. toctree:: :maxdepth: 1 :caption: 进阶指南 inference/turbomind.md inference/pytorch.md advance/pytorch_new_model.md advance/long_context.md advance/chat_template.md advance/debug_turbomind.md advance/structed_output.md advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md advance/context_parallel.md advance/spec_decoding.md advance/update_weights.md .. toctree:: :maxdepth: 1 :caption: API 文档 api/pipeline.rst api/openapi.rst api/cli.rst 索引与表格 ================== * :ref:`genindex` * :ref:`search` * :ref:`routingtable` ================================================ FILE: docs/zh_cn/inference/load_hf.md ================================================ # 直接读取 huggingface 模型 从 v0.1.0 开始,Turbomid 添加了直接读取 Huggingface 格式权重的能力。 ## 支持的类型 目前,TurboMind 支持加载三种类型的模型: 1. 在 huggingface.co 上面通过 lmdeploy 量化的模型,如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit) 2. huggingface.co 上面其他 LM 模型,如Qwen/Qwen-7B-Chat ## 使用方式 ### 1) 通过 lmdeploy 量化的模型 对于通过 `lmdeploy.lite` 量化的模型,TurboMind 可以直接加载,比如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit). ``` repo_id=internlm/internlm-chat-20b-4bit model_name=internlm-chat-20b # or # repo_id=/path/to/downloaded_model # Inference by TurboMind lmdeploy chat $repo_id --model-name $model_name # Serving with Restful API lmdeploy serve api_server $repo_id --model-name $model_name --tp 1 ``` ### 2) 其他的 LM 模型 其他 LM 模型比如 Qwen/Qwen-7B-Chat, baichuan-inc/Baichuan2-7B-Chat。LMDeploy 模型支持情况可通过 `lmdeploy list` 查看。 ``` repo_id=Qwen/Qwen-7B-Chat model_name=qwen-7b # or # repo_id=/path/to/Qwen-7B-Chat/local_path # Inference by TurboMind lmdeploy chat $repo_id --model-name $model_name # Serving with Restful API lmdeploy serve api_server $repo_id --model-name $model_name --tp 1 ``` ================================================ FILE: docs/zh_cn/inference/pytorch.md ================================================ # lmdeploy.pytorch 架构 `lmdeploy.pytorch` 是 LMDeploy 提供的推理后端之一。与着重于性能的 turbomind 相比,lmdeploy.pytorch 以较小的性能开销为代价,提供了一套更容易开发与扩展的大模型推理实现。 ## 设计 ![pytorch arch](https://github.com/grimoire/lmdeploy/blob/media/lmdeploy_pytorch_arch.png?raw=true) ## API lmdeploy.pytorch 可以与 turbomind 共享同样的服务接口,这些服务接口通过 Engine 与 EngineInstance 与 lmdeploy.pytorch 进行交互。 EngineInstance 是推理请求的发起者,它会将推理请求组织成特定格式发送给 Engine,以此实现流式推理。EngineInstance 的推理接口是线程安全的,服务发起者可以在不同线程中启动各自的 EngineInstance,Engine 回根据当前资源与推理请求自动进行 batch 化处理。 Engine 是推理请求的接收与执行者。它包含如下的组件来完成这项任务: - ModelAgent 对象负责模型的加载、缓存管理以及 tensor parallelism 的管理。 - Scheduler 对象负责 session 的管理,sequence 与 lora adapter 所需要的资源的分配。 - RequestManager 负责请求的发送与接收,可以通过它与 EngineInstance 交互。 ## Engine 为了应对异步推理请求,Engine 在启动后会维护一个线程,循环如下操作: 1. 通过 RequestManager 读取请求,对各种请求进行分类处理。 2. Scheduler 规划哪些请求可以被处理,以及它们所需的缓存和 adapters。 3. ModelAgent 根据步骤 2. 得到的信息为输入分配资源,然后使用 patch 后的模型进行推理 4. Scheduler 根据推理结果更新请求状态 5. RequestManager 将输出返回给发送者(EngineInstance),回到步骤 1. 下面我们将介绍上述步骤中用到的几个重要组件 ### Scheduler 在进行大模型的推理时,通常会把 attention 的历史输入 key 和 value 缓存起来,以避免在未来的推理中进行重复计算。这种情况下如果要进行多 batch 的推理,由于不同数据的序列长度可能不同,kv 会进行大量的填充,浪费很多显存资源,也限制了模型的并发推理能力上限。 [vLLM](https://docs.vllm.ai) 提了一种 paging 策略,以 page block 为单位为 key value 分配缓存,这样就可以避免由于 padding 导致的显存浪费。 lmdeploy.pytorch 中的 Scheduler 也遵循同样的设计,根据请求的长度合理分配所需的资源,并撤出暂时不使用的资源以保证存储资源的高效利用。 lmdeploy.pytorch 还对 [S-LoRA](https://github.com/S-LoRA/S-LoRA) 的支持,S-LoRA 是一种对单模型多 adapter 的支持方案。LoRA 在推理时通常会把 adapter 融合进模型权重当中,同时使用复数个 adapter 会导致显存使用量的激增;S-LoRA 不对 adapter 进行融合,通过使用 unified paging,在推理时动态换入需要使用的 adapter,大幅降低了使用 adapter 的显存开销。Scheduler 中也实现了相关的功能,让用户可以更方便的使用自己的 adapter. ### ModelAgent lmdeploy.pytorch 中对 Tensor Parallelism(TP)进行了支持,不同的 TP 参数对模型的构造、权重处理、分配 cache 都存在影响。ModelAgent 对这些内容进行了封装,让 Engine 不用再关心这部分细节。 ModelAgent 有两个重要组件: 1. patched_model 是更新后的 transformer 模型,更新后的模型添加了各种功能的支持,包括更高性能的子模块实现、TP、量化等等 2. cache_engine 是缓存的分配与交换模块。它接收来自 scheduler 的交换请求,执行 host-device 间显存交换,adapter 加载等工作 ## 特性 - **Continuous Batching**: 由于输入序列的长度不一样,batching 通常需要对输入进行 padding,这种 padding 会导致后续运算的计算量增加、影响速度,也会使得显存的占用大幅增加。遵循许多其他成熟框架的方案,lmdeploy.pytorch 采用了 continuous batching 的方式对输入做了连续化处理,避免了多余的资源占用。 - **Tensor Parallelism**: 大模型可能会占用远超一张显卡的显存量,为了支持这样的大模型的推理,我们实现了 Tensor 并发,模型的权重会被分布在不同的设备中,每张 GPU 设备负责一部分计算,减少了单卡显存占用,也充分利用了多显卡的计算优势。 - **S-LoRA**: LoRA adapter 可以帮助我们使用有限的显存来调优大模型,S-LoRA 可以帮助我们在有限的显存中同时使用复数个 LoRA 权重,扩展模型的能力。 - **Quantization**: 量化可以帮助我们进一步减少显存占用,提高推理性能。lmdeploy.pytorch 分支中添加了 w8a8 模型量化的支持,可以阅读 [w8a8](../quantization/w8a8.md) 了解更多细节。 ================================================ FILE: docs/zh_cn/inference/turbomind.md ================================================ # TurboMind 框架 TurboMind 是一款关于 LLM 推理的高效推理引擎,基于英伟达的 [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) 研发而成。它的主要功能包括:LLaMa 结构模型的支持,persistent batch 推理模式和可扩展的 KV 缓存管理器。 ## TurboMind 结构 ``` +--------------------+ | API | +--------------------+ | ^ 请 求 | | 流式回调 v | +--------------------+ 获取 +-------------------+ | Persistent Batch | <-------> | KV Cache 管理器 | +--------------------+ 更新 +-------------------+ ^ | v +------------------------+ | LLaMa推理实现 | +------------------------+ | FT kernels & utilities | +------------------------+ ``` ## Persistent Batch 你也许在别的项目中看到这项机制的另一个名字: `continuous batching` 。在开发这个功能时,我们将对话式 LLM 的推理建模为一个持续运行的 batch ,其生命周期跨越整个服务过程,故将其命名为 `persistent batch` 。简单来说是这样实现的: - 该功能会预先准备好 N 个 batch slots。 - 当有空闲 slots 时, 请求就会加入到 batch 中。当请求对应的 tokens 都生成完毕后,对应的 batch slot 会立刻被释放,接收新的请求。 - **当一个 sequence 命中缓存时(见下文),它的历史 token 不必在每轮中都进行解码,所以它的 token 生成过程会即刻开始**。 - 整个 batch 会自动扩缩容来避免不必要的计算。 ## KV 缓存管理器 TurboMind 的 [KV 缓存管理器](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/SequenceManager.h) 是一个内存池类型的对象,并且在其中加入了 LRU 的实现,这样整个管理器可以被看作是一个 **KV 缓存的缓存**。大致工作方式如下: - KV 缓存由管理器分配。管理器会根据预先配置好的 slot 数量开辟空间。每个 slot 对应于一个 sequence 所需的 KV 缓存。分配的内存块大小可通过配置来实现预分配或者按需分配(或介于两者之间)。 - 当有新的请求,但是缓存池中没有空闲 slot时,根据 LRU 机制,管理器会踢除最近使用最少的 sequence,把它占据的 slot 分给新的请求。不仅仅如此, - sequence获取到了slot,类似缓存命中。它在缓存中的历史KV会被直接返回,而不用再进行context decoding 。 - 被踢除的 sequences 不会被完全的删除,而是会被转换成最简洁的形式,例如 token IDs 。当之后获取到相同的 sequence id 时 (即 _cache-miss_ 状态),这些 token IDs 将被 FMHA 的 context decoder 解码并被转回 KV 缓存。 - 踢除和转换均由 TurboMind 内部自动管理所以对用户来说是透明的。__从用户的使用角度来看,使用了 TurboMind 的系统就像是可以访问无限的设备内存__。 ## TurboMind 的 LLaMa 实现 我们对 LLaMa 系列模型的实现是从 FasterTransformer 中的 Gpt-NeX 模型修改而来的。除了对 LLaMa 系列进行基本重构和修改外,我们还做了一些改进以实现会话模型的高性能推理,其中最重要的是: - 支持多轮对话中的快速文本解码。我们用基于 [cutlass](https://github.com/NVIDIA/cutlass) 的 FMHA 实现替代了 context decoder 中的注意力机制实现,从而支持了 Q/K 长度不匹配的情况。 - 我们在 context FMHA 和 generation FMHA 中都加入了间接缓冲指针,支持 batch 中不连续的 KV 缓存。 - 为了支持 persistent batch 的并发推理,我们设计了新的同步机制来协调在张量并型模式下的工作线程。 - 我们实现了 INT8 KV cache,降低了内存开销,提高了批处理大小和系统吞吐量。这在实际场景中非常有用,因为相比权重和其他激活,KV cache 会消耗更多的内存和内存带宽。 - 我们解决了单个进程内多个模型实例在 TP 模式下运行时 NCCL 卡住的问题。NCCL APIs 现由 host 端的同步 barriers 保护。 ## API TurboMind 的 Python API 支持流式结果返回和张量并行模式。 ## TurboMind 和 FasterTransformer 的区别 除了上文中提到的功能外,TurboMind 相较于 FasterTransformer 还有不少差别。譬如不少 FasterTransformer 的功能在 TurboMind 中都被去掉了,这其中包括前缀提示词、 beam search 、上下文 embedding、稀疏化 GEMM 操作和对应 GPT 或 T5 等结构的模型的支持等等。 ## FAQ ### 对 Huggingface 模型的支持 因为历史因素, TurboMind 的权重设计是基于 [LLaMa 的官方实现](https://github.com/facebookresearch/llama) 完成的,两者只相差一个转置操作。但是 Huggingface 版本的实现却是[另一种形式](https://github.com/huggingface/transformers/blob/45025d92f815675e483f32812caa28cce3a960e7/src/transformers/models/llama/convert_llama_weights_to_hf.py#L123C76-L123C76),两种权重实现方式在 `W_q` 和 `W_k` 上的区别我们在 [deploy.py](https://github.com/InternLM/lmdeploy/blob/ff4648a1d09e5aec74cf70efef35bfaeeac552e0/lmdeploy/serve/turbomind/deploy.py#L398) 进行了适配处理,用户可前往查看。 ================================================ FILE: docs/zh_cn/inference/turbomind_config.md ================================================ # TurboMind 配置 TurboMind 是 LMDeploy 的推理引擎,在用它推理 LLM 模型时,需要把输入模型转成 TurboMind 模型。在 TurboMind 的模型文件夹中,除模型权重外,TurboMind 模型还包括其他一些文件,其中最重要的是和推理性能息息相关的配置文件`triton_models/weights/config.ini`。 如果你使用的是 LMDeploy 0.0.x 版本,请参考[turbomind 1.0 配置](#turbomind-10-配置)章节,了解配置中的相关内容。如果使用的是 LMDeploy 0.1.x 版本,请阅读[turbomind 2.x 配置](#turbomind-2x-配置)了解配置细节。 ## TurboMind 2.x 配置 以 `llama-2-7b-chat` 模型为例,在 TurboMind 2.x 中,它的`config.ini`内容如下: ```toml [llama] model_name = "llama2" tensor_para_size = 1 head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 session_len = 4104 weight_type = "fp16" rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 group_size = 0 max_batch_size = 64 max_context_token_num = 1 step_length = 1 cache_max_entry_count = 0.5 cache_block_seq_len = 128 cache_chunk_size = 1 enable_prefix_caching = false quant_policy = 0 max_position_embeddings = 2048 rope_scaling_factor = 0.0 use_logn_attn = 0 ``` 这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改** ```toml model_name = "llama2" head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 ``` 和 TurboMind 1.0 config 相比,TurboMind 2.x config 中的模型属性部分和 1.0 一致,但推理参数发生了变化。 在接下来的章节中,我们重点介绍推理参数。 ### 数据类型 和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。 `weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`。 ### 批处理大小 仍通过 `max_batch_size` 设置最大批处理量。默认值由原来的 32 改成 64。 在 TurboMind 2.x 中,`max_batch_size` 和 `cache_max_entry_count`无关。 ### k/v 缓存大小 `cache_block_seq_len` 和 `cache_max_entry_count` 用来调节 k/v cache 的内存大小。 TurboMind 2.x 实现了 Paged Attention,按块管理 k/v cache。 `cache_block_seq_len` 表示一块 k/v block 可以存放的 token 序列长度,默认 128。TurboMind 按照以下公式计算 k/v block 的内存大小: ``` cache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type) ``` 对于 llama2-7b 模型来说,以 half 类型存放 k/v 时,一块 k/v block 的内存为:`128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB` `cache_max_entry_count` 根据取值不同,表示不同的含义: - 当值为 (0, 1) 之间的小数时,`cache_max_entry_count` 表示 k/v block 使用的内存百分比。比如 A100-80G 显卡内存是80G,当`cache_max_entry_count`为0.5时,表示 k/v block 使用的内存总量为 80 * 0.5 = 40G - 当 lmdeploy 版本大于 0.2.1 时,`cache_max_entry_count` 将**空闲**内存的百分比用于 k/v blocks,默认值为 `0.8`。例如,在 A100-80G GPU 上运行 Turbomind 加载 13b 模型时,k/v blocks 使用的内存为 `(80 - 26) * 0.8 = 43.2G`,即利用剩余 54G 中的 80% - 当值为 > 1的整数时,表示 k/v block 数量 `cache_chunk_size` 表示在每次需要新的 k/v cache 块时,开辟 k/v cache 块的大小。不同的取值,表示不同的含义: - 当为 > 0 的整数时,开辟 `cache_chunk_size` 个 k/v cache 块 - 当值为 -1 时,开辟 `cache_max_entry_count` 个 k/v cache 块 - 当值为 0 时,时,开辟 `sqrt(cache_max_entry_count)` 个 k/v cache 块 ### 前缀缓存开关 `enable_prefix_caching`是前缀缓存(Prefix Caching)功能的开关。值为`True`时表示开启,`False`表示关闭,默认为`False`。 前缀缓存功能主要适用于多个请求具有相同的prompt前缀(比如system prompt)的场景,该相同前缀部分的 k/v block 会被缓存起来,被多个请求重复利用,从而节省了重复计算的开销,提高推理性能。相同prompt前缀长度越长,性能提升越大。 由于前缀缓存对 k/v 重复利用的最小粒度是block,如果相同prompt前缀不足一个block(前缀长度\<`cache_block_seq_len`),则推理性能不会有提升。 ### kv 量化推理开关 `quant_policy`是 kv 量化和推理开关。 - `quant_policy=4` 代表 4bit k/v 量化和推理 - `quant_policy=8` 代表 8bit k/v 量化和推理 具体使用方法,请参考 [kv quant](../quantization/kv_quant.md) 部署文档 ### 外推能力开关 默认 `rope_scaling_factor = 0` 不具备外推能力。设置为 1.0,可以开启 RoPE 的 Dynamic NTK 功能,支持长文本推理。 关于 Dynamic NTK 的原理,详细请参考: 1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases 2. https://kexue.fm/archives/9675 设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。 ## TurboMind 1.0 配置 以 `llama-2-7b-chat` 模型为例,在 TurboMind 1.0 中,它的`config.ini`内容如下: ```toml [llama] model_name = "llama2" tensor_para_size = 1 head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 session_len = 4104 weight_type = "fp16" rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 group_size = 0 max_batch_size = 32 max_context_token_num = 4 step_length = 1 cache_max_entry_count = 48 cache_chunk_size = 1 use_context_fmha = 1 quant_policy = 0 max_position_embeddings = 2048 use_dynamic_ntk = 0 use_logn_attn = 0 ``` 这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改** ```toml model_name = "llama2" head_num = 32 kv_head_num = 32 vocab_size = 32000 num_layer = 32 inter_size = 11008 norm_eps = 1e-06 attn_bias = 0 start_id = 1 end_id = 2 rotary_embedding = 128 rope_theta = 10000.0 size_per_head = 128 ``` 在接下来的章节中,我们重点介绍推理参数。 ### 数据类型 和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。 `weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`。 ### 批处理大小 可通过`max_batch_size`调节推理时最大的 batch 数。一般,batch 越大吞吐量越高。但务必保证 `max_batch_size <= cache_max_entry_count` ### k/v cache 大小 TurboMind 根据 `session_len`、 `cache_chunk_size` 和 `cache_max_entry_count` 开辟 k/v cache 内存。 - `session_len` 表示一个序列的最大长度,即 context window 的大小。 - `cache_chunk_size` 表示当新增对话序列时,每次要开辟多少个序列的 k/v cache - `cache_max_entry_count` 表示最多缓存多少个对话序列 ### kv int8 开关 当启动 8bit k/v 推理时,需要修改参数 `quant_policy` 和 `use_context_fmha`。详细内容请查阅 [kv int8](../quantization/kv_quant.md) 部署文档。 ### 外推能力开关 设置 `use_dynamic_ntk = 1`,可以开启 RoPE 的 Dynamic NTK 选项,支持长文本推理。 关于 Dynamic NTK 的原理,详细请参考: 1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases 2. https://kexue.fm/archives/9675 设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。 ================================================ FILE: docs/zh_cn/llm/api_server.md ================================================ # 部署 LLM 类 openai 服务 本文主要介绍单个模型在单机多卡环境下,部署兼容 openai 接口服务的方式,以及服务接口的用法。为行文方便,我们把该服务名称为 `api_server`。对于多模型的并行服务,请阅读[请求分发服务器](./proxy_server.md)一文。 在这篇文章中, 我们首先介绍服务启动的两种方法,你可以根据应用场景,选择合适的。 其次,我们重点介绍服务的 RESTful API 定义,以及接口使用的方式,并展示如何通过 Swagger UI、LMDeploy CLI 工具体验服务功能 ## 启动服务 以 huggingface hub 上的 [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) 模型为例,你可以任选以下方式之一,启动推理服务。 ### 方式一:使用 lmdeploy cli 工具 ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 ``` api_server 启动时的参数可以通过命令行`lmdeploy serve api_server -h`查看。 比如,`--tp` 设置张量并行,`--session-len` 设置推理的最大上下文窗口长度,`--cache-max-entry-count` 调整 k/v cache 的内存使用比例等等。 ### 方式二:使用 docker 使用 LMDeploy 官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags),可以运行兼容 OpenAI 的服务。下面是使用示例: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server internlm/internlm2_5-7b-chat ``` 在这个例子中,`lmdeploy server api_server` 的命令参数与方式一一致。 每个模型可能需要 Docker 映像中未包含的特定依赖项。如果遇到问题,您可能需要根据具体情况自行安装这些依赖项。如有疑问,请参阅特定模型的项目以获取文档。 例如,对于 Llava ``` FROM openmmlab/lmdeploy:latest RUN apt-get update && apt-get install -y python3 python3-pip git WORKDIR /app RUN pip3 install --upgrade pip RUN pip3 install timm RUN pip3 install git+https://github.com/haotian-liu/LLaVA.git --no-deps COPY . . CMD ["lmdeploy", "serve", "api_server", "liuhaotian/llava-v1.6-34b"] ``` ### 方式三:部署到Kubernetes集群 使用[kubectl](https://kubernetes.io/docs/reference/kubectl/)命令行工具,连接到一个运行中Kubernetes集群并部署internlm2_5-7b-chat模型服务。下面是使用示例(需要替换``为你的huggingface hub token): ```shell sed 's/{{HUGGING_FACE_HUB_TOKEN}}//' k8s/deployment.yaml | kubectl create -f - \ && kubectl create -f k8s/service.yaml ``` 示例中模型数据来源于node上的本地磁盘(hostPath),多副本部署时考虑替换为高可用共享存储,通过[PersistentVolume](https://kubernetes.io/docs/concepts/storage/persistent-volumes/)方式挂载到容器中。 ## RESTful API LMDeploy 的 RESTful API 兼容了 OpenAI 以下 3 个接口: - /v1/chat/completions - /v1/models - /v1/completions 服务启动后,你可以在浏览器中打开网页 http://0.0.0.0:23333,通过 Swagger UI 查看接口的详细说明,并且也可以直接在网页上操作,体验每个接口的用法,如下图所示。 ![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) 若需要把服务集成到自己的项目或者产品中,我们推荐以下用法: ### 使用 openai 接口 以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前,请先安装 openai 包: `pip install openai`。 ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": " provide three suggestions about time management"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` 如果你想使用异步的接口,可以尝试下面的例子: ```python import asyncio from openai import AsyncOpenAI async def main(): client = AsyncOpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_cards = await client.models.list()._get_page() response = await client.chat.completions.create( model=model_cards.data[0].id, messages=[ { 'role': 'system', 'content': 'You are a helpful assistant.' }, { 'role': 'user', 'content': ' provide three suggestions about time management' }, ], temperature=0.8, top_p=0.8) print(response) asyncio.run(main()) ``` 关于其他 openai 接口的调用,也可以如法炮制。详情请参考 openai 官方[文档](https://platform.openai.com/docs/guides/text-generation) ### 使用 lmdeploy `APIClient` 接口 如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient(f'http://{server_ip}:{server_port}') model_name = api_client.available_models[0] messages = [{"role": "user", "content": "Say this is a test!"}] for item in api_client.chat_completions_v1(model=model_name, messages=messages): print(item) ``` 如果你想用 `/v1/completions` 接口,你可以尝试: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient(f'http://{server_ip}:{server_port}') model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` ### 工具调用 参考 [api_server_tools](./api_server_tools.md)。 ### 使用 Java/Golang/Rust 可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。 下面是一个使用示例: ```shell $ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust $ ls rust/* rust/Cargo.toml rust/git_push.sh rust/README.md rust/docs: ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md rust/src: apis lib.rs models ``` ### 使用 cURL cURL 也可以用于查看 API 的输出结果 - 查看模型列表 `v1/models` ```bash curl http://{server_ip}:{server_port}/v1/models ``` - 对话 `v1/chat/completions` ```bash curl http://{server_ip}:{server_port}/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "internlm-chat-7b", "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` - 文本补全 `v1/completions` ```shell curl http://{server_ip}:{server_port}/v1/completions \ -H 'Content-Type: application/json' \ -d '{ "model": "llama", "prompt": "two steps to build a house:" }' ``` ## 同时启动多个 api_server 两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后: 1. 启动代理服务 `lmdeploy serve proxy`。 2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**: 多机多卡不要用默认 url `0.0.0.0:8000`,我们需要输入真实ip对应的地址,如:`11.25.34.55:8000`。多机情况下,因为不需要子节点间的通信,所以并不需要用户指定 torchrun 的 `--nnodes` 等参数,只要能保证每个节点执行一次单节点的 torchrun 就行。 ```python import os import socket from typing import List, Literal import fire def get_host_ip(): try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(('8.8.8.8', 80)) ip = s.getsockname()[0] finally: s.close() return ip def main(model_path: str, tp: int = 1, proxy_url: str = 'http://0.0.0.0:8000', port: int = 23333, backend: Literal['turbomind', 'pytorch'] = 'turbomind'): local_rank = int(os.environ.get('LOCAL_RANK', -1)) world_size = int(os.environ.get('WORLD_SIZE', -1)) local_ip = get_host_ip() if isinstance(port, List): assert len(port) == world_size port = port[local_rank] else: port += local_rank * 10 if (world_size - local_rank) % tp == 0: rank_list = ','.join([str(local_rank + i) for i in range(tp)]) command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\ f'--server-name {local_ip} --server-port {port} --tp {tp} '\ f'--proxy-url {proxy_url} --backend {backend}' print(f'running command: {command}') os.system(command) if __name__ == '__main__': fire.Fire(main) ``` ### 示例 为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例: ```shell #!/bin/bash # 激活 conda 环境 source /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env export HOME=/path/to/your/home # 获取主节点IP地址(假设 MLP_WORKER_0_HOST 是主节点的IP) MASTER_IP=${MLP_WORKER_0_HOST} # 检查是否为主节点 if [ "${MLP_ROLE_INDEX}" -eq 0 ]; then # 启动 lmdeploy serve proxy 并放入后台 echo "Starting lmdeploy serve proxy on master node..." PROXY_PORT=8000 lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} & else # 这里我们默认调度平台同时启动了所有机器,否则要sleep一会,等待 proxy 启动成功 echo "Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}." fi # 启动 torchrun 并放入后台 # 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数,相当于每个机器上执行一次单节点的 torchrun 即可。 torchrun \ --nproc_per_node=${MLP_WORKER_GPU} \ /path/to/script.py \ InternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT} # 打印主机的IP地址 echo "Host IP addresses:" hostname -I ``` ## FAQ 1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。如需调整会话支持的最大长度,可以通过启动`api_server`时,设置`--session_len`参数大小。 2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `backend_config` 的 `cache_max_entry_count` 大小 3. 关于停止符,我们只支持编码后为单个 index 的字符。此外,可能存在多种 index 都会解码出带有停止符的结果。对于这种情况,如果这些 index 数量太多,我们只会采用 tokenizer 编码出的 index。而如果你想要编码后为多个 index 的停止符,可以考虑在流式客户端做字符串匹配,匹配成功后跳出流式循环即可。 4. 自定义对话模板,请参考[chat_template.md](../advance/chat_template.md) ================================================ FILE: docs/zh_cn/llm/api_server_lora.md ================================================ # LoRA 推理服务 ## 启动 LoRA 服务 LoRA 目前只有 pytorch 后端支持。它的服务化,和其他模型服务化一样,命令都可以用 `lmdeploy serve api_server -h` 查看。其中 pytorch 后端支持的参数就有 LoRA 的配置内容。 ``` PyTorch engine arguments: --adapters [ADAPTERS [ADAPTERS ...]] Used to set path(s) of lora adapter(s). One can input key-value pairs in xxx=yyy format for multiple lora adapters. If only have one adapter, one can only input the path of the adapter.. Default: None. Type: str ``` 用户只需要将 lora 权重的 huggingface 模型路径通过字典的形式传入 `--adapters` 即可。 ```shell lmdeploy serve api_server THUDM/chatglm2-6b --adapters mylora=chenchi/lora-chatglm2-6b-guodegang ``` 服务启动后,可以在 Swagger UI 中查询到两个可用的模型名字:“THUDM/chatglm2-6b” 和 “mylora”。后者是 `--adapters` 字典的 key。 ## 客户端使用 ### CLI 使用时,OpenAI 接口参数 `model` 可以用来选择使用基础模型还是某个 lora 权重用于推理。下面的例子就选择使用了传入的 `chenchi/lora-chatglm2-6b-guodegang` 用于推理。 ```shell curl -X 'POST' \ 'http://localhost:23334/v1/chat/completions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "model": "mylora", "messages": [ { "content": "hi", "role": "user" } ] }' ``` 可以得到一个这个 lora 权重特有的回复: ```json { "id": "2", "object": "chat.completion", "created": 1721377275, "model": "mylora", "choices": [ { "index": 0, "message": { "role": "assistant", "content": " 很高兴哪有什么赶凳儿?(按东北语说的“起早哇”),哦,东北人都学会外语了?", "tool_calls": null }, "logprobs": null, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 17, "total_tokens": 43, "completion_tokens": 26 } } ``` ### python ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = 'mylora' response = client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": "hi"}, ], temperature=0.8, top_p=0.8 ) print(response) ``` 打印的响应内容为: ``` ChatCompletion(id='4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' 很高兴能够见到你哪,我也在辐射区开了个愣儿,你呢,还活着。', role='assistant', function_call=None, tool_calls=None))], created=1721377497, model='mylora', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=22, prompt_tokens=17, total_tokens=39)) ``` ================================================ FILE: docs/zh_cn/llm/api_server_reasoning.md ================================================ # Reasoning Outputs 对于支持推理能力的模型,比如 [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1),LMDeploy 支持在服务中将推理的结果解析出来,并单独用 reasoning_content 记录推理内容。 ## 使用示例 ### DeepSeek R1 我们可以像启动其他模型的 api_server 服务一样启动 DeepSeek R1 的模型,只是不同的是,我们需要指定 `--reasoning-parser`。 在 `--reasoning-parser` 传参里,我们需要指定具体的 parser。 ``` lmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1 ``` 然后,我们就可以在客户端调用这个服务的功能: ```python from openai import OpenAI openai_api_key = "Your API key" openai_api_base = "http://0.0.0.0:23333/v1" client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) models = client.models.list() model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] response = client.chat.completions.create(model=model, messages=messages, stream=True) for stream_response in response: print('reasoning content: ',stream_response.choices[0].delta.reasoning_content) print('content: ', stream_response.choices[0].delta.content) response = client.chat.completions.create(model=model, messages=messages, stream=False) reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content print("reasoning_content:", reasoning_content) print("content:", content) ``` ## 自定义 parser 只需要在 `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py` 中添加一个类似的 parser 类即可。 ```python # import the required packages from typing import Sequence, Union, Tuple, Optional from lmdeploy.serve.openai.reasoning_parser import ( ReasoningParser, ReasoningParserManager) from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaMessage) # define a reasoning parser and register it to lmdeploy # the name list in register_module can be used # in --reasoning-parser. @ReasoningParserManager.register_module(["example"]) class ExampleParser(ReasoningParser): def __init__(self, tokenizer: object): super().__init__(tokenizer) def extract_reasoning_content_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting reasoning from an incomplete response; for use when handling reasoning calls and streaming. Has to be an instance method because it requires state - the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. Used for non-streaming responses where we have the entire model response available before sending to the client. Args: model_output (str): The model-generated string to extract reasoning content from. request (ChatCompletionRequest): he request object that was used to generate the model_output. Returns: reasoning_content (str | None): The reasoning content. final_output (str | None): The content. """ ``` 类似的,启动服务的命令就变成了: ``` lmdeploy serve api_server $model_path --reasoning-parser example ``` ================================================ FILE: docs/zh_cn/llm/api_server_tools.md ================================================ # Tools LMDeploy 支持 InternLM2, InternLM2.5, Llama3.1 和 Qwen2.5模型的工具调用。请在启动 api_server 的时候使用 `--tool-call-parser` 指定 parser 名字。以下是支持的名字: 1. internlm 2. qwen 3. llama3 ## 单轮调用 启动好模型的服务后,运行下面 demo 即可。 ```python from openai import OpenAI tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, } } ] messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] client = OpenAI(api_key='YOUR_API_KEY',base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) ``` ## 多轮调用 ### InternLM 一个完整的工具链调用过程可以通过下面的例子展示。 ```python from openai import OpenAI def add(a: int, b: int): return a + b def mul(a: int, b: int): return a * b tools = [{ 'type': 'function', 'function': { 'name': 'add', 'description': 'Compute the sum of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }, { 'type': 'function', 'function': { 'name': 'mul', 'description': 'Calculate the product of two numbers', 'parameters': { 'type': 'object', 'properties': { 'a': { 'type': 'int', 'description': 'A number', }, 'b': { 'type': 'int', 'description': 'A number', }, }, 'required': ['a', 'b'], }, } }] messages = [{'role': 'user', 'content': 'Compute (3+5)*2'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) func1_name = response.choices[0].message.tool_calls[0].function.name func1_args = response.choices[0].message.tool_calls[0].function.arguments func1_out = eval(f'{func1_name}(**{func1_args})') print(func1_out) messages.append(response.choices[0].message) messages.append({ 'role': 'tool', 'content': f'3+5={func1_out}', 'tool_call_id': response.choices[0].message.tool_calls[0].id }) response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response) func2_name = response.choices[0].message.tool_calls[0].function.name func2_args = response.choices[0].message.tool_calls[0].function.arguments func2_out = eval(f'{func2_name}(**{func2_args})') print(func2_out) ``` 实际使用 InternLM2-Chat-7B 模型执行上述例子,可以得到下面的结果: ``` ChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"a": 3, "b": 5}', name='add'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=263, total_tokens=288)) 8 ChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='1', function=Function(arguments='{"a": 8, "b": 2}', name='mul'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=293, total_tokens=318)) 16 ``` ### Llama3.1 Meta 在 [Llama3 的官方用户指南](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1)中宣布(注:下文为原文的中文翻译): > 有三个内置工具(brave_search、wolfram_alpha 和 code interpreter)可以使用系统提示词打开: > > 1. Brave Search:执行网络搜索的工具调用。 > 2. Wolfram Alpha:执行复杂数学计算的工具调用。 > 3. Code Interpreter:使模型能够输出 Python 代码的功能。 此外,它还警告说:“注意: 我们建议使用 Llama 70B-instruct 或 Llama 405B-instruct 用于结合对话和工具调用的应用。Llama 8B-Instruct 无法可靠地在工具调用定义的同时维持对话。它可以用于零样本工具调用,但在模型和用户之间的常规对话中,应移除工具指令。”(注:引号中内容为原文的中文翻译) 因此,我们使用 [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) 来展示如何通过 LMDeploy的`api_server`调用模型的工具能力. 在 A100-SXM-80G 节点上,可以按照以下方式启动服务: ```shell lmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4 ``` 有关 api_server 的详细介绍,请参考[此处](./api_server.md)的详细文档。 以下代码示例展示了如何使用 "Wolfram Alpha" 工具。假设你已经在[Wolfram Alpha](https://www.wolframalpha.com) 网站上注册并获取了 API 密钥。请确保拥有一个有效的 API 密钥,以便访问 Wolfram Alpha 提供的服务。 ```python from openai import OpenAI import requests def request_llama3_1_service(messages): client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False) return response.choices[0].message.content # The role of "system" MUST be specified, including the required tools messages = [ { "role": "system", "content": "Environment: ipython\nTools: wolfram_alpha\n\n Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\nYou are a helpful Assistant." # noqa }, { "role": "user", "content": "Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0" # noqa } ] # send request to the api_server of llama3.1-70b and get the response # the "assistant_response" is supposed to be: # <|python_tag|>wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0") assistant_response = request_llama3_1_service(messages) print(assistant_response) # Call the API of Wolfram Alpha with the query generated by the model app_id = 'YOUR-Wolfram-Alpha-API-KEY' params = { "input": assistant_response, "appid": app_id, "format": "plaintext", "output": "json", } wolframalpha_response = requests.get( "https://api.wolframalpha.com/v2/query", params=params ) wolframalpha_response = wolframalpha_response.json() # Append the contents obtained by the model and the wolframalpha's API # to "messages", and send it again to the api_server messages += [ { "role": "assistant", "content": assistant_response }, { "role": "ipython", "content": wolframalpha_response } ] assistant_response = request_llama3_1_service(messages) print(assistant_response) ``` ### Qwen2.5 Qwen2.5 支持了多工具调用,这意味着可以在一次请求中可能发起多个工具请求 ```python from openai import OpenAI import json def get_current_temperature(location: str, unit: str = "celsius"): """Get current temperature at a location. Args: location: The location to get the temperature for, in the format "City, State, Country". unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) Returns: the temperature, the location, and the unit in a dict """ return { "temperature": 26.1, "location": location, "unit": unit, } def get_temperature_date(location: str, date: str, unit: str = "celsius"): """Get temperature at a location and date. Args: location: The location to get the temperature for, in the format "City, State, Country". date: The date to get the temperature for, in the format "Year-Month-Day". unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) Returns: the temperature, the location, the date and the unit in a dict """ return { "temperature": 25.9, "location": location, "date": date, "unit": unit, } def get_function_by_name(name): if name == "get_current_temperature": return get_current_temperature if name == "get_temperature_date": return get_temperature_date tools = [{ 'type': 'function', 'function': { 'name': 'get_current_temperature', 'description': 'Get current temperature at a location.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'unit': { 'type': 'string', 'enum': [ 'celsius', 'fahrenheit' ], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': [ 'location' ] } } }, { 'type': 'function', 'function': { 'name': 'get_temperature_date', 'description': 'Get temperature at a location and date.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' }, 'date': { 'type': 'string', 'description': 'The date to get the temperature for, in the format \'Year-Month-Day\'.' }, 'unit': { 'type': 'string', 'enum': [ 'celsius', 'fahrenheit' ], 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' } }, 'required': [ 'location', 'date' ] } } }] messages = [{'role': 'user', 'content': 'Today is 2024-11-14, What\'s the temperature in San Francisco now? How about tomorrow?'}] client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response.choices[0].message.tool_calls) messages.append(response.choices[0].message) for tool_call in response.choices[0].message.tool_calls: tool_call_args = json.loads(tool_call.function.arguments) tool_call_result = get_function_by_name(tool_call.function.name)(**tool_call_args) messages.append({ 'role': 'tool', 'name': tool_call.function.name, 'content': tool_call_result, 'tool_call_id': tool_call.id }) response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.8, top_p=0.8, stream=False, tools=tools) print(response.choices[0].message.content) ``` 使用Qwen2.5-14B-Instruct,可以得到以下类似结果 ``` [ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"location": "San Francisco, California, USA"}', name='get_current_temperature'), type='function'), ChatCompletionMessageToolCall(id='1', function=Function(arguments='{"location": "San Francisco, California, USA", "date": "2024-11-15"}', name='get_temperature_date'), type='function')] The current temperature in San Francisco, California, USA is 26.1°C. For tomorrow, 2024-11-15, the temperature is expected to be 25.9°C. ``` 需要注意的是,多工具调用的情况下,工具调用的结果顺序会影响回答的效果,tool_call_id并没有正确给到LLM. ================================================ FILE: docs/zh_cn/llm/codellama.md ================================================ # Code Llama ## 模型介绍 [codellama](https://github.com/facebookresearch/codellama) 支持很多种编程语言,包括 Python, C++, Java, PHP, Typescript (Javascript), C#, Bash 等等。具备代码续写、代码填空、对话、python专项等 4 种能力。 它在 [HuggingFace](https://huggingface.co/codellama) 上发布了基座模型,Python模型和指令微调模型: | 基座模型 | Python微调模型 | 指令模型 | | ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | | [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf) | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) | | [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) | | [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) | 模型和能力的对应关系为: | 模型 | 代码续写 | 代码填空 | 对话 | Python专项 | | -------------- | -------- | ----------------- | ---- | ---------- | | 基座模型 | Y | Y(7B,13B), N(34B) | N | N | | Python微调模型 | Y | N | N | Y | | 指令微调模型 | Y | Y(7B,13B), N(34B) | Y | N | ## 推理 根据前文模型的能力表,在本小节中,我们讲通过具体的示例展示使用 CodeLlama 各能力的方法 ### 代码续写 ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='completion' )) response = pipe( 'import socket\n\ndef ping_exponential_backoff(host: str):', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ### 代码填空 ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='infilling' )) prompt = """ def remove_non_ascii(s: str) -> str: \"\"\" \"\"\" return result """ response = pipe( prompt, gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95, max_new_tokens=500 ) ) print(response.text) ``` ### 对话 ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-Instruct-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='chat' )) response = pipe( 'implement quick sort in C++', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ### Python 专项 ```python from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig pipe = pipeline('meta-llama/CodeLlama-7b-Python-hf', chat_template_config=ChatTemplateConfig( model_name='codellama', capability='python' )) response = pipe( 'implement quick sort', gen_config=GenerationConfig( top_k=10, temperature=0.1, top_p=0.95 ) ) print(response.text) ``` ## 量化 TBD ## 服务 准备好对话模板文件,比如说“codellama.json”,参考如下示例,填写 CodeLlama 的能力: ```json { "model_name": "codellama", "capability": "completion" } ``` 然后,启动推理服务: ```shell lmdeploy serve api_server meta-llama/CodeLlama-7b-Instruct-hf --chat-template codellama.json ``` 在服务启动成功后,可以通过`openai`客户端接口,访问服务: ```python from openai import OpenAI client = OpenAI( api_key='YOUR_API_KEY', base_url="http://0.0.0.0:23333/v1" ) model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": "import socket\n\ndef ping_exponential_backoff(host: str):"}, ], temperature=0.1, top_p=0.95, max_tokens=500 ) print(response) ``` 关于 api_server 的详细介绍,请参考[这份](../llm/api_server.md)文档。 ================================================ FILE: docs/zh_cn/llm/pipeline.md ================================================ # LLM 离线推理 pipeline 本文通过一些例子展示 pipeline 的基本用法。 pipeline API 详细的接口说明,请阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/api/pipeline.html) ## 使用方法 ### "Hello, world" 示例 ```python from lmdeploy import pipeline pipe = pipeline('internlm/internlm2_5-7b-chat') response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` 在这个例子中,pipeline 默认申请一定比例显存,用来存储推理过程中产生的 k/v。比例由参数 `TurbomindEngineConfig.cache_max_entry_count` 控制。 LMDeploy 在研发过程中,k/v cache 比例的设定策略有变更,以下为变更记录: 1. `v0.2.0 <= lmdeploy <= v0.2.1` 默认比例为 0.5,表示 **GPU总显存**的 50% 被分配给 k/v cache。 对于 7B 模型来说,如果显存小于 40G,会出现 OOM。当遇到 OOM 时,请按照下面的方法,酌情降低 k/v cache 占比: ```python from lmdeploy import pipeline, TurbomindEngineConfig # 调低 k/v cache内存占比调整为总显存的 20% backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` 2. `lmdeploy > v0.2.1` 分配策略改为从**空闲显存**中按比例为 k/v cache 开辟空间。默认比例值调整为 0.8。如果遇到 OOM,类似上面的方法,请酌情减少比例值,降低 k/v cache 的内存占用量 ### 设置多卡并行 ```python from lmdeploy import pipeline, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` ### 设置随机采样参数 ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) print(response) ``` ### 使用 OpenAI 格式的 prompt ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] response = pipe(prompts, gen_config=gen_config) print(response) ``` ### 流式输出 ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(tp=2) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] for item in pipe.stream_infer(prompts, gen_config=gen_config): print(item) ``` ### 获取生成 token 的 logits ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('internlm/internlm2_5-7b-chat') gen_config=GenerationConfig(output_logits='generation' max_new_tokens=10) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) logits = [x.logits for x in response] ``` ### 获取生成 token 最后一层的 hidden_states ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('internlm/internlm2_5-7b-chat') gen_config=GenerationConfig(output_last_hidden_state='generation', max_new_tokens=10) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) hidden_states = [x.last_hidden_state for x in response] ``` ### 计算 ppl ```python from transformers import AutoTokenizer from lmdeploy import pipeline model_repoid_or_path = 'internlm/internlm2_5-7b-chat' pipe = pipeline(model_repoid_or_path) tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True) messages = [ {"role": "user", "content": "Hello, how are you?"}, ] input_ids = tokenizer.apply_chat_template(messages) # logits is a list of tensor logits = pipe.get_logits(input_ids) print(logits) # ppl is a list of float numbers ppl = pipe.get_ppl(input_ids) print(ppl) ``` ```{note} 当 input_ids 过长时,可能会出现 OOM 错误,请小心应用 get_ppl 返回的是 cross entropy loss,没有在之后加 exp 操作 ``` ### 使用 PyTorchEngine 需要先安装 triton ```shell pip install triton>=2.1.0 ``` ```python from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig backend_config = PytorchEngineConfig(session_len=2048) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('internlm/internlm2_5-7b-chat', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': 'Hi, pls intro yourself' }], [{ 'role': 'user', 'content': 'Shanghai is' }]] response = pipe(prompts, gen_config=gen_config) print(response) ``` ### LoRA 模型推理 ```python from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig backend_config = PytorchEngineConfig(session_len=2048, adapters=dict(lora_name_1='chenchi/lora-chatglm2-6b-guodegang')) gen_config = GenerationConfig(top_p=0.8, top_k=40, temperature=0.8, max_new_tokens=1024) pipe = pipeline('THUDM/chatglm2-6b', backend_config=backend_config) prompts = [[{ 'role': 'user', 'content': '您猜怎么着' }]] response = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1') print(response) ``` ### 释放 pipeline 您可以通过调用其 `close()` 方法来显式释放 pipeline,或者,也可以使用 `with` 语句,如下所示: ```python from lmdeploy import pipeline with pipeline('internlm/internlm2_5-7b-chat') as pipe: response = pipe(['Hi, pls intro yourself', 'Shanghai is']) print(response) ``` ## 常见问题 - **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**. 如果你在使用 tp>1 和 pytorch 后端的时候,遇到了这个错误。请确保 python 脚本中有下面内容作为入口 ```python if __name__ == '__main__': ``` 一般来说,在多线程或多进程上下文中,可能需要确保初始化代码只执行一次。这时候,`if __name__ == '__main__':` 可以帮助确保这些初始化代码只在主程序执行,而不会在每个新创建的进程或线程中重复执行。 - 自定义对话模板,请参考[chat_template.md](../advance/chat_template.md) - 如果 lora 的权重有对应的对话模板,可以先注册对话模板到 lmdeploy,然后 adapter 名为对话模板名使用即可 ================================================ FILE: docs/zh_cn/llm/proxy_server.md ================================================ # 请求分发服务器 请求分发服务可以将多个 api_server 服务,进行并联。用户可以只需要访问代理 URL,就可以间接访问不同的 api_server 服务。代理服务内部会自动分发请求,做到负载均衡。 ## 启动 启动代理服务: ```shell lmdeploy serve proxy --server-name {server_name} --server-port {server_port} --routing-strategy "min_expected_latency" --serving-strategy Hybrid ``` 启动成功后,代理服务的 URL 也会被脚本打印。浏览器访问这个 URL,可以打开 Swagger UI。 随后,用户可以在启动 api_server 服务的时候,通过 `--proxy-url` 命令将其直接添加到代理服务中。例如:`lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。 这样,用户可以通过代理节点访问 api_server 的服务,代理节点的使用方式和 api_server 一模一样,都是兼容 OpenAI 的形式。 - /v1/models - /v1/chat/completions - /v1/completions ## 节点管理 通过 Swagger UI,我们可以看到多个 API。其中,和 api_server 节点管理相关的有: - /nodes/status - /nodes/add - /nodes/remove 他们分别表示,查看所有的 api_server 服务节点,增加某个节点,删除某个节点。他们的使用方式,最直接的可以在浏览器里面直接操作。也可以通过命令行或者 python 操作。 ### 通过 command 增删查 ```shell curl -X 'GET' \ 'http://localhost:8000/nodes/status' \ -H 'accept: application/json' ``` ```shell curl -X 'POST' \ 'http://localhost:8000/nodes/add' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "url": "http://0.0.0.0:23333" }' ``` ```shell curl -X 'POST' \ 'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \ -H 'accept: application/json' \ -d '' ``` ### 通过 python 脚本增删查 ```python # 查询所有节点 import requests url = 'http://localhost:8000/nodes/status' headers = {'accept': 'application/json'} response = requests.get(url, headers=headers) print(response.text) ``` ```python # 添加新节点 import requests url = 'http://localhost:8000/nodes/add' headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } data = {"url": "http://0.0.0.0:23333"} response = requests.post(url, headers=headers, json=data) print(response.text) ``` ```python # 删除某个节点 import requests url = 'http://localhost:8000/nodes/remove' headers = {'accept': 'application/json',} params = {'node_url': 'http://0.0.0.0:23333',} response = requests.post(url, headers=headers, data='', params=params) print(response.text) ``` ## 服务策略 LMDeploy 当前支持混合部署服务(Hybrid),以及 PD 分离部署服务(DistServe) - Hybrid: 不区分 Prefill 和 Decoding 实例,即传统的推理部署模式。 - DistServe: 将 Prefill 和 Decoding 实例分离,部署在不同的服务节点上以实现更灵活高效的资源调度和扩展。 ## 分发策略 代理服务目前的分发策略如下: - random: 根据用户提供的各个 api_server 节点的处理请求的能力,进行有权重的随机。处理请求的吞吐量越大,就越有可能被分配。部分节点没有提供吞吐量,将按照其他节点的平均吞吐量对待。 - min_expected_latency: 根据每个节点现有的待处理完的请求,和各个节点吞吐能力,计算预期完成响应所需时间,时间最短的将被分配。未提供吞吐量的节点,同上。 - min_observed_latency: 根据每个节点过去一定数量的请求,处理完成所需的平均用时,用时最短的将被分配。 ================================================ FILE: docs/zh_cn/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/zh_cn/multi_modal/api_server_vl.md ================================================ # 部署 VLM 类 openai 服务 本文主要介绍单个VL模型在单机多卡环境下,部署兼容 openai 接口服务的方式,以及服务接口的用法。为行文方便,我们把该服务名称为 `api_server`。对于多模型的并行服务,请阅读[请求分发服务器](../llm/proxy_server.md)一文。 在这篇文章中, 我们首先介绍服务启动的两种方法,你可以根据应用场景,选择合适的。 其次,我们重点介绍服务的 RESTful API 定义,以及接口使用的方式,并展示如何通过 Swagger UI、LMDeploy CLI 工具体验服务功能 ## 启动服务 以 huggingface hub 上的 [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 模型为例,你可以任选以下方式之一,启动推理服务。 ### 方式一:使用 lmdeploy cli 工具 ```shell lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b --server-port 23333 ``` api_server 启动时的参数可以通过命令行`lmdeploy serve api_server -h`查看。 比如,`--tp` 设置张量并行,`--session-len` 设置推理的最大上下文窗口长度,`--cache-max-entry-count` 调整 k/v cache 的内存使用比例等等。 ### 方式二:使用 docker 使用 LMDeploy 官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags),可以运行兼容 OpenAI 的服务。下面是使用示例: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b ``` 在这个例子中,`lmdeploy server api_server` 的命令参数与方式一一致。 ## RESTful API LMDeploy 的 RESTful API 兼容了 OpenAI 以下 3 个接口: - /v1/chat/completions - /v1/models - /v1/completions 其中使用图片交互的接口是 `/v1/chat/completions`,与 OpenAI 的一致。 服务启动后,你可以在浏览器中打开网页 http://0.0.0.0:23333,通过 Swagger UI 查看接口的详细说明,并且也可以直接在网页上操作,体验每个接口的用法,如下图所示。 ![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) 若需要把服务集成到自己的项目或者产品中,我们推荐以下用法: ### 使用 openai 接口 以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前,请先安装 openai 包: `pip install openai`。 ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ### 使用 lmdeploy `APIClient` 接口 如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码: ```python from lmdeploy.serve.openai.api_client import APIClient api_client = APIClient(f'http://0.0.0.0:23333') model_name = api_client.available_models[0] messages = [{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }] }] for item in api_client.chat_completions_v1(model=model_name, messages=messages): print(item) ``` ### 使用 Java/Golang/Rust 可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。 下面是一个使用示例: ```shell $ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust $ ls rust/* rust/Cargo.toml rust/git_push.sh rust/README.md rust/docs: ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md rust/src: apis lib.rs models ``` ================================================ FILE: docs/zh_cn/multi_modal/cogvlm.md ================================================ # cogvlm ## 简介 CogVLM 是一个强大的开源视觉语言模型(VLM). LMDeploy 已在PyTorch后端支持 CogVLM-17B 模型 [THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf) 和 CogVLM2-19B 模型如[THUDM/cogvlm2-llama3-chat-19B](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B) ## 快速开始 请参考[安装文档](../get_started/installation.md)安装 LMDeploy ### 准备 当使用LMDeploy部署 **CogVLM** 模型时,需要下载模型至本地目录。由于 **CogVLM** 模型使用外部Tokenizer,因而需要将相关文件下载至模型目录。然而对于**CogVLM2**模型,则可跳过此步骤。 以 **CogVLM** 模型 `cogvlm-chat-hf` 为例,可执行如下脚本下载模型: ```shell huggingface-cli download THUDM/cogvlm-chat-hf --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False huggingface-cli download lmsys/vicuna-7b-v1.5 special_tokens_map.json tokenizer.model tokenizer_config.json --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False ``` ### 离线推理 pipeline 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('cogvlm-chat-hf') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/zh_cn/multi_modal/deepseek_vl2.md ================================================ # DeepSeek-VL2 ## 简介 DeepSeek-VL2 是一系列先进的 MoE 视觉-语言模型,相较于其前身 DeepSeek-VL 有了显著的改进。 DeepSeek-VL2 在各种任务中展现出卓越的能力,包括但不限于视觉问答、OCR、文档/表格/图表理解以及视觉定位。 LMDeploy 目前在 Pytorch 引擎中支持 [deepseek-vl2-tiny](https://huggingface.co/deepseek-ai/deepseek-vl2-tiny), [deepseek-vl2-small](https://huggingface.co/deepseek-ai/deepseek-vl2-small) 和 [deepseek-vl2](https://huggingface.co/deepseek-ai/deepseek-vl2) 。 ## 快速开始 请参考[安装文档](../get_started/installation.md)安装 LMDeploy。 ### 准备 在使用 LMDeploy 部署 **DeepSeek-VL2** 模型时,您必须安装官方的 GitHub 仓库以及一些相关的第三方库。这是因为 LMDeploy 会复用官方仓库中提供的图像处理功能。 ``` pip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git --no-deps pip install attrdict timm 'transformers<4.48.0' ``` 值得注意的是,如果使用 transformers>=4.48.0,可能会出现失败的情况,详情可以参考此 [Issue](https://github.com/deepseek-ai/DeepSeek-VL2/issues/45)。 ### 离线推理 pipeline 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)。 为了构建有效的、包含图像输入的 DeepSeek-VL2 提示词,用户应手动插入 `` ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('deepseek-ai/deepseek-vl2-tiny') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/zh_cn/multi_modal/gemma3.md ================================================ # Gemma3 ## 简介 Gemma 是 Google 推出的轻量级、最先进的开放模型系列,采用与创建 Gemini 模型相同的研究和技术构建而成。Gemma3 模型是多模态模型,可处理文本和图像输入并生成文本输出,对预训练和指令微调均具有开源的权重。Gemma3 具有 128K 的大型上下文窗口,支持 140 多种语言,并且比以前的版本提供更多尺寸。Gemma3 模型非常适合各种文本生成和图像理解任务,包括问答、总结和推理。它们的尺寸相对较小,因此可以将其部署在资源有限的环境中,例如笔记本电脑、台式机或您自己的云基础设施,从而让每个人都能轻松访问最先进的 AI 模型,并帮助促进创新。 ## 快速开始 请参考[安装文档](../get_started/installation.md)安装 LMDeploy。 ### 准备 在使用 LMDeploy 部署 **Gemma3** 模型时,请安装最新的 transformers。 ### 离线推理 pipeline 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)。 ```python from lmdeploy import pipeline from lmdeploy.vl import load_image if __name__ == "__main__": pipe = pipeline('google/gemma-3-12b-it') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ================================================ FILE: docs/zh_cn/multi_modal/index.rst ================================================ 视觉语言模型 ================================= .. toctree:: :maxdepth: 2 :caption: 示例 deepseek_vl2.md llava.md internvl.md xcomposer2d5.md cogvlm.md minicpmv.md phi3.md qwen2_vl.md qwen2_5_vl.md molmo.md gemma3.md ================================================ FILE: docs/zh_cn/multi_modal/internvl.md ================================================ # InternVL LMDeploy 支持 InternVL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :-------------------: | :-----------: | :------------------------: | | InternVL | 13B-19B | TurboMind | | InternVL1.5 | 2B-26B | TurboMind, PyTorch | | InternVL2 | 4B | PyTorch | | InternVL2 | 1B-2B, 8B-76B | TurboMind, PyTorch | | InternVL2.5/2.5-MPO/3 | 1B-78B | TurboMind, PyTorch | | Mono-InternVL | 2B | PyTorch | 本文将以[InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B)为例,演示使用 LMDeploy 部署 InternVL 系列模型的方法。 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy,并安装上游 InternVL 模型库需的依赖。 ```shell pip install timm # 建议从https://github.com/Dao-AILab/flash-attention/releases寻找和环境匹配的whl包 pip install flash-attn ``` 或者,你可以为 InternVL 的推理构建 docker image。如果,宿主机器上的 CUDA 版本 `>=12.4`,你可以执行如下命令构建镜像: ``` git clone https://github.com/InternLM/lmdeploy.git cd lmdeploy docker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile ``` 否则的话,可以基于 LMDeploy cu11 的镜像来构建: ```shell docker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile ``` ## 离线推理 以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` 更多例子如下:
多图多轮对话,拼接图像 ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\nDescribe the two images in detail.'), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
多图多轮对话,独立图像 ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\nDescribe the two images in detail.'), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
视频多轮对话 ```python import numpy as np from lmdeploy import pipeline, GenerationConfig from decord import VideoReader, cpu from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import encode_image_base64 from PIL import Image pipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO') def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(video_path, bound=None, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) imgs = [] for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') imgs.append(img) return imgs video_path = 'red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' for i in range(len(imgs)): question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' question += 'What is the red panda doing?' content = [{'type': 'text', 'text': question}] for img in imgs: content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}}) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='Describe this video in detail. Don\'t repeat.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## 在线服务 你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: ```shell lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` 也可以基于前文构建的 docker image 启动服务: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:internvl \ lmdeploy serve api_server OpenGVLab/InternVL2-8B ``` Docker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件,内容参考如下: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:internvl ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server OpenGVLab/InternVL2-8B deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` 然后,你就可以执行命令启动服务了: ```shell docker-compose up -d ``` 通过`docker logs -f lmdeploy`可以查看启动的日志信息,如果发现类似下方的日志信息,就表明服务启动成功了。 ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` 有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。 关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md) ================================================ FILE: docs/zh_cn/multi_modal/llava.md ================================================ # LLaVA LMDeploy 支持以下 LLaVA 系列模型,具体如下表所示: | 模型 | 大小 | 支持的推理引擎 | | :----------------------------------: | :--: | :----------------: | | llava-hf/Llava-interleave-qwen-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-1.5-7b-hf | 7B | TurboMind, PyTorch | | llava-hf/llava-v1.6-mistral-7b-hf | 7B | PyTorch | | llava-hf/llava-v1.6-vicuna-7b-hf | 7B | PyTorch | | liuhaotian/llava-v1.6-vicuna-7b | 7B | TurboMind | | liuhaotian/llava-v1.6-mistral-7b | 7B | TurboMind | 接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型,并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。 ```{note} 自 0.6.4 之后,PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到 ``` ## 安装 请按照[安装指南](../get_started/installation.md)安装 LMDeploy。 或者,您也可以使用官方的 Docker 镜像: ```shell docker pull openmmlab/lmdeploy:latest ``` ## 离线推理 以下示例代码展示了 VLM pipeline 的基本用法。有关详细信息,请参考 [VLM 离线推理流程](./vl_pipeline.md)。 ```python from lmdeploy import GenerationConfig, TurbomindEngineConfig, pipeline from lmdeploy.vl import load_image pipe = pipeline("llava-hf/llava-interleave-qwen-7b-hf", backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5), gen_config=GenerationConfig(max_new_tokens=512)) image = load_image('https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg') prompt = 'Describe the image.' print(f'prompt:{prompt}') response = pipe((prompt, image)) print(response) ``` 更多示例:
多图片多轮对话,组合图片 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('llava-hf/llava-interleave-qwen-7b-hf', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## 在线服务 可以使用 `lmdeploy serve api_server` CLI 启动服务器: ```shell lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf ``` 或者,使用前面提到的 Docker 镜像启动服务: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf ``` 采用 Docker Compose 部署也是一种常见选择。在 lmdeploy 项目的根目录创建 `docker-compose.yml` 文件,如下: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:latest ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` 然后,可以执行以下命令启动服务: ```shell docker-compose up -d ``` 当运行 `docker logs -f lmdeploy` 后看到如下日志,说明服务启动成功: ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` 可以通过 `lmdeploy serve api_server -h` 查看 `lmdeploy serve api_server` 的参数详情。 关于 `api_server` 以及如何访问服务的更多信息可以在[这里](api_server_vl.md)找到。 ================================================ FILE: docs/zh_cn/multi_modal/minicpmv.md ================================================ # MiniCPM-V LMDeploy 支持 MiniCPM-V 系列模型,具体如下: | Model | Supported Inference Engine | | :------------------: | :------------------------: | | MiniCPM-Llama3-V-2_5 | TurboMind | | MiniCPM-V-2_6 | TurboMind | 本文将以[MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)为例,演示使用 LMDeploy 部署 MiniCPM-V 系列模型的方法 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy。 ## 离线推理 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('openbmb/MiniCPM-V-2_6') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` 更多例子如下:
多张图片,多轮对话 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')), dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
上下文小样本学习 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') question = "production date" messages = [ dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='example1.jpg')), ]), dict(role='assistant', content='2023.08.04'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='example2.jpg')), ]), dict(role='assistant', content='2007.04.24'), dict(role='user', content=[ dict(type='text', text=question), dict(type='image_url', image_url=dict(url='test.jpg')), ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
视频对话 ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl import encode_image_base64 import torch from PIL import Image from transformers import AutoModel, AutoTokenizer from decord import VideoReader, cpu # pip install decord pipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO') MAX_NUM_FRAMES=64 # if cuda OOM set a smaller number def encode_video(video_path): def uniform_sample(l, n): gap = len(l) / n idxs = [int(i * gap + gap / 2) for i in range(n)] return [l[i] for i in idxs] vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] if len(frame_idx) > MAX_NUM_FRAMES: frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype('uint8')) for v in frames] print('num frames:', len(frames)) return frames video_path="video_test.mp4" frames = encode_video(video_path) question = "Describe the video" content=[dict(type='text', text=question)] for frame in frames: content.append(dict(type='image_url', image_url=dict(use_image_id=False, max_slice_nums=2, url=f'data:image/jpeg;base64,{encode_image_base64(frame)}'))) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) print(out.text) ```
## 在线服务 你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: ```shell lmdeploy serve api_server openbmb/MiniCPM-V-2_6 ``` 也可以基于 LMDeploy 的 docker 启动服务: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:latest \ lmdeploy serve api_server openbmb/MiniCPM-V-2_6 ``` Docker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件,内容参考如下: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:latest ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server openbmb/MiniCPM-V-2_6 deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` 然后,你就可以执行命令启动服务了: ```shell docker-compose up -d ``` 通过`docker logs -f lmdeploy`可以查看启动的日志信息,如果发现类似下方的日志信息,就表明服务启动成功了。 ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` 有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。 关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md) ================================================ FILE: docs/zh_cn/multi_modal/molmo.md ================================================ # Qwen2-VL LMDeploy 支持 Molmo 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :-------------: | :--: | :------------------------: | | Molmo-7B-D-0924 | 7B | TurboMind | | Molmo-72-0924 | 72B | TurboMind | 本文将以[Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) 为例,演示使用 LMDeploy 部署 Molmo 系列模型的方法 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy。 ## 离线推理 以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('allenai/Molmo-7B-D-0924') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` 更多例子如下:
多图多轮对话 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## 在线服务 你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: ```shell lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` 也可以基于 docker image 启动服务: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:qwen2vl \ lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` 如果日志中有如下信息,就表明服务启动成功了。 ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` 有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。 关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md) ================================================ FILE: docs/zh_cn/multi_modal/phi3.md ================================================ # Phi-3 Vision ## 简介 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) 是微软发布的轻量级系列模型,LMDeploy支持了其中的多模态模型如下: | Model | Size | Supported Inference Engine | | :-------------------------------------------------------------------------------------------------: | :--: | :------------------------: | | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | 4.2B | PyTorch | | [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct) | 4.2B | PyTorch | 本文将以[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)为例,演示使用 LMDeploy 部署 Phi-3 系列多模态模型的方法 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy,并安装该模型的依赖。 ```shell # 建议从https://github.com/Dao-AILab/flash-attention/releases寻找和环境匹配的whl包 pip install flash-attn ``` ## 离线推理 pipeline 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('microsoft/Phi-3.5-vision-instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ## 在线服务 ### 服务启动 你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: ```shell lmdeploy serve api_server microsoft/Phi-3.5-vision-instruct ``` ### 使用 openai 接口 以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前,请先安装 openai 包: `pip install openai`。 ```python from openai import OpenAI client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1') model_name = client.models.list().data[0].id response = client.chat.completions.create( model=model_name, messages=[{ 'role': 'user', 'content': [{ 'type': 'text', 'text': 'Describe the image please', }, { 'type': 'image_url', 'image_url': { 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', }, }], }], temperature=0.8, top_p=0.8) print(response) ``` ================================================ FILE: docs/zh_cn/multi_modal/qwen2_5_vl.md ================================================ # Qwen2.5-VL LMDeploy 支持 Qwen-VL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :--------: | :--------------: | :------------------------: | | Qwen2.5-VL | 3B, 7B, 32B, 72B | PyTorch | 本文将以[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2.5-VL 系列模型的方法 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy,并安装上游 Qwen2.5-VL 模型库所需的依赖。 ```shell # Qwen2.5-VL requires the latest transformers (transformers >= 4.49.0) pip install git+https://github.com/huggingface/transformers # It's highly recommended to use `[decord]` feature for faster video loading. pip install qwen-vl-utils[decord]==0.0.8 ``` ## 离线推理 以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` 更多例子如下:
多图多轮对话 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
控制图片分辨率,加速推理 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') min_pixels = 64 * 28 * 28 max_pixels = 64 * 28 * 28 messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
视频多轮对话 ```python import numpy as np from lmdeploy import pipeline, GenerationConfig from decord import VideoReader, cpu from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import encode_image_base64 from PIL import Image pipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO') def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(video_path, bound=None, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) pixel_values_list, num_patches_list = [], [] frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) imgs = [] for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') imgs.append(img) return imgs video_path = 'red-panda.mp4' imgs = load_video(video_path, num_segments=8) question = '' for i in range(len(imgs)): question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' question += 'What is the red panda doing?' content = [{'type': 'text', 'text': question}] for img in imgs: content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}}) messages = [dict(role='user', content=content)] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='Describe this video in detail. Don\'t repeat.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
================================================ FILE: docs/zh_cn/multi_modal/qwen2_vl.md ================================================ # Qwen2-VL LMDeploy 支持 Qwen-VL 系列模型,具体如下: | Model | Size | Supported Inference Engine | | :----------: | :----: | :------------------------: | | Qwen-VL-Chat | - | TurboMind | | Qwen2-VL | 2B, 7B | PyTorch | 本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例,演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法 ## 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy,并安装上游 Qwen2-VL 模型库需的依赖。 ```shell pip install qwen_vl_utils ``` 或者,你可以为 Qwen2-VL 的推理构建 docker image。如果,宿主机器上的 CUDA 版本 `>=12.4`,你可以执行如下命令构建镜像: ``` git clone https://github.com/InternLM/lmdeploy.git cd lmdeploy docker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile ``` 否则的话,可以基于 LMDeploy cu11 的镜像来构建: ```shell docker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile ``` ## 离线推理 以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` 更多例子如下:
多图多轮对话 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
控制图片分辨率,加速推理 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') min_pixels = 64 * 28 * 28 max_pixels = 64 * 28 * 28 messages = [ dict(role='user', content=[ dict(type='text', text='Describe the two images in detail.'), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) ]) ] out = pipe(messages, gen_config=GenerationConfig(top_k=1)) messages.append(dict(role='assistant', content=out.text)) messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) out = pipe(messages, gen_config=GenerationConfig(top_k=1)) ```
## 在线服务 你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: ```shell lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` 也可以基于前文构建的 docker image 启动服务: ```shell docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ -p 23333:23333 \ --ipc=host \ openmmlab/lmdeploy:qwen2vl \ lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct ``` Docker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件,内容参考如下: ```yaml version: '3.5' services: lmdeploy: container_name: lmdeploy image: openmmlab/lmdeploy:qwen2vl ports: - "23333:23333" environment: HUGGING_FACE_HUB_TOKEN: volumes: - ~/.cache/huggingface:/root/.cache/huggingface stdin_open: true tty: true ipc: host command: lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct deploy: resources: reservations: devices: - driver: nvidia count: "all" capabilities: [gpu] ``` 然后,你就可以执行命令启动服务了: ```shell docker-compose up -d ``` 通过`docker logs -f lmdeploy`可以查看启动的日志信息,如果发现类似下方的日志信息,就表明服务启动成功了。 ```text HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! INFO: Started server process [2439] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) ``` 有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。 关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md) ================================================ FILE: docs/zh_cn/multi_modal/vl_pipeline.md ================================================ # VLM 离线推理 pipeline LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单好用的 pipeline。它的用法与大语言模型(LLM)推理 [pipeline](../llm/pipeline.md) 类似。 在[这个列表中](../supported_models/supported_models.md),你可以查阅每个推理引擎支持的 VLM 模型。我们诚挚邀请社区在 LMDeploy 中添加更多 VLM 模型。 本文将以 [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) 模型为例,展示 VLM pipeline 的用法。你将了解它的最基础用法,以及如何通过调整引擎参数和生成条件来逐步解锁更多高级特性,如张量并行,上下文窗口大小调整,随机采样,以及对话模板的定制。 此外,我们还提供针对多图、批量提示词等场景的实际推理示例。 使用 pipeline 接口推理其他 VLM 模型,大同小异,主要区别在于模型依赖的配置和安装。你可以阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/multi_modal/),查看不同模型的环境安装和配置方式 ## "Hello, world" 示例 ```python from lmdeploy import pipeline from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` 如果在执行这个用例时,出现 `ImportError` 的错误,请按照提示安装相关的依赖包。 上面的例子中,推理时的提示词是 (prompt, image) 的 tuple 结构。除了这种结构外,pipeline 支持 openai 格式的提示词: ```python from lmdeploy import pipeline pipe = pipeline('OpenGVLab/InternVL2_5-8B') prompts = [ { 'role': 'user', 'content': [ {'type': 'text', 'text': 'describe this image'}, {'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}} ] } ] response = pipe(prompts) print(response) ``` ### 设置多卡并行 设置引擎参数 `tp`,可激活多卡并行能力 ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(tp=2)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### 设置上下文长度 创建 pipeline 时,通过设置引擎参数 `session_len`,可以定制上下文窗口的最大长度 ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### 设置随机采样参数 可通过传入 `GenerationConfig` 修改 pipeline 的生成接口中的默认采样参数。 ```python from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(tp=2, session_len=8192)) gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image), gen_config=gen_config) print(response) ``` ### 自定义图片 token 的位置 默认情况下,LMDeploy 会根据算法 repo 提供的对话模版将表示图片的特殊 token 插入到 user prompt 中,但在一些模型中,图片 token 的位置并没有限制,如 deepseek-vl,或者用户需要自定义图片 token 插入的位置。这种情况下,用户需要手动将表示图片的 token 插入到 prompt 中。LMDeploy 使用 `` 作为表示图片的特殊 token。 ```python from lmdeploy import pipeline from lmdeploy.vl import load_image from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('deepseek-ai/deepseek-vl-1.3b-chat') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image{IMAGE_TOKEN}', image)) print(response) ``` ### 设置对话模板 推理时,LMDeploy 会根据模型路径匹配内置的对话模板,并把对话模板应用到输入的提示词上。如果用户使用的是本地模型,并且模型文件夹名字与官方模型不一致时,需要手动指定对话模版。以 [llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) 为例,官方使用 ['llava-v1'](https://github.com/haotian-liu/LLaVA/blob/v1.2.2/llava/conversation.py#L325-L335) 对话模版,如果本地文件夹名字不是 `llava-v1.5-7b`,可以按照如下方式使用。 ```python from lmdeploy import pipeline, ChatTemplateConfig from lmdeploy.vl import load_image pipe = pipeline('local_model_folder', chat_template_config=ChatTemplateConfig(model_name='llava-v1')) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` 关于如何自定义对话模版,请参考[这里](../advance/chat_template.md) ### 设置视觉模型参数 可通过设置 `VisionConfig` 修改视觉模型默认参数 ```python from lmdeploy import pipeline, VisionConfig from lmdeploy.vl import load_image vision_config=VisionConfig(max_batch_size=16) pipe = pipeline('liuhaotian/llava-v1.5-7b', vision_config=vision_config) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) ``` ### 获取生成 token 的 logits ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image), gen_config=GenerationConfig(output_logits='generation')) logits = response.logits print(logits) ``` ## 多图推理 对于多图的场景,在推理时,只要把它们放在一个列表中即可。不过,多图意味着输入 token 数更多,所以通常需要[增大推理的上下文长度](#设置上下文长度) ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image_urls=[ 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg', 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg' ] images = [load_image(img_url) for img_url in image_urls] response = pipe(('describe these images', images)) print(response) ``` ## 提示词批处理 做批量提示词推理非常简单,只要把它们放在一个 list 结构中: ```python from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image_urls=[ "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg", "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg" ] prompts = [('describe this image', load_image(img_url)) for img_url in image_urls] response = pipe(prompts) print(response) ``` ## 多轮对话 pipeline 进行多轮对话有两种方式,一种是按照 openai 的格式来构造 messages,另外一种是使用 `pipeline.chat` 接口。 ```python from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig from lmdeploy.vl import load_image pipe = pipeline('OpenGVLab/InternVL2_5-8B', backend_config=TurbomindEngineConfig(session_len=8192)) image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg') gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6) sess = pipe.chat(('describe this image', image), gen_config=gen_config) print(sess.response.text) sess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config) print(sess.response.text) ``` ### 释放 pipeline 您可以通过调用其 `close()` 方法来显式释放 pipeline,或者,也可以使用 `with` 语句,如下所示: ```python from lmdeploy import pipeline from lmdeploy import pipeline from lmdeploy.vl import load_image with pipeline('OpenGVLab/InternVL2_5-8B') as pipe: image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe(('describe this image', image)) print(response) # Clear the torch cache and perform garbage collection if needed import torch import gc torch.cuda.empty_cache() gc.collect() ``` ================================================ FILE: docs/zh_cn/multi_modal/xcomposer2d5.md ================================================ # InternLM-XComposer-2.5 ## 简介 [InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) 是基于书生·浦语2大语言模型研发的突破性的图文多模态大模型,仅使用 7B LLM 后端就达到了 GPT-4V 级别的能力。浦语·灵笔2.5使用24K交错的图像-文本上下文进行训练,通过RoPE外推可以无缝扩展到96K长的上下文。这种长上下文能力使浦语·灵笔2.5在需要广泛输入和输出上下文的任务中表现出色。 LMDeploy 支持了 [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) 模型,通过 TurboMind 引擎推理。 ## 快速开始 ### 安装 请参考[安装文档](../get_started/installation.md)安装 LMDeploy,并安装上游模型库 InternLM-XComposer-2.5 所需的依赖。 ```shell pip install decord ``` ### 离线推理 pipeline 以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) ```python from lmdeploy import pipeline from lmdeploy.vl import load_image from lmdeploy.vl.constants import IMAGE_TOKEN pipe = pipeline('internlm/internlm-xcomposer2d5-7b') image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') response = pipe((f'describe this image', image)) print(response) ``` ## Lora 模型 InternLM-XComposer-2.5 针对网页制作和文章创作训练了 LoRA 模型,由于 TurboMind 不支持 slora 特性,所以需要同时只能部署一个 LoRA 模型,需要先对权重进行合并。LMDeploy 提供相关的转换脚本,使用方式为: ``` export HF_MODEL=internlm/internlm-xcomposer2d5-7b export WORK_DIR=internlm/internlm-xcomposer2d5-7b-web export TASK=web python -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK ``` ## 量化 下面以 base 模型为例,展示量化的方式,若要使用 LoRA 模型,请先按照上一章节提取 LoRA 模型。 ```shell export HF_MODEL=internlm/internlm-xcomposer2d5-7b export WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit lmdeploy lite auto_awq \ $HF_MODEL \ --work-dir $WORK_DIR ``` ## 更多使用例子
Video Understanding 下面以 `pipeline.chat` 为例展示用法,其它接口同样支持推理,需要手动拼接对话内容。 ```python from lmdeploy import pipeline, GenerationConfig from transformers.dynamic_module_utils import get_class_from_dynamic_module HF_MODEL = 'internlm/internlm-xcomposer2d5-7b' load_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL) frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL) Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL) get_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL) video = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4 img = frame2img(video, get_font()) img = Video_transform(img) pipe = pipeline(HF_MODEL) gen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0) query = 'Here are some frames of a video. Describe this video in detail' sess = pipe.chat((query, img), gen_config=gen_config) print(sess.response.text) query = 'tell me the athlete code of Liu Xiang' sess = pipe.chat(query, session=sess, gen_config=gen_config) print(sess.response.text) ```
Multi-Image ```python from lmdeploy import pipeline, GenerationConfig from lmdeploy.vl.constants import IMAGE_TOKEN from lmdeploy.vl import load_image query = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one' urls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg', 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg', 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg'] images = [load_image(url) for url in urls] pipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO') output = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939)) ``` 由于 LMDeploy 不支持 beam search,生成的结果与使用 transformers 的 beam search 相比,会有较大的差异,建议关闭 top_k 或者使用较大的 top_k 采样来增加多样性。
Instruction to Webpage 请先使用使用上述说明,转化 web 模型。 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO') pipe.chat_template.meta_instruction = None query = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.' output = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048)) ``` 使用 transformers 测试时,发现如果设置了 repetition_penalty,beam search 为1时有较大概率停不下来,因为 LMDeploy 不支持 beam search,建议使用 LMDeploy 推理时关闭 repetition_penalty。
Write Article 请先使用使用上述说明,转化 write 模型。 ```python from lmdeploy import pipeline, GenerationConfig pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO') pipe.chat_template.meta_instruction = None query = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence' output = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192)) ```
================================================ FILE: docs/zh_cn/quantization/kv_quant.md ================================================ # Key-Value(KV) Cache 量化 自 v0.4.0 起,LMDeploy 支持**在线** kv cache int4/int8 量化,量化方式为 per-head per-token 的非对称量化。原来的 kv 离线量化方式移除。 从直观上看,量化 kv 有利于增加 kv block 的数量。与 fp16 相比,int4/int8 kv 的 kv block 分别可以增加到 4 倍和 2 倍。这意味着,在相同的内存条件下,kv 量化后,系统能支撑的并发数可以大幅提升,从而最终提高吞吐量。 但是,通常,量化会伴随一定的模型精度损失。我们使用了 opencompass 评测了若干个模型在应用了 int4/int8 量化后的精度,int8 kv 精度几乎无损,int4 kv 略有损失。详细结果放在了[精度评测](#精度评测)章节中。大家可以参考,根据实际需求酌情选择。 LMDeploy kv 4/8 bit 量化和推理支持如下 NVIDIA 显卡型号: - volta 架构(sm70): V100 - 图灵架构(sm75):20系列、T4 - 安培架构(sm80,sm86):30系列、A10、A16、A30、A100 - Ada Lovelace架构(sm89):40 系列 - Hopper 架构(sm90): H100, H200 总结来说,LMDeploy kv 量化具备以下优势: 1. 量化不需要校准数据集 2. 支持 volta 架构(sm70)及以上的所有显卡型号 3. kv int8 量化精度几乎无损,kv int4 量化精度在可接受范围之内 4. 推理高效,在 llama2-7b 上加入 int8/int4 kv 量化,RPS 相较于 fp16 分别提升近 30% 和 40% 接下来,我们以 internlm2-chat-7b 模型为例,介绍 kv 量化和推理的若干应用。而在此之前,请安装 lmdeploy ```shell pip install lmdeploy ``` ## 应用示例 通过 LMDeploy 应用 kv 量化非常简单,只需要设定 `quant_policy` 参数。 **LMDeploy 规定 `qant_policy=4` 表示 kv int4 量化,`quant_policy=8` 表示 kv int8 量化。** ### 离线推理 ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig(quant_policy=8) pipe = pipeline("internlm/internlm2_5-7b-chat", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` ### 推理服务 ```shell lmdeploy serve api_server internlm/internlm2_5-7b-chat --quant-policy 8 ``` ## 精度评测 我们把 lmdeploy 的 kv 量化应用在若干 LLM 模型上,并使用 opencompass 评测推理精度,结果如下表所示: | - | - | - | llama2-7b-chat | - | - | internlm2-chat-7b | - | - | internlm2.5-chat-7b | - | - | qwen1.5-7b-chat | - | - | | ----------- | ------- | ------------- | -------------- | ------- | ------- | ----------------- | ------- | ------- | ------------------- | ------- | ------- | --------------- | ------- | ------- | | dataset | version | metric | kv fp16 | kv int8 | kv int4 | kv fp16 | kv int8 | kv int4 | kv fp16 | kv int8 | kv int4 | fp16 | kv int8 | kv int4 | | ceval | - | naive_average | 28.42 | 27.96 | 27.58 | 60.45 | 60.88 | 60.28 | 78.06 | 77.87 | 77.05 | 70.56 | 70.49 | 68.62 | | mmlu | - | naive_average | 35.64 | 35.58 | 34.79 | 63.91 | 64 | 62.36 | 72.30 | 72.27 | 71.17 | 61.48 | 61.56 | 60.65 | | triviaqa | 2121ce | score | 56.09 | 56.13 | 53.71 | 58.73 | 58.7 | 58.18 | 65.09 | 64.87 | 63.28 | 44.62 | 44.77 | 44.04 | | gsm8k | 1d7fe4 | accuracy | 28.2 | 28.05 | 27.37 | 70.13 | 69.75 | 66.87 | 85.67 | 85.44 | 83.78 | 54.97 | 56.41 | 54.74 | | race-middle | 9a54b6 | accuracy | 41.57 | 41.78 | 41.23 | 88.93 | 88.93 | 88.93 | 92.76 | 92.83 | 92.55 | 87.33 | 87.26 | 86.28 | | race-high | 9a54b6 | accuracy | 39.65 | 39.77 | 40.77 | 85.33 | 85.31 | 84.62 | 90.51 | 90.42 | 90.42 | 82.53 | 82.59 | 82.02 | 具体的评测方式可以参考[这份指南](../benchmark/evaluate_with_opencompass.md)。评测时,请在config文件中,为推理引擎添加 `quant_policy` 参数。 ## 推理效率 | model | kv type | test settings | RPS | v.s. kv fp16 | | ----------------- | ------- | ---------------------------------------- | ----- | ------------ | | llama2-chat-7b | fp16 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 14.98 | 1.0 | | - | int8 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 19.01 | 1.27 | | - | int4 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 20.81 | 1.39 | | llama2-chat-13b | fp16 | tp1 / ratio 0.9 / bs 128 / prompts 10000 | 8.55 | 1.0 | | - | int8 | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 10.96 | 1.28 | | - | int4 | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 11.91 | 1.39 | | internlm2-chat-7b | fp16 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 24.13 | 1.0 | | - | int8 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.28 | 1.05 | | - | int4 | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.80 | 1.07 | 上述结果使用的测试脚本是 `benchmark/profile_throughput.py` ================================================ FILE: docs/zh_cn/quantization/llm_compressor.md ================================================ # llm-compressor 支持 本指南旨在介绍如何使用 LMDeploy 的 TurboMind 推理引擎,运行经由 [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)工具量化后的模型。 目前支持的 `llm-compressor` 量化模型包括: - int4 量化(例如 AWQ、GPTQ) 上述量化模型通过 TurboMind 引擎可以在以下 NVIDIA GPU 架构上运行: | Compute Capability | Micro-architecture | GPUs | | ------------------ | ------------------ | ------------------------------- | | 7.0 | Volta | V100 | | 7.2 | Volta | Jetson Xavier | | 7.5 | Turing | GeForce RTX 20 series, T4 | | 8.0 | Ampere | A100, A800, A30 | | 8.6 | Ampere | GeForce RTX 30 series, A40, A10 | | 8.7 | Ampere | Jetson Orin | | 8.9 | Ada Lovelace | GeForce RTX 40 series, L40, L20 | | 9.0 | Hopper | H20, H200, H100, GH200 | | 12.0 | Blackwell | GeForce RTX 50 series | LMDeploy 将持续跟进并扩展对 `llm-compressor` 项目的支持。 本文的其余部分由以下章节组成: - [模型量化](#模型量化) - [模型部署](#模型部署) - [精度评测](#精度评测) ## 模型量化 `llm-compressor` 提供了丰富的模型量化[用例](https://github.com/vllm-project/llm-compressor/tree/main/examples),请参考其教程选择 LMDeploy 支持的量化算法,完成模型量化工作。 LMDeploy 也内置了通过 `llm-compressor` 对 Qwen3-30B-A3B 进行 AWQ 量化的[脚本](https://github.com/InternLM/lmdeploy/blob/main/examples/lite/qwen3_30b_a3b_awq.py),供大家进行参考: ```shell # 创建 conda 环境 conda create -n lmdeploy python=3.10 -y conda activate lmdeploy # 安装 llm-compressor pip install llmcompressor # 下载 lmdeploy 源码,运行量化用用例 git clone https://github.com/InternLM/lmdeploy cd lmdeploy python examples/lite/qwen3_30b_a3b_awq.py --work-dir ./qwen3_30b_a3b_awq ``` 在接下来的章节中,我们以此量化模型为例,介绍模型部署、评测精度等方法 ## 模型部署 ### 离线推理 量化后的模型,通过以下几行简单的代码,可以实现离线批处理: ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig() with pipeline("./qwen3_30b_a3b_4bit", backend_config=engine_config) as pipe: response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` 关于 pipeline 的详细介绍,请参考[这里](https://lmdeploy.readthedocs.io/zh-cn/latest/llm/pipeline.html) ### 在线服务 LMDeploy api_server 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: ```shell lmdeploy serve api_server ./qwen3_30b_a3b_4bit --backend turbomind ``` 服务默认端口是23333。在 server 启动后,你可以通过 openai SDK 访问服务。关于服务的命令参数,以及访问服务的方式,可以阅读[这份](https://lmdeploy.readthedocs.io/zh-cn/latest/llm/api_server.html)文档 ## 精度评测 我们将 Qwen3-8B (Dense) 与 Qwen3-30B-A3B (MoE) 的 AWQ 对称/非对称量化模型通过 LMDeploy 部署为服务,并使用 [opencompass](https://github.com/open-compass/opencompass) 在多个学术数据集上评测。结果显示:Qwen3-8B 的非对称量化整体优于对称量化,而 Qwen3-30B-A3B 在两种量化方式间差异不显著;Qwen3-8B 在对称/非对称量化下与 BF16 模型的精度差异小于 Qwen3-30B-A3B。与 BF16 相比,量化模型在长输出数据集,比如 aime2025 (平均 17,635 tokens)、LCB (平均 14,157 tokens),精度下降更明显;在中短输出数据集,比如 ifeval (平均 1,885 tokens),mmlu_pro (平均 2,826),精度符合预期。 | dataset | Qwen3-8B | | | Qwen3-30B-A3B | | | | ----------------- | -------- | ------- | -------- | ------------- | ------- | -------- | | | bf16 | awq sym | awq asym | bf16 | awq sym | awq asym | | ifeval | 85.58 | 83.73 | 85.77 | 86.32 | 84.10 | 84.29 | | hle | 5.05 | 5.05 | 5.24 | 7.00 | 5.47 | 5.65 | | gpqa | 59.97 | 56.57 | 59.47 | 61.74 | 57.95 | 57.07 | | aime2025 | 69.48 | 64.38 | 63.96 | 73.44 | 64.79 | 66.67 | | mmlu_pro | 73.69 | 71.73 | 72.34 | 77.85 | 75.77 | 75.69 | | LCBCodeGeneration | 50.86 | 44.10 | 46.95 | 56.67 | 50.86 | 49.24 | 复现方式可以参考[这份](https://lmdeploy.readthedocs.io/zh-cn/latest/benchmark/evaluate_with_opencompass.html)文档 ================================================ FILE: docs/zh_cn/quantization/w4a16.md ================================================ # INT4 模型量化和部署 LMDeploy TurboMind 引擎支持由 [AWQ](https://arxiv.org/abs/2306.00978) 和 [GPTQ](https://github.com/AutoGPTQ/AutoGPTQ) 两种量化方法量化的 4bit 模型的推理。然而,LMDeploy 量化模块目前仅支持 AWQ 量化算法。 可用于 AWQ/GPTQ INT4 推理的 NVIDIA GPU 包括: - V100(sm70): V100 - Turing(sm75): 20 系列,T4 - Ampere(sm80,sm86): 30 系列,A10, A16, A30, A100 - Ada Lovelace(sm89): 40 系列 在进行量化和推理之前,请确保按照[安装指南](../get_started/installation.md)安装了 lmdeploy。 本文的其余部分由以下章节组成: - [模型量化](#模型量化) - [模型评测](#模型评测) - [模型推理](#模型推理) - [推理服务](#推理服务) - [推理性能](#推理性能) ## 模型量化 仅需执行一条命令,就可以完成模型量化工作。量化结束后,权重文件存放在 `$WORK_DIR` 下。 ```shell export HF_MODEL=internlm/internlm2_5-7b-chat export WORK_DIR=internlm/internlm2_5-7b-chat-4bit lmdeploy lite auto_awq \ $HF_MODEL \ --calib-dataset 'wikitext2' \ --calib-samples 128 \ --calib-seqlen 2048 \ --w-bits 4 \ --w-group-size 128 \ --batch-size 1 \ --work-dir $WORK_DIR ``` 绝大多数情况下,在执行上述命令时,可选参数可不用填写,使用默认的即可。比如量化 [internlm/internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) 模型,命令可以简化为: ```shell lmdeploy lite auto_awq internlm/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-4bit ``` **Note:** - 我们建议 --work-dir 参数带有模型名字,就像上面的例子展示的那样。这样在推理时,就不用指定对话模板了。因为推理接口会以模糊搜索方式,选出和 --work-dir 近似的对话模板 - 如果量化模型精度有损,建议开启 --search-scale 重新量化,并调大 --batch-size,比如 8。search_scale 开启后,量化过程会比较耗时。--batch-size 会影响内存占用量,可以根据实际情况,酌情调整。 量化后的模型,可以用一些工具快速验证对话效果。 比如,直接在控制台和模型对话, ```shell lmdeploy chat ./internlm2_5-7b-chat-4bit --model-format awq ``` ## 模型评测 我们使用 [OpenCompass](https://opencompass.readthedocs.io/zh-cn/latest/index.html) 评测量化模型在各个维度上的能力。方法请参考[此处](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/evaluation_lmdeploy.html) ## 模型推理 量化后的模型,通过以下几行简单的代码,可以实现离线推理: ```python from lmdeploy import pipeline, TurbomindEngineConfig engine_config = TurbomindEngineConfig(model_format='awq') pipe = pipeline("./internlm2_5-7b-chat-4bit", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` 关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md) 除了推理本地量化模型外,LMDeploy 还支持直接推理 huggingface hub 上的通过 AWQ 量化的 4bit 权重模型,比如 [lmdeploy 空间](https://huggingface.co/lmdeploy)和 [TheBloke 空间](https://huggingface.co/TheBloke)下的模型。 ```python # 推理 lmdeploy 空间下的模型 from lmdeploy import pipeline, TurbomindEngineConfig pipe = pipeline("lmdeploy/llama2-chat-70b-4bit", backend_config=TurbomindEngineConfig(model_format='awq', tp=4)) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) # 推理 TheBloke 空间下的模型(试试codellama行不行) from lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig pipe = pipeline("TheBloke/LLaMA2-13B-Tiefighter-AWQ", backend_config=TurbomindEngineConfig(model_format='awq'), chat_template_config=ChatTemplateConfig(model_name='llama2') ) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` ## 推理服务 LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: ```shell lmdeploy serve api_server ./internlm2_5-7b-chat-4bit --backend turbomind --model-format awq ``` 服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话: ```shell lmdeploy serve api_client http://0.0.0.0:23333 ``` 还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 ## 推理性能 我们在 NVIDIA GeForce RTX 4090 上分别测试了 4-bit Llama-2-7B-chat 和 Llama-2-13B-chat 模型的 token 生成速度。测试配置为 batch size = 1,(prompt_tokens, completion_tokens) = (1, 512) | model | llm-awq | mlc-llm | turbomind | | ---------------- | ------- | ------- | --------- | | Llama-2-7B-chat | 112.9 | 159.4 | 206.4 | | Llama-2-13B-chat | N/A | 90.7 | 115.8 | ## 快速问答 1. 量化时出现 Out of Memory 显存不够:可以通过减小传参 `--calib-seqlen`,增大传参 `--calib-samples`,并使用 `--batch-size` 为 1。 2. 量化时,无法链接huggingface并下载数据集。可以尝试使用镜像,`export HF_ENDPOINT=https://hf-mirror.com`。 ================================================ FILE: docs/zh_cn/quantization/w8a8.md ================================================ # W8A8 LLM 模型部署 LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。 可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为: - INT8 - V100(sm70): V100 - Turing(sm75): 20 series, T4 - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 - Ada Lovelace(sm89): 40 series - Hopper(sm90): H100 - FP8 - Ada Lovelace(sm89): 40 series - Hopper(sm90): H100 首先,执行如下命令安装lmdeploy: ```shell pip install lmdeploy[all] ``` ## 8-bit 权重量化 进行 8-bit 权重量化需要经历以下三步: 1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。 2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。 3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。 lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。 以下示例演示了进行 int8 或 fp8 的量化命令。 - int8 ```shell lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 ``` - fp8 ```shell lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 ``` ## 模型推理 量化后的模型,通过以下几行简单的代码,可以实现离线推理: ```python from lmdeploy import pipeline, PytorchEngineConfig engine_config = PytorchEngineConfig(tp=1) pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) response = pipe(["Hi, pls intro yourself", "Shanghai is"]) print(response) ``` 关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md) ## 推理服务 LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: ```shell lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch ``` 服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话: ```shell lmdeploy serve api_client http://0.0.0.0:23333 ``` 还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 ================================================ FILE: docs/zh_cn/supported_models/reward_models.md ================================================ # 奖励模型 LMDeploy 支持以下奖励模型: | Model | Size | Supported Inference Engine | | :--------------: | :-----------: | :------------------------: | | Qwen2.5-Math-RM | 72B | PyTorch | | InternLM2-Reward | 1.8B, 7B, 20B | PyTorch | | POLAR | 1.8B, 7B | PyTorch | ## 离线推理 我们以 `internlm/internlm2-1_8b-reward` 为例: ```python from transformers import AutoTokenizer from lmdeploy import pipeline, PytorchEngineConfig model_path = "internlm/internlm2-1_8b-reward" chat = [ {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, {"role": "user", "content": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"}, {"role": "assistant", "content": "To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to follow these steps:\n\n1. Calculate the total number of eggs laid by the ducks each day.\n2. Determine how many eggs Janet eats and bakes for herself each day.\n3. Find out how many eggs are left to be sold.\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\n\nLet's start with the first step:\n\n1. Janet's ducks lay 16 eggs per day.\n\nNext, we calculate how many eggs Janet eats and bakes for herself each day:\n\n2. Janet eats 3 eggs for breakfast every morning.\n3. Janet bakes 4 eggs for her friends every day.\n\nSo, the total number of eggs Janet eats and bakes for herself each day is:\n\\[ 3 + 4 = 7 \\text{ eggs} \\]\n\nNow, we find out how many eggs are left to be sold:\n\\[ 16 - 7 = 9 \\text{ eggs} \\]\n\nFinally, we calculate the revenue from selling the remaining eggs at $2 per egg:\n\\[ 9 \\times 2 = 18 \\text{ dollars} \\]\n\nTherefore, Janet makes 18 dollars every day at the farmers' market."} ] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) conversation_str = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=False ) input_ids = tokenizer.encode( conversation_str, add_special_tokens=False ) if __name__ == '__main__': engine_config = PytorchEngineConfig(tp=tp) with pipeline(model_path, backend_config=engine_config) as pipe: score = pipe.get_reward_score(input_ids) print(f'score: {score}') ``` ## 在线推理 启动 API 服务: ```bash lmdeploy serve api_server internlm/internlm2-1_8b-reward --backend pytorch ``` 通过 `/pooling` 接口获取奖励分数: ``` curl http://0.0.0.0:23333/pooling \ -H "Content-Type: application/json" \ -d '{ "model": "internlm/internlm2-1_8b-reward", "input": "Who are you?" }' ``` ================================================ FILE: docs/zh_cn/supported_models/supported_models.md ================================================ # 支持的模型 以下列表分别为 LMDeploy TurboMind 引擎和 PyTorch 引擎在不同软硬件平台下支持的模型 ## TurboMind CUDA 平台 | Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | | :------------------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: | | Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | | InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes | | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Intern-S1 | 241B | MLLM | Yes | Yes | Yes | No | | Intern-S1-mini | 8.3B | MLLM | Yes | Yes | Yes | No | | Intern-S1-Pro | 1TB | MLLM | Yes | - | - | No | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | | Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | | Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | | Qwen3 | 0.6B-235B | LLM | Yes | Yes | Yes\* | Yes | | Qwen3.5\[3\] | 0.8B-397B | LLM | Yes | Yes | No | Yes | | Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | | DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | | Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | | Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No | | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | | InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL3\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes | | InternVL3.5\[3\] | 1 - 241BA28B | MLLM | Yes | Yes\* | Yes\* | No | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | | GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | | Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | | gpt-oss | 20B,120B | LLM | Yes | Yes | Yes | Yes | “-” 表示还没有验证。 ```{note} * [1] turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine * [2] 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等 * [3] turbomind 目前暂不支持 Qwen3.5 系列的视觉编码器。 ``` ## PyTorchEngine CUDA 平台 | Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 | | :----------------------------: | :-------------: | :--: | :-------: | :-----: | :-----: | :--: | :---: | | Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes | | Llama4 | Scout, Maverick | MLLM | Yes | Yes | Yes | - | - | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | | Intern-S1 | 241B | MLLM | Yes | Yes | Yes | Yes | - | | Intern-S1-mini | 8.3B | MLLM | Yes | Yes | Yes | Yes | - | | Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No | | Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No | | ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No | | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No | | QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes | | QWen3.5 | 0.8B-397B | MLLM | Yes | No | No | No | No | | QWen3-Next | 80B | LLM | Yes | No | No | No | No | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | | QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | | MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | | Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | | Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - | | Phi-4-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes | | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No | | InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - | | InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - | | InternVL3 | 1B-78B | MLLM | Yes | Yes | Yes | - | - | | InternVL3.5 | 1B-241BA28B | MLLM | Yes | Yes | Yes | No | No | | Mono-InternVL\[1\] | 2B | MLLM | Yes\* | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | | Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - | | Gemma3 | 1B-27B | MLLM | Yes | Yes | Yes | - | - | | GLM-4 | 9B | LLM | Yes | Yes | Yes | No | No | | GLM-4-0414 | 9B | LLM | Yes | Yes | Yes | - | - | | GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes | | GLM-4.1V-Thinking | 9B | MLLM | Yes | Yes | Yes | - | - | | GLM-4.5 | 355B | LLM | Yes | Yes | Yes | - | - | | GLM-4.5-Air | 106B | LLM | Yes | Yes | Yes | - | - | | GLM-4.7-Flash | 30B | LLM | Yes | No | No | No | No | | GLM-5 | 754B | LLM | Yes | No | No | No | No | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - | | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | | SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | ```{note} * [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16 * [2] 自 0.6.4 之后,PyTorch 引擎移除了对 llava 模型原始格式的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到 自 0.11.1 起,PytorchEngine 移除了 mllama 的支持 ``` ## PyTorchEngine 其他平台 | | | | Atlas 800T A2 | Atlas 800T A2 | Atlas 800T A2 | Atlas 800T A2 | Atlas 300I Duo | Atlas 800T A3 | Maca C500 | Cambricon | | :------------: | :-------: | :--: | :--------------: | :--------------: | :-----------: | :-----------: | :------------: | :--------------: | :-------: | :-------: | | Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | FP16(graph) | FP16/BF16(eager) | BF/FP16 | BF/FP16 | | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | - | Yes | Yes | Yes | | Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | Mixtral | 8x7B | LLM | Yes | Yes | No | No | Yes | - | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | - | - | Yes | - | | QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | - | - | Yes | No | | QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | Yes | - | Yes | No | | QWen2-MoE | A14.57B | LLM | Yes | - | No | No | - | - | Yes | - | | QWen3 | 0.6B-235B | LLM | Yes | Yes | No | No | Yes | Yes | Yes | Yes | | DeepSeek-V2 | 16B | LLM | No | Yes | No | No | - | - | - | - | | InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | - | - | Yes | - | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | InternVL3 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes | | CogVLM2-chat | 19B | MLLM | Yes | No | - | - | - | - | Yes | - | | GLM4V | 9B | MLLM | Yes | No | - | - | - | - | - | - | ================================================ FILE: eval/config.py ================================================ # flake8: noqa from mmengine.config import read_base from opencompass.models import OpenAISDK from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner from opencompass.runners import LocalRunner from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask from opencompass.utils.text_postprocessors import extract_non_reasoning_content # Dataset Configurations with read_base(): # Datasets from opencompass.configs.datasets.aime2025.aime2025_llmjudge_academic import aime2025_datasets from opencompass.configs.datasets.gpqa.gpqa_cascade_eval_academic import gpqa_datasets from opencompass.configs.datasets.HLE.hle_llmverify_academic import hle_datasets from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets from opencompass.configs.datasets.livecodebench.livecodebench_v6_academic import LCBCodeGeneration_dataset from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_nocot_genericllmeval_gen_08c1de import mmlu_pro_datasets # Summary Groups from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups # datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + [LCBCodeGeneration_dataset] # TASK_TAG = '' API_SERVER_ADDR = 'http://' SERVED_MODEL_PATH = '' models = [ dict(abbr=TASK_TAG, key='dummy', openai_api_base=f'{API_SERVER_ADDR}/v1', type=OpenAISDK, path=SERVED_MODEL_PATH, temperature=0.6, meta_template=dict(round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ], ), query_per_second=10, max_out_len=64000, max_seq_len=65536, batch_size=32, retry=10, pred_postprocessor=dict(type=extract_non_reasoning_content), verbose=False) ] JUDGER_ADDR = 'http://' JUDGER_MODEL_PATH = '' judge_cfg = dict( abbr='CompassVerifier', type=OpenAISDK, path=JUDGER_MODEL_PATH, key='YOUR_API_KEY', openai_api_base=f'{JUDGER_ADDR}/v1', meta_template=dict(round=[ dict(role='HUMAN', api_role='HUMAN'), dict(role='BOT', api_role='BOT', generate=True), ]), query_per_second=8, batch_size=32, temperature=0.001, max_out_len=8192, max_seq_len=65536, mode='mid', ) for item in datasets: if 'judge_cfg' in item['eval_cfg']['evaluator']: item['eval_cfg']['evaluator']['judge_cfg'] = judge_cfg if 'llm_evaluator' in item['eval_cfg']['evaluator'].keys( ) and 'judge_cfg' in item['eval_cfg']['evaluator']['llm_evaluator']: item['eval_cfg']['evaluator']['llm_evaluator']['judge_cfg'] = judge_cfg ####################################################################### # Dataset Summarizer # ####################################################################### core_summary_groups = [ { 'name': 'core_average', 'subsets': [ ['IFEval', 'Prompt-level-strict-accuracy'], ['hle_llmjudge', 'accuracy'], ['aime2025_repeat_32', 'accuracy (32 runs average)'], ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'], ['mmlu_pro', 'naive_average'], ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'], ], }, ] summarizer = dict( dataset_abbrs=[ ['core_average', 'naive_average'], '', 'Instruction Following', ['IFEval', 'Prompt-level-strict-accuracy'], '', 'General Reasoning', ['hle_llmjudge', 'accuracy'], ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'], '', 'Math Calculation', ['aime2025_repeat_32', 'accuracy (32 runs average)'], '', 'Knowledge', ['mmlu_pro', 'naive_average'], '', 'Code', ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'], ], summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), ) ####################################################################### # Inference/Evaluation Configuration # ####################################################################### # infer with local runner infer = dict( partitioner=dict(type=NumWorkerPartitioner, num_worker=8), runner=dict( type=LocalRunner, max_num_workers=16, retry=0, # Modify if needed task=dict(type=OpenICLInferTask), ), ) # eval with local runner eval = dict( partitioner=dict(type=NaivePartitioner, n=10), runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLEvalTask)), ) ================================================ FILE: eval/eval.py ================================================ import argparse import os import signal import subprocess import sys from datetime import datetime class ProcessManager: """Manager for subprocess execution with proper signal handling.""" def __init__(self): self.process = None self.original_handlers = {} def __enter__(self): """Context manager entry - setup signal handlers""" # Save original signal handlers self.original_handlers[signal.SIGINT] = signal.getsignal(signal.SIGINT) self.original_handlers[signal.SIGTERM] = signal.getsignal(signal.SIGTERM) # Register new signal handlers signal.signal(signal.SIGINT, self._signal_handler) signal.signal(signal.SIGTERM, self._signal_handler) return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit - restore original signal handlers""" # Restore original signal handlers for sig, handler in self.original_handlers.items(): signal.signal(sig, handler) def _signal_handler(self, sig, frame): """Handle termination signals.""" signal_name = 'SIGINT' if sig == signal.SIGINT else 'SIGTERM' print(f'\nReceived {signal_name}, cleaning up subprocess...') self.cleanup() sys.exit(0) def start_process(self, cmd): self.process = subprocess.Popen(cmd) return self.process def cleanup(self): if self.process and self.process.poll() is None: print('Terminating subprocess...') self.process.terminate() try: self.process.wait(timeout=5) print('Subprocess terminated successfully') except subprocess.TimeoutExpired: print('Subprocess did not terminate normally, forcing kill...') self.process.kill() self.process.wait() print('Subprocess killed') def read_config(): """Get configuration content from config file in script directory. Returns: str: Configuration file content, returns None if reading fails """ script_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(script_dir, 'config.py') # Read config file content try: with open(config_path, 'r', encoding='utf-8') as f: config_content = f.read() return config_content except FileNotFoundError: print(f'Error: Config file not found at {config_path}') return None except Exception as e: print(f'Error reading config file: {e}') return None def update_datasets(config, datasets): """Update datasets part in config according to datasets list. Args: config (str): Original configuration content datasets (list[str]): List of dataset names to include Returns: str: Updated configuration content """ if 'all' in datasets: # datasets part of the config file specifies all datasets, no need to update return config selected_datasets = [] if 'code' in datasets: selected_datasets.append('[LCBCodeGeneration_dataset]') datasets.remove('code') for d in datasets: selected_datasets.append(f'{d}_datasets') selected_datasets = ' + '.join(selected_datasets) selected_datasets = f'datasets = {selected_datasets}' # replace datasets part in config start_tag = '# ' end_tag = '# ' start_index = config.find(start_tag) end_index = config.find(end_tag) if start_index == -1 or end_index == -1: raise ValueError('replace tag not found in config file') end_index += len(end_tag) replacement = f'{start_tag}\n{selected_datasets}\n{end_tag}' result = config[:start_index] + replacement + config[end_index:] return result def get_model_name_from_server(server: str, tag: str) -> str: from openai import OpenAI try: client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{server}/v1') model_name = client.models.list().data[0].id return model_name except Exception as e: raise RuntimeError(f'Failed to get model name from {tag}_server {server}: {e}') def save_config(work_dir: str, config: str): """Save configuration content to a file in the specified directory. Args: work_dir (str): Directory to save the configuration file config (str): Configuration content to save """ if not work_dir: return os.makedirs(work_dir, exist_ok=True) output_file = os.path.join(work_dir, 'config.py') with open(output_file, 'w', encoding='utf-8') as f: f.write(config) print(f'Config written to {output_file}') def perform_evaluation(config, api_server, judger_server, mode, work_dir, reuse): """Perform model evaluation by opencompass. Args: config (str): Configuration content api_server (str): API server address for inference judger_server (str): Judger server address for evaluation mode (str): Running mode selection, options: infer, eval, all, config work_dir (str): Output directory for evaluation results. If not specified, config will not be saved and execution will not be performed. reuse (str): Whether to reuse existing results """ if mode in ['infer', 'all', 'config']: served_model_name = get_model_name_from_server(api_server, 'api') config = config.replace("SERVED_MODEL_PATH = ''", f"SERVED_MODEL_PATH = '{served_model_name}'") if mode in ['eval', 'all', 'config']: judger_model_name = get_model_name_from_server(judger_server, 'judger') config = config.replace("JUDGER_MODEL_PATH = ''", f"JUDGER_MODEL_PATH = '{judger_model_name}'") # write updated config to work_dir if work_dir: save_config(work_dir, config) if mode == 'config': return else: print(config) return # execute opencompass command cmd = ['opencompass', f'{work_dir}/config.py', '-m', mode, '-w', work_dir] if reuse: # reuse previous outputs & results. If reuse is a string, it indicates a specific timestamp. try: datetime.strptime(reuse, '%Y%m%d_%H%M%S') cmd.extend(['-r', str(reuse)]) except ValueError as e: print(e) raise ValueError(f'Invalid reuse timestamp format: {reuse}. Expected format: YYYYMMDD_HHMMSS') from e try: print(f'Executing command: {" ".join(cmd)}') # result = subprocess.run(cmd, text=True, check=True) # return result with ProcessManager() as manager: process = manager.start_process(cmd) result = process.wait() return subprocess.CompletedProcess(cmd, result) except Exception as e: print(f'Executing commanded failed with {e}') return def main(): parser = argparse.ArgumentParser(description='Perform model evaluation') parser.add_argument('task_name', type=str, help='The name of an evaluation task') parser.add_argument('-a', '--api-server', type=str, default='', help='API server address for inference') parser.add_argument('-j', '--judger-server', type=str, default='', help='Judger server address for evaluation') dataset_choices = ['aime2025', 'gpqa', 'ifeval', 'code', 'mmlu_pro', 'hle', 'all'] parser.add_argument('-d', '--datasets', nargs='+', choices=dataset_choices, default=['all'], help=f"List of datasets. Available options: {', '.join(dataset_choices)}. " 'Use "all" to include all datasets.') parser.add_argument('-w', '--work-dir', type=str, default='', help='Output directory of evaluation. If not specified, outputs will not be saved.') parser.add_argument('-r', '--reuse', nargs='?', type=str, const='latest', help='Reuse previous outputs & results, and run any missing jobs presented in the config. ' 'If its argument is not specified, the latest results in the work_dir will be reused. ' 'The argument should also be a specific timestamp, e.g. 20230516_144254') parser.add_argument('-m', '--mode', type=str, help='Running mode selection. ' 'all: complete pipeline including both inference and evaluation (default). ' 'infer: only perform model inference to generate results. ' 'eval: only evaluate previously generated results. ' 'config: generate configuration files without execution.', choices=['all', 'infer', 'eval', 'config'], default='all') args = parser.parse_args() task_name = args.task_name api_server = args.api_server judger_server = args.judger_server datasets = args.datasets mode = args.mode work_dir = args.work_dir # Process server addresses if api_server and not api_server.startswith('http'): api_server = f'http://{api_server}' if judger_server and not judger_server.startswith('http'): judger_server = f'http://{judger_server}' # read config file config = read_config() # update task name in config config = config.replace("TASK_TAG = ''", f"TASK_TAG = '{task_name}'") # update datasets part of config according to args.datasets config = update_datasets(config, datasets) # update api_server part of config according to args.api_server if api_server: config = config.replace("API_SERVER_ADDR = 'http://'", f"API_SERVER_ADDR = '{api_server}'") if judger_server: # update judger_server part of config according to args.judger_server config = config.replace("JUDGER_ADDR = 'http://'", f"JUDGER_ADDR = '{judger_server}'") # perform evaluation perform_evaluation(config, api_server, judger_server, mode, work_dir, args.reuse) if __name__ == '__main__': main() ================================================ FILE: examples/lite/qwen3_30b_a3b_awq.py ================================================ import argparse from datasets import load_dataset from llmcompressor import oneshot from llmcompressor.modifiers.awq import AWQModifier from transformers import AutoModelForCausalLM, AutoTokenizer def parse_args(): parser = argparse.ArgumentParser(description='Run AWQ quantization for Qwen3 model') parser.add_argument('--work-dir', type=str, default='./qwen3_30b_a3b_awq', required=True, help='The directory to save the quantized model') parser.add_argument('--model-id', type=str, default='Qwen/Qwen3-30B-A3B', help='The Hugging Face model ID to quantize') return parser.parse_args() def main(): # 1. Achieve command args args = parse_args() MODEL_ID = args.model_id SAVE_DIR = args.work_dir print(f'Loading model: {MODEL_ID}') print(f'Saving to: {SAVE_DIR}') # 2. Load_dataset and tokenizer model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype='auto', device_map='auto', trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # 3. Prepare calibration dataset DATASET_ID = 'neuralmagic/calibration' DATASET_SPLIT = 'train' NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 512 def get_calib_dataset(tokenizer): ds = load_dataset( DATASET_ID, 'LLM', split=f'{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]', ) def preprocess(example): messages = [] for message in example['messages']: if message['role'] == 'user': messages.append({'role': 'user', 'content': message['content']}) elif message['role'] == 'assistant': messages.append({'role': 'assistant', 'content': message['content']}) return tokenizer( tokenizer.apply_chat_template( messages, tokenize=False, ), padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, ) ds = (ds.shuffle(seed=42).map(preprocess, remove_columns=ds.column_names).select(range(NUM_CALIBRATION_SAMPLES))) return ds # 4. Configure quant args (W4A16_ASYM AWQ) recipe = [ AWQModifier( ignore=['lm_head', 're:.*mlp.gate$'], scheme='W4A16_ASYM', targets=['Linear'], duo_scaling='both', ), ] # 5. Run quantization print('Starting quantization...') oneshot( model=model, dataset=get_calib_dataset(tokenizer), recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, log_dir=None, ) # 6. Save quantized model print('Saving model...') model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) print(f'Successfully saved to {SAVE_DIR}') if __name__ == '__main__': main() ================================================ FILE: examples/lite/qwen3_30b_a3b_gptq.py ================================================ import argparse from datasets import load_dataset from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier from transformers import AutoModelForCausalLM, AutoTokenizer def parse_args(): parser = argparse.ArgumentParser(description='Run GPTQ quantization for Qwen3 model') parser.add_argument('--work-dir', type=str, default='./qwen3_30b_a3b_gptq', required=True, help='The directory to save the quantized model') parser.add_argument('--model-id', type=str, default='Qwen/Qwen3-30B-A3B', help='The Hugging Face model ID to quantize') return parser.parse_args() def main(): # 1. Achieve command args args = parse_args() MODEL_ID = args.model_id SAVE_DIR = args.work_dir print(f'Loading model: {MODEL_ID}') print(f'Saving to: {SAVE_DIR}') # 2. Load_dataset and tokenizer model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype='auto', device_map='auto', trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # 3. Prepare calibration dataset DATASET_ID = 'neuralmagic/calibration' DATASET_SPLIT = 'train' NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 512 def get_calib_dataset(tokenizer): ds = load_dataset( DATASET_ID, 'LLM', split=f'{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]', ) def preprocess(example): messages = [] for message in example['messages']: if message['role'] == 'user': messages.append({'role': 'user', 'content': message['content']}) elif message['role'] == 'assistant': messages.append({'role': 'assistant', 'content': message['content']}) return tokenizer( tokenizer.apply_chat_template( messages, tokenize=False, ), padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, ) ds = (ds.shuffle(seed=42).map(preprocess, remove_columns=ds.column_names).select(range(NUM_CALIBRATION_SAMPLES))) return ds # 4. Configure quant args (W4A16_ASYM AWQ) recipe = [ GPTQModifier(targets='Linear', scheme='W4A16_ASYM', ignore=['lm_head', 're:.*mlp.gate$']), ] # 5. Run quantization print('Starting quantization...') oneshot( model=model, dataset=get_calib_dataset(tokenizer), recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, log_dir=None, ) # 6. Save quantized model print('Saving model...') model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) print(f'Successfully saved to {SAVE_DIR}') if __name__ == '__main__': main() ================================================ FILE: generate.sh ================================================ #!/bin/bash WORKSPACE_PATH=$(dirname "$(readlink -f "$0")") builder="-G Ninja" if [ "$1" == "make" ]; then builder="" fi cmake ${builder} .. \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DCMAKE_INSTALL_PREFIX=${WORKSPACE_PATH}/install \ -DBUILD_PY_FFI=ON \ -DBUILD_MULTI_GPU=ON \ -DCMAKE_CUDA_FLAGS="-lineinfo" \ -DUSE_NVTX=ON \ -DFETCHCONTENT_QUIET=OFF ================================================ FILE: k8s/deployment.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: labels: app: internlm2-chat-7b name: internlm2-chat-7b spec: replicas: 1 selector: matchLabels: app: internlm2-chat-7b strategy: {} template: metadata: labels: app: internlm2-chat-7b spec: containers: - name: internlm2-chat-7b image: openmmlab/lmdeploy:latest command: - /bin/sh - -c args: - "lmdeploy serve api_server internlm/internlm2-chat-7b --server-port 23333" env: - name: HUGGING_FACE_HUB_TOKEN value: "{{HUGGING_FACE_HUB_TOKEN}}" ports: - containerPort: 23333 protocol: TCP name: main resources: limits: cpu: "16" memory: 64Gi nvidia.com/gpu: "1" requests: cpu: "16" memory: 64Gi nvidia.com/gpu: "1" readinessProbe: failureThreshold: 3 initialDelaySeconds: 400 periodSeconds: 10 successThreshold: 1 tcpSocket: port: main timeoutSeconds: 1 livenessProbe: failureThreshold: 3 initialDelaySeconds: 900 periodSeconds: 20 successThreshold: 1 tcpSocket: port: main timeoutSeconds: 1 volumeMounts: - mountPath: /root/.cache/huggingface name: model-data - mountPath: /dev/shm name: dshm volumes: - name: model-data hostPath: path: /root/.cache/huggingface type: DirectoryOrCreate - emptyDir: medium: Memory name: dshm ================================================ FILE: k8s/service.yaml ================================================ apiVersion: v1 kind: Service metadata: labels: app: internlm2-chat-7b name: internlm2-chat-7b-svc spec: ports: - name: main port: 23333 protocol: TCP targetPort: main selector: app: internlm2-chat-7b type: ClusterIP ================================================ FILE: lmdeploy/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .api import client, pipeline, serve from .messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, VisionConfig from .model import ChatTemplateConfig from .pipeline import Pipeline from .tokenizer import Tokenizer from .version import __version__, version_info __all__ = [ 'pipeline', 'serve', 'client', 'Tokenizer', 'GenerationConfig', '__version__', 'version_info', 'ChatTemplateConfig', 'PytorchEngineConfig', 'TurbomindEngineConfig', 'VisionConfig', 'Pipeline' ] ================================================ FILE: lmdeploy/__main__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .cli import run if __name__ == '__main__': run() ================================================ FILE: lmdeploy/api.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from __future__ import annotations from typing import TYPE_CHECKING, List, Literal from typing_extensions import deprecated from .pipeline import Pipeline if TYPE_CHECKING: from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig from .model import ChatTemplateConfig def pipeline(model_path: str, backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None, chat_template_config: 'ChatTemplateConfig' | None = None, log_level: str = 'WARNING', max_log_len: int | None = None, speculative_config: 'SpeculativeConfig' | None = None, **kwargs): """ Args: model_path: the path of a model. It could be one of the following options: - i) A local directory path of a turbomind model which is converted by ``lmdeploy convert`` command or download from ii) and iii). - ii) The model_id of a lmdeploy-quantized model hosted inside a model repo on huggingface.co, such as ``InternLM/internlm-chat-20b-4bit``, ``lmdeploy/llama2-chat-70b-4bit``, etc. - iii) The model_id of a model hosted inside a model repo on huggingface.co, such as ``internlm/internlm-chat-7b``, ``Qwen/Qwen-7B-Chat``, ``baichuan-inc/Baichuan2-7B-Chat`` and so on. backend_config: backend config instance. Default to None. chat_template_config: chat template configuration. Default to None. log_level: set log level whose value among [``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``] max_log_len: Max number of prompt characters or prompt tokens being printed in log Examples: .. code-block:: python # LLM import lmdeploy pipe = lmdeploy.pipeline('internlm/internlm-chat-7b') response = pipe(['hi','say this is a test']) print(response) # VLM from lmdeploy.vl import load_image from lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig pipe = pipeline('liuhaotian/llava-v1.5-7b', backend_config=TurbomindEngineConfig(session_len=8192), chat_template_config=ChatTemplateConfig(model_name='vicuna')) im = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg') response = pipe([('describe this image', [im])]) print(response) """ # noqa E501 return Pipeline(model_path, backend_config=backend_config, chat_template_config=chat_template_config, log_level=log_level, max_log_len=max_log_len, speculative_config=speculative_config, **kwargs) @deprecated('This function is no longer available. Please use CLI command "lmdeploy serve api_server" instead.') def serve(model_path: str, model_name: str | None = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None, chat_template_config: 'ChatTemplateConfig' | None = None, server_name: str = '0.0.0.0', server_port: int = 23333, log_level: str = 'ERROR', api_keys: List[str] | str | None = None, ssl: bool = False, **kwargs): """This function is deprecated and no longer available. .. deprecated:: This function has been removed. Please use alternative methods. This will run the api_server in a subprocess. """ # noqa E501 raise NotImplementedError("The 'serve' function is no longer available. " 'This function has been deprecated and removed.') @deprecated('This function is no longer available. Please use "from lmdeploy.serve import APIClient" instead.') def client(api_server_url: str = 'http://0.0.0.0:23333', api_key: str | None = None, **kwargs): """This function is deprecated and no longer available. .. deprecated:: This function has been removed. Please use ``from lmdeploy.serve import APIClient`` instead. Args: api_server_url: communicating address ``http://:`` of api_server api_key: api key. Default to None, which means no api key will be used. Return: Chatbot for LLaMA series models with turbomind as inference engine. """ raise NotImplementedError("The 'client' function is no longer available. This function has been deprecated. " ' Please use "from lmdeploy.serve import APIClient" instead.') ================================================ FILE: lmdeploy/archs.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os from typing import Dict, List, Literal, Tuple from transformers import AutoConfig from .messages import PytorchEngineConfig, TurbomindEngineConfig from .utils import get_logger logger = get_logger('lmdeploy') def autoget_backend(model_path: str) -> Literal['turbomind', 'pytorch']: """Get backend type in auto backend mode. Args: model_path (str): the path of a model. It could be one of the following options: - i) A local directory path of a turbomind model which is converted by `lmdeploy convert` command or download from ii) and iii). - ii) The model_id of a lmdeploy-quantized model hosted inside a model repo on huggingface.co, such as "InternLM/internlm-chat-20b-4bit", "lmdeploy/llama2-chat-70b-4bit", etc. - iii) The model_id of a model hosted inside a model repo on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. Returns: str: the backend type. """ turbomind_has = False is_turbomind_installed = True try: from lmdeploy.turbomind.supported_models import is_supported as is_supported_turbomind turbomind_has = is_supported_turbomind(model_path) except ImportError: is_turbomind_installed = False if is_turbomind_installed: if not turbomind_has: logger.warning('Fallback to pytorch engine because ' f'{model_path!r} not supported by turbomind' ' engine.') else: logger.warning('Fallback to pytorch engine because turbomind engine is not ' 'installed correctly. If you insist to use turbomind engine, ' 'you may need to reinstall lmdeploy from pypi or build from ' 'source and try again.') backend = 'turbomind' if turbomind_has else 'pytorch' return backend def autoget_backend_config( model_path: str, backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None ) -> Tuple[Literal['turbomind', 'pytorch'], PytorchEngineConfig | TurbomindEngineConfig]: """Get backend config automatically. Args: model_path (str): The input model path. backend_config (TurbomindEngineConfig | PytorchEngineConfig): The input backend config. Default to None. Returns: (PytorchEngineConfig | TurbomindEngineConfig): The auto-determined backend engine config. """ from dataclasses import asdict if isinstance(backend_config, PytorchEngineConfig): return 'pytorch', backend_config backend = autoget_backend(model_path) config = PytorchEngineConfig() if backend == 'pytorch' else TurbomindEngineConfig() if backend_config is not None: if type(backend_config) == type(config): config = backend_config else: data = asdict(backend_config) for k, v in data.items(): if v and hasattr(config, k): setattr(config, k, v) # map attributes with different names if type(backend_config) is TurbomindEngineConfig: config.block_size = backend_config.cache_block_seq_len else: config.cache_block_seq_len = backend_config.block_size return backend, config def check_vl_llm(backend: str, config: dict) -> bool: """Check if the model is a vl model from model config.""" if 'auto_map' in config: for _, v in config['auto_map'].items(): if 'InternLMXComposer2ForCausalLM' in v: return True if 'language_config' in config and 'vision_config' in config and config['language_config'].get( 'architectures', [None])[0] == 'DeepseekV2ForCausalLM': return True arch = config['architectures'][0] supported_archs = set([ 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', 'CogVLMForCausalLM', 'InternLMXComposer2ForCausalLM', 'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: return True elif arch == 'MultiModalityCausalLM' and 'language_config' in config: return True elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config: return True elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] and backend == 'turbomind': return False elif arch in supported_archs: return True return False def get_task(backend: str, model_path: str): """Get pipeline type and pipeline class from model config.""" from lmdeploy.serve.core import AsyncEngine if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')): # workspace model return 'llm', AsyncEngine _, config = get_model_arch(model_path) if check_vl_llm(backend, config.to_dict()): from lmdeploy.serve.core import VLAsyncEngine return 'vlm', VLAsyncEngine # default task, pipeline_class return 'llm', AsyncEngine def get_model_arch(model_path: str): """Get a model's architecture and configuration. Args: model_path(str): the model path """ try: cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) except Exception as e: # noqa from transformers import PretrainedConfig cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) _cfg = cfg.to_dict() if _cfg.get('architectures', None): arch = _cfg['architectures'][0] if _cfg.get('auto_map'): for _, v in _cfg['auto_map'].items(): if 'InternLMXComposer2ForCausalLM' in v: arch = 'InternLMXComposer2ForCausalLM' elif _cfg.get('auto_map', None) and 'AutoModelForCausalLM' in _cfg['auto_map']: arch = _cfg['auto_map']['AutoModelForCausalLM'].split('.')[-1] elif _cfg.get('language_config', None) and _cfg['language_config'].get( 'auto_map', None) and 'AutoModelForCausalLM' in _cfg['language_config']['auto_map']: arch = _cfg['language_config']['auto_map']['AutoModelForCausalLM'].split('.')[-1] else: raise RuntimeError(f'Could not find model architecture from config: {_cfg}') return arch, cfg def search_nested_config(config, key): """Recursively searches for the value associated with the given key in a nested configuration of a model.""" if isinstance(config, Dict): for k, v in config.items(): if k == key: return v if isinstance(v, (Dict, List)): result = search_nested_config(v, key) if result is not None: return result elif isinstance(config, List): for item in config: result = search_nested_config(item, key) if result is not None: return result return None ================================================ FILE: lmdeploy/cli/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .entrypoint import run __all__ = ['run'] ================================================ FILE: lmdeploy/cli/chat.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import closing import fire from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline from lmdeploy.archs import autoget_backend def input_prompt(): """Input a prompt in the consolo interface.""" print('\ndouble enter to end input >>> ', end='') sentinel = '' # ends when this string is seen return '\n'.join(iter(input, sentinel)) def build_pipe(model_path, backend, **kwargs): engine_config = None if kwargs.get('enable_prefix_caching', False): print('interactive chat cannot be used when prefix caching is enabled') exit(-1) if backend == 'turbomind': engine_config = TurbomindEngineConfig() for key, value in kwargs.items(): if hasattr(TurbomindEngineConfig, key): setattr(engine_config, key, value) else: engine_config = PytorchEngineConfig() for key, value in kwargs.items(): key = 'device_type' if key == 'device' else key if hasattr(PytorchEngineConfig, key): setattr(engine_config, key, value) if kwargs.get('adapters', None): from .utils import get_lora_adapters adapters = get_lora_adapters(kwargs['adapters']) engine_config.adapters = adapters # disable metrics to avoid installing prometheus_client, which is not needed # in interactive chat engine_config.enable_metrics = False # set chat template config chat_template = kwargs.get('chat_template', None) chat_template_config = None if chat_template: from .utils import get_chat_template chat_template_config = get_chat_template(chat_template, model_path) pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, log_level='ERROR', **kwargs) return pipe def build_gen_config(**kwargs): gen_config = GenerationConfig(do_sample=True, max_new_tokens=4096) for key, value in kwargs.items(): if hasattr(GenerationConfig, key): setattr(gen_config, key, value) return gen_config def get_adapter_name(adapters=None, **kwargs): if adapters is None: return None from .utils import get_lora_adapters adapters = get_lora_adapters(adapters) return list(adapters.keys())[0] def main(model_path, backend, **kwargs): if backend != 'pytorch': # set auto backend mode backend = autoget_backend(model_path) quit = False with build_pipe(model_path, backend, **kwargs) as pipe: gen_config = build_gen_config(**kwargs) adapter_name = get_adapter_name(**kwargs) while not quit: with closing(pipe.session()) as sess: while True: try: prompt = input_prompt() except KeyboardInterrupt: quit = True break if prompt == 'end': sess.close() break if prompt == 'exit': quit = True break if prompt.strip() == '': continue resps = pipe.chat(prompt, session=sess, gen_config=gen_config, adapter_name=adapter_name, stream_response=True) try: for resp in resps: print(resp.text, end='', flush=True) except KeyboardInterrupt: sess.abort() else: print('exiting...') if __name__ == '__main__': fire.Fire(main) ================================================ FILE: lmdeploy/cli/cli.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os from ..version import __version__ from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args, get_speculative_config) class CLI(object): _desc = 'The CLI provides a unified API for converting, ' \ 'compressing and deploying large language models.' parser = FlexibleArgumentParser(prog='lmdeploy', description=_desc, add_help=True) parser.add_argument('-v', '--version', action='version', version=__version__) subparsers = parser.add_subparsers(title='Commands', description='lmdeploy has following commands:', dest='command') @staticmethod def add_parser_chat(): """Add parser for list command.""" parser = CLI.subparsers.add_parser('chat', formatter_class=DefaultsAndTypesHelpFormatter, description=CLI.chat.__doc__, help=CLI.chat.__doc__) parser.set_defaults(run=CLI.chat) parser.add_argument('model_path', type=str, help='The path of a model. it could be one of the following ' 'options: - i) a local directory path of a turbomind model' ' which is converted by `lmdeploy convert` command or ' 'download from ii) and iii). - ii) the model_id of a ' 'lmdeploy-quantized model hosted inside a model repo on ' 'huggingface.co, such as "internlm/internlm-chat-20b-4bit",' ' "lmdeploy/llama2-chat-70b-4bit", etc. - iii) the model_id' ' of a model hosted inside a model repo on huggingface.co,' ' such as "internlm/internlm-chat-7b", "qwen/qwen-7b-chat "' ', "baichuan-inc/baichuan2-7b-chat" and so on') # common args ArgumentHelper.backend(parser) # chat template args ArgumentHelper.chat_template(parser) # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) ArgumentHelper.dllm_block_length(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(cache_max_entry_act) tb_group._group_actions.append(prefix_caching_act) tb_group._group_actions.append(quant_policy) ArgumentHelper.model_format(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) ArgumentHelper.cp(tb_group) ArgumentHelper.async_(tb_group) # speculative decoding ArgumentHelper.add_spec_group(parser) @staticmethod def add_parser_checkenv(): """Add parser for check_env command.""" parser = CLI.subparsers.add_parser('check_env', formatter_class=DefaultsAndTypesHelpFormatter, description=CLI.check_env.__doc__, help=CLI.check_env.__doc__) parser.set_defaults(run=CLI.check_env) parser.add_argument('--dump-file', type=str, default=None, help='The file path to save env info. Only ' 'support file format in `json`, `yml`,' ' `pkl`') @staticmethod def check_env(args): """Check the environmental information.""" import importlib import mmengine from mmengine.utils import get_git_hash from mmengine.utils.dl_utils import collect_env from lmdeploy.version import __version__ env_info = collect_env() env_info['LMDeploy'] = __version__ + '+' + get_git_hash()[:7] # remove some unnecessary info remove_reqs = ['MMEngine', 'OpenCV'] for req in remove_reqs: if req in env_info: env_info.pop(req) # extra important dependencies extra_reqs = ['transformers', 'fastapi', 'pydantic', 'triton'] for req in extra_reqs: try: env_info[req] = importlib.import_module(req).__version__ except Exception: env_info[req] = 'Not Found' def get_gpu_topo(): import subprocess import sys if sys.platform.startswith('linux'): try: res = subprocess.run(['nvidia-smi', 'topo', '-m'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) if res.returncode == 0: return '\n' + res.stdout else: return None except FileNotFoundError: return None else: return None gpu_topo = get_gpu_topo() if gpu_topo is not None: env_info['NVIDIA Topology'] = gpu_topo # print env info for k, v in env_info.items(): print(f'{k}: {v}') # dump to local file dump_file = args.dump_file if dump_file is not None: work_dir, _ = os.path.split(dump_file) if work_dir: os.makedirs(work_dir, exist_ok=True) mmengine.dump(env_info, dump_file) @staticmethod def chat(args): from .chat import main kwargs = convert_args(args) speculative_config = get_speculative_config(args) to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens'] for key in to_remove: kwargs.pop(key) kwargs['speculative_config'] = speculative_config main(**kwargs) @staticmethod def add_parsers(): """Add all parsers.""" CLI.add_parser_checkenv() CLI.add_parser_chat() ================================================ FILE: lmdeploy/cli/entrypoint.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os import sys from .cli import CLI from .lite import SubCliLite from .serve import SubCliServe def run(): """The entry point of running LMDeploy CLI.""" args = sys.argv[1:] CLI.add_parsers() SubCliServe.add_parsers() SubCliLite.add_parsers() parser = CLI.parser args = parser.parse_args() if hasattr(args, 'model_name'): # if `model_name` is not specified, use the model_path instead. The # 'model_path' could be a a local path, or a repo id from hub args.model_name = args.model_name if args.model_name else \ args.model_path if 'run' in dir(args): from lmdeploy.utils import get_model model_path = getattr(args, 'model_path', None) revision = getattr(args, 'revision', None) download_dir = getattr(args, 'download_dir', None) if model_path is not None and not os.path.exists(args.model_path): args.model_path = get_model(args.model_path, download_dir=download_dir, revision=revision) model_path_or_server = getattr(args, 'model_path_or_server', None) if model_path_or_server is not None and (':' not in model_path_or_server and not os.path.exists(model_path_or_server)): args.model_path_or_server = get_model(args.model_path_or_server, download_dir=download_dir, revision=revision) args.run(args) else: try: args.print_help() except AttributeError: command = args.command if command == 'serve': SubCliServe.parser.print_help() elif command == 'lite': SubCliLite.parser.print_help() else: parser.print_help() ================================================ FILE: lmdeploy/cli/lite.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .cli import CLI from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args class SubCliLite(object): """CLI for compressing LLMs.""" _help = 'Compressing and accelerating LLMs with lmdeploy.lite module' _desc = _help parser = CLI.subparsers.add_parser( 'lite', help=_help, description=_desc, ) subparsers = parser.add_subparsers(title='Commands', description='This group has the following commands:') @staticmethod def add_parser_auto_awq(): """Add parser for auto_awq command.""" parser = SubCliLite.subparsers.add_parser('auto_awq', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliLite.auto_awq.__doc__, help=SubCliLite.auto_awq.__doc__) parser.set_defaults(run=SubCliLite.auto_awq) parser.add_argument('model', type=str, help='The path of model in hf format') ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) ArgumentHelper.work_dir(parser) ArgumentHelper.calib_dataset(parser) ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) ArgumentHelper.dtype(parser) parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)') parser.add_argument('--w-bits', type=int, default=4, help='Bit number for weight quantization') parser.add_argument('--w-sym', action='store_true', help='Whether to do symmetric quantization') parser.add_argument('--w-group-size', type=int, default=128, help='Group size for weight quantization statistics') @staticmethod def add_parser_auto_gptq(): """Add parser for auto_gptq command.""" parser = SubCliLite.subparsers.add_parser('auto_gptq', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliLite.auto_gptq.__doc__, help=SubCliLite.auto_gptq.__doc__) parser.set_defaults(run=SubCliLite.auto_gptq) parser.add_argument('model', type=str, help='The path of model in hf format') ArgumentHelper.revision(parser) ArgumentHelper.work_dir(parser) ArgumentHelper.calib_dataset(parser) ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.dtype(parser) parser.add_argument('--w-bits', type=int, default=4, help='Bit number for weight quantization') parser.add_argument('--w-group-size', type=int, default=128, help='Group size for weight quantization statistics') @staticmethod def add_parser_calibrate(): """Add parser for calibrate command.""" parser = SubCliLite.subparsers.add_parser('calibrate', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliLite.calibrate.__doc__, help=SubCliLite.calibrate.__doc__) parser.set_defaults(run=SubCliLite.calibrate) parser.add_argument('model', type=str, help='The name or path of the model to be loaded') ArgumentHelper.work_dir(parser) ArgumentHelper.calib_dataset(parser) ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) ArgumentHelper.dtype(parser) @staticmethod def add_parser_smooth_quant(): """Add parser for smooth_quant command.""" parser = SubCliLite.subparsers.add_parser('smooth_quant', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliLite.smooth_quant.__doc__, help=SubCliLite.smooth_quant.__doc__) parser.set_defaults(run=SubCliLite.smooth_quant) parser.add_argument('model', type=str, help='The name or path of the model to be loaded') parser.add_argument('--work-dir', type=str, default='./work_dir', help='The working directory for outputs. defaults to "./work_dir"') parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)') ArgumentHelper.calib_dataset(parser) ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) ArgumentHelper.dtype(parser) ArgumentHelper.quant_dtype(parser) ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) @staticmethod def auto_awq(args): """Perform weight quantization using AWQ algorithm.""" from lmdeploy.lite.apis.auto_awq import auto_awq kwargs = convert_args(args) auto_awq(**kwargs) @staticmethod def auto_gptq(args): """Perform weight quantization using GPTQ algorithm.""" from lmdeploy.lite.apis.gptq import auto_gptq kwargs = convert_args(args) auto_gptq(**kwargs) @staticmethod def calibrate(args): """Perform calibration on a given dataset.""" from lmdeploy.lite.apis.calibrate import calibrate kwargs = convert_args(args) calibrate(**kwargs) @staticmethod def smooth_quant(args): """Perform w8a8 quantization using SmoothQuant.""" from lmdeploy.lite.apis.smooth_quant import smooth_quant kwargs = convert_args(args) smooth_quant(**kwargs) @staticmethod def add_parsers(): """Add all parsers.""" SubCliLite.add_parser_auto_awq() SubCliLite.add_parser_auto_gptq() SubCliLite.add_parser_calibrate() SubCliLite.add_parser_smooth_quant() ================================================ FILE: lmdeploy/cli/serve.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend from lmdeploy.utils import get_max_batch_size from .cli import CLI from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters, get_speculative_config) class SubCliServe: """Serve LLMs and interact on terminal.""" _help = 'Serve LLMs with openai API' _desc = _help parser = CLI.subparsers.add_parser( 'serve', help=_help, description=_desc, ) subparsers = parser.add_subparsers(title='Commands', description='This group has the following commands:') @staticmethod def add_parser_api_server(): """Add parser for api_server command.""" parser = SubCliServe.subparsers.add_parser('api_server', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliServe.api_server.__doc__, help=SubCliServe.api_server.__doc__) parser.set_defaults(run=SubCliServe.api_server) parser.add_argument('model_path', type=str, help='The path of a model. it could be one of the following ' 'options: - i) a local directory path of a turbomind model' ' which is converted by `lmdeploy convert` command or ' 'download from ii) and iii). - ii) the model_id of a ' 'lmdeploy-quantized model hosted inside a model repo on ' 'huggingface.co, such as "internlm/internlm-chat-20b-4bit",' ' "lmdeploy/llama2-chat-70b-4bit", etc. - iii) the model_id' ' of a model hosted inside a model repo on huggingface.co,' ' such as "internlm/internlm-chat-7b", "qwen/qwen-7b-chat "' ', "baichuan-inc/baichuan2-7b-chat" and so on') parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for serving') parser.add_argument('--server-port', type=int, default=23333, help='Server port') parser.add_argument('--allow-origins', nargs='+', type=str, default=['*'], help='A list of allowed origins for cors') parser.add_argument('--allow-credentials', action='store_true', help='Whether to allow credentials for cors') parser.add_argument('--allow-methods', nargs='+', type=str, default=['*'], help='A list of allowed http methods for cors') parser.add_argument('--allow-headers', nargs='+', type=str, default=['*'], help='A list of allowed http headers for cors') parser.add_argument('--proxy-url', type=str, default=None, help='The proxy url for api server.') parser.add_argument('--max-concurrent-requests', type=int, default=None, help='This refers to the number of concurrent requests that ' 'the server can handle. The server is designed to process the ' 'engine’s tasks once the maximum number of concurrent requests is ' 'reached, regardless of any additional requests sent by clients ' 'concurrently during that time. Default to None.') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) ArgumentHelper.model_name(parser) ArgumentHelper.max_log_len(parser) ArgumentHelper.disable_fastapi_docs(parser) ArgumentHelper.allow_terminate_by_client(parser) ArgumentHelper.enable_abort_handling(parser) # chat template args ArgumentHelper.chat_template(parser) # parsers ArgumentHelper.tool_call_parser(parser) ArgumentHelper.reasoning_parser(parser) # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) ArgumentHelper.disable_vision_encoder(pt_group) ArgumentHelper.logprobs_mode(pt_group) ArgumentHelper.dllm_block_length(pt_group) ArgumentHelper.dllm_unmasking_strategy(pt_group) ArgumentHelper.dllm_denoising_steps(pt_group) ArgumentHelper.dllm_confidence_threshold(pt_group) ArgumentHelper.enable_return_routed_experts(pt_group) ArgumentHelper.distributed_executor_backend(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) max_batch_size_act = ArgumentHelper.max_batch_size(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) model_format = ArgumentHelper.model_format(pt_group) hf_overrides = ArgumentHelper.hf_overrides(pt_group) disable_metrics = ArgumentHelper.disable_metrics(pt_group) dp = ArgumentHelper.dp(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) # multi-node serving args node_rank_act = ArgumentHelper.node_rank(pt_group) num_nodes_act = ArgumentHelper.num_nodes(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(max_batch_size_act) tb_group._group_actions.append(cache_max_entry_act) tb_group._group_actions.append(cache_block_seq_len_act) tb_group._group_actions.append(prefix_caching_act) tb_group._group_actions.append(max_prefill_token_num_act) tb_group._group_actions.append(quant_policy) tb_group._group_actions.append(model_format) tb_group._group_actions.append(num_nodes_act) tb_group._group_actions.append(node_rank_act) tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(disable_metrics) tb_group._group_actions.append(dp) ArgumentHelper.cp(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.async_(tb_group) ArgumentHelper.communicator(tb_group) ArgumentHelper.dist_init_addr(tb_group) # vlm args vision_group = parser.add_argument_group('Vision model arguments') ArgumentHelper.vision_max_batch_size(vision_group) # spec decode ArgumentHelper.add_spec_group(parser) @staticmethod def add_parser_proxy(): """Add parser for proxy server command.""" parser = SubCliServe.subparsers.add_parser('proxy', formatter_class=DefaultsAndTypesHelpFormatter, description=SubCliServe.proxy.__doc__, help=SubCliServe.proxy.__doc__) parser.set_defaults(run=SubCliServe.proxy) parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for proxy serving') parser.add_argument('--server-port', type=int, default=8000, help='Server port of the proxy') parser.add_argument('--serving-strategy', type=str, choices=['Hybrid', 'DistServe'], default='Hybrid', help='the strategy to serve, Hybrid for colocating Prefill and Decode' 'workloads into same engine, DistServe for Prefill-Decode Disaggregation') parser.add_argument('--dummy-prefill', action='store_true', help='dummy prefill for performance profiler') parser.add_argument('--routing-strategy', type=str, choices=['random', 'min_expected_latency', 'min_observed_latency'], default='min_expected_latency', help='the strategy to dispatch requests to nodes') parser.add_argument('--disable-cache-status', action='store_true', help='Whether to disable cache status of the ' 'proxy. If set, the proxy will forget the status ' 'of the previous time') # For Disaggregation parser.add_argument('--migration-protocol', type=str, choices=['RDMA', 'NVLINK'], default='RDMA', help='transport protocol of KV migration') parser.add_argument('--link-type', type=str, choices=['RoCE', 'IB'], default='RoCE', help='RDMA Link Type') parser.add_argument('--disable-gdr', action='store_true', help='with GPU Direct Memory Access') ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) ArgumentHelper.log_level(parser) @staticmethod def api_server(args): """Serve LLMs with restful api using fastapi.""" from lmdeploy.archs import autoget_backend max_batch_size = args.max_batch_size if args.max_batch_size \ else get_max_batch_size(args.device) backend = args.backend if backend != 'pytorch': # set auto backend mode backend = autoget_backend(args.model_path) if backend == 'pytorch': from lmdeploy.messages import PytorchEngineConfig adapters = get_lora_adapters(args.adapters) backend_config = PytorchEngineConfig( dtype=args.dtype, tp=args.tp, dp=args.dp, ep=args.ep, max_batch_size=max_batch_size, cache_max_entry_count=args.cache_max_entry_count, block_size=args.cache_block_seq_len, session_len=args.session_len, adapters=adapters, enable_prefix_caching=args.enable_prefix_caching, device_type=args.device, quant_policy=args.quant_policy, eager_mode=args.eager_mode, max_prefill_token_num=args.max_prefill_token_num, enable_microbatch=args.enable_microbatch, enable_eplb=args.enable_eplb, enable_metrics=not args.disable_metrics, role=EngineRole[args.role], migration_backend=MigrationBackend[args.migration_backend], model_format=args.model_format, hf_overrides=args.hf_overrides, disable_vision_encoder=args.disable_vision_encoder, logprobs_mode=args.logprobs_mode, dllm_block_length=args.dllm_block_length, dllm_unmasking_strategy=args.dllm_unmasking_strategy, dllm_denoising_steps=args.dllm_denoising_steps, dllm_confidence_threshold=args.dllm_confidence_threshold, enable_return_routed_experts=args.enable_return_routed_experts, distributed_executor_backend=args.distributed_executor_backend, ) else: from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, dp=args.dp, cp=args.cp, nnodes=args.nnodes, node_rank=args.node_rank, dist_init_addr=args.dist_init_addr, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, quant_policy=args.quant_policy, rope_scaling_factor=args.rope_scaling_factor, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, enable_prefix_caching=args.enable_prefix_caching, max_prefill_token_num=args.max_prefill_token_num, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, async_=args.async_, communicator=args.communicator, enable_metrics=not args.disable_metrics, hf_overrides=args.hf_overrides) chat_template_config = get_chat_template(args.chat_template, args.model_path) speculative_config = get_speculative_config(args) from lmdeploy.messages import VisionConfig vision_config = VisionConfig(args.vision_max_batch_size) if args.dp == 1 or backend == 'turbomind': from lmdeploy.serve.openai.api_server import serve as run_api_server run_api_server( args.model_path, model_name=args.model_name, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, vision_config=vision_config, server_name=args.server_name, server_port=args.server_port, allow_origins=args.allow_origins, allow_credentials=args.allow_credentials, allow_methods=args.allow_methods, allow_headers=args.allow_headers, allow_terminate_by_client=args.allow_terminate_by_client, enable_abort_handling=args.enable_abort_handling, log_level=args.log_level.upper(), api_keys=args.api_keys, ssl=args.ssl, proxy_url=args.proxy_url, max_log_len=args.max_log_len, disable_fastapi_docs=args.disable_fastapi_docs, max_concurrent_requests=args.max_concurrent_requests, reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, ) else: from lmdeploy.serve.openai.launch_server import launch_server launch_server( args.nnodes, args.node_rank, args.model_path, model_name=args.model_name, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, vision_config=vision_config, server_name=args.server_name, server_port=args.server_port, allow_origins=args.allow_origins, allow_credentials=args.allow_credentials, allow_methods=args.allow_methods, allow_headers=args.allow_headers, allow_terminate_by_client=args.allow_terminate_by_client, enable_abort_handling=args.enable_abort_handling, log_level=args.log_level.upper(), api_keys=args.api_keys, ssl=args.ssl, proxy_url=args.proxy_url, max_log_len=args.max_log_len, disable_fastapi_docs=args.disable_fastapi_docs, max_concurrent_requests=args.max_concurrent_requests, reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, ) @staticmethod def proxy(args): """Proxy server that manages distributed api_server nodes.""" from lmdeploy.serve.proxy.proxy import proxy kwargs = convert_args(args) proxy(**kwargs) @staticmethod def add_parsers(): SubCliServe.add_parser_api_server() SubCliServe.add_parser_proxy() ================================================ FILE: lmdeploy/cli/utils.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import argparse import json import re import sys from collections import defaultdict from typing import Any, List from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') class DefaultsAndTypesHelpFormatter(argparse.HelpFormatter): """Formatter to output default value and type in help information.""" def _get_help_string(self, action): """Add default and type info into help.""" help = action.help if '%(default)' not in action.help: if action.default is not argparse.SUPPRESS: defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] if (action.option_strings or action.nargs in defaulting_nargs) and 'default' not in help.lower(): if not help.endswith('.'): help += '.' help += ' Default: %(default)s' if action.type: if not help.endswith('.'): help += '.' help += ' Type: %(type)s' return help def convert_args(args): """Convert args to dict format.""" special_names = ['run', 'command'] kwargs = {k[0]: k[1] for k in args._get_kwargs() if k[0] not in special_names} return kwargs def get_lora_adapters(adapters: List[str]): """Parse lora adapers from cli input. Args: adapters (List[str]): CLI input string of lora adapter path(s). Returns: Dict[str,str] or None: Parsed lora adapter path(s). """ if not adapters: return None n = len(adapters) output = {} if n == 1: name = 'default' path = adapters[0].strip() if '=' in path: name, path = path.split('=', 1) output[name] = path else: for pair in adapters: assert '=' in pair, f'Multiple lora paths must in format of ' \ f'xxx=yyy. But given: {pair}' name, path = pair.strip().split('=', 1) assert name not in output, f'Multiple lora paths with repeated lora name: {name}' output[name] = path return output def get_chat_template(chat_template: str, model_path: str = None): """Get chat template config. Args: chat_template(str): it could be a builtin chat template name, or a chat template json file model_path(str): the model path, used to check deprecated chat template names """ import os from lmdeploy.model import ChatTemplateConfig if chat_template: if os.path.isfile(chat_template): return ChatTemplateConfig.from_json(chat_template) else: from lmdeploy.model import DEPRECATED_CHAT_TEMPLATE_NAMES, MODELS, REMOVED_CHAT_TEMPLATE_NAMES if chat_template in REMOVED_CHAT_TEMPLATE_NAMES: raise ValueError(f"The chat template '{chat_template}' has been removed. " f'Please refer to the latest chat templates in ' f'https://lmdeploy.readthedocs.io/en/latest/advance/chat_template.html') if chat_template in DEPRECATED_CHAT_TEMPLATE_NAMES: logger.warning(f"The chat template '{chat_template}' is deprecated and fallback to hf chat template.") chat_template = 'hf' assert chat_template in MODELS.module_dict.keys(), \ f"chat template '{chat_template}' is not " \ f'registered. The builtin chat templates are: ' \ f'{MODELS.module_dict.keys()}' return ChatTemplateConfig(model_name=chat_template, model_path=model_path) else: return None def get_speculative_config(args): """Get speculative config from args.""" from lmdeploy.messages import SpeculativeConfig speculative_config = None if args.speculative_algorithm is not None: speculative_config = SpeculativeConfig( method=args.speculative_algorithm, model=args.speculative_draft_model, num_speculative_tokens=args.speculative_num_draft_tokens, ) return speculative_config class ArgumentHelper: """Helper class to add unified argument.""" @staticmethod def model_name(parser): """Add argument model_name to parser.""" return parser.add_argument('--model-name', type=str, default=None, help='The name of the served model. It can be accessed ' 'by the RESTful API `/v1/models`. If it is not specified, ' '`model_path` will be adopted') @staticmethod def dtype(parser, default: str = 'auto'): return parser.add_argument('--dtype', type=str, default=default, choices=['auto', 'float16', 'bfloat16'], help='data type for model weights and activations. ' 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models. This option will be ignored if ' 'the model is a quantized model') @staticmethod def quant_dtype(parser, default: str = 'int8'): return parser.add_argument('--quant-dtype', type=str, default=default, choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'], help='data type for the quantized model weights and activations.' 'Note "fp8" is the short version of "float8_e4m3fn"') @staticmethod def model_format(parser, default: str = None): return parser.add_argument('--model-format', type=str, default=default, choices=['hf', 'awq', 'gptq', 'fp8', 'mxfp4'], help='The format of input model. `hf` means `hf_llama`, ' '`awq` represents the quantized model by AWQ,' ' and `gptq` refers to the quantized model by GPTQ') @staticmethod def revision(parser, default: str = None): return parser.add_argument('--revision', type=str, default=default, help='The specific model version to use. ' 'It can be a branch name, a tag name, or a commit id. ' 'If unspecified, will use the default version.') @staticmethod def download_dir(parser, default: str = None): return parser.add_argument('--download-dir', type=str, default=default, help='Directory to download and load the weights, ' 'default to the default cache directory of huggingface.') @staticmethod def tp(parser): """Add argument tp to parser.""" return parser.add_argument('--tp', type=int, default=1, help='GPU number used in tensor parallelism. Should be 2^n') @staticmethod def dp(parser): """Add argument dp to parser.""" return parser.add_argument('--dp', type=int, default=1, help='data parallelism. dp_rank is required when pytorch engine is used.') @staticmethod def ep(parser): """Add argument ep to parser.""" return parser.add_argument('--ep', type=int, default=1, help='expert parallelism. dp is required when pytorch engine is used.') @staticmethod def cp(parser): """Add argument cp to parser.""" return parser.add_argument( '--cp', type=int, default=1, help='context parallelism size in attention for turbomind backend, tp must be a multiple of cp.') @staticmethod def dp_rank(parser): """Add argument dp_rank to parser.""" return parser.add_argument('--dp-rank', type=int, default=0, help='data parallelism rank, all ranks between 0 ~ dp should be created.') @staticmethod def node_rank(parser): """Add argument node_rank to parser.""" return parser.add_argument('--node-rank', type=int, default=0, help='The current node rank.') @staticmethod def num_nodes(parser): """Add argument num_nodes to parser.""" return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums') @staticmethod def dist_init_addr(parser): """Add argument dist_init_addr to parser.""" return parser.add_argument('--dist-init-addr', type=str, default=None) @staticmethod def session_id(parser): """Add argument session_id to parser.""" return parser.add_argument('--session-id', type=int, default=1, help='The identical id of a session') @staticmethod def session_len(parser, default: int = None): return parser.add_argument('--session-len', type=int, default=default, help='The max session length of a sequence') @staticmethod def max_batch_size(parser): """Add argument max_batch_size to parser.""" return parser.add_argument('--max-batch-size', type=int, default=None, help='Maximum batch size. If not specified, the engine will ' 'automatically set it according to the device') @staticmethod def quant_policy(parser, default: int = 0): """Add argument quant_policy to parser.""" return parser.add_argument('--quant-policy', type=int, default=0, choices=[0, 4, 8], help='Quantize kv or not. 0: no quant; 4: 4bit kv; 8: 8bit kv') @staticmethod def rope_scaling_factor(parser): """Add argument rope_scaling_factor to parser.""" return parser.add_argument('--rope-scaling-factor', type=float, default=0.0, help='Rope scaling factor') @staticmethod def hf_overrides(parser): """Add argument hf_overrides to parser.""" return parser.add_argument('--hf-overrides', type=json.loads, default=None, help='Extra arguments to be forwarded to the HuggingFace config.') @staticmethod def use_logn_attn(parser): """Add argument use_logn_attn to parser.""" return parser.add_argument('--use-logn-attn', action='store_true', default=False, help='Whether to use logn attention scaling') @staticmethod def block_size(parser): """Add argument block_size to parser.""" return parser.add_argument('--block-size', type=int, default=64, help='The block size for paging cache') @staticmethod def top_p(parser): """Add argument top_p to parser.""" return parser.add_argument('--top-p', type=float, default=0.8, help='An alternative to sampling with temperature,' ' called nucleus sampling, where the model ' 'considers the results of the tokens with ' 'top_p probability mass') @staticmethod def top_k(parser): """Add argument top_k to parser.""" return parser.add_argument('--top-k', type=int, default=1, help='An alternative to sampling with temperature, ' 'where the model considers the top_k tokens ' 'with the highest probability') @staticmethod def temperature(parser, default: float = 0.8): return parser.add_argument('-temp', '--temperature', type=float, default=default, help='Sampling temperature') @staticmethod def repetition_penalty(parser): """Add argument repetition_penalty to parser.""" return parser.add_argument('--repetition-penalty', type=float, default=1.0, help='Parameter to penalize repetition') @staticmethod def log_level(parser): """Add argument log_level to parser.""" import logging return parser.add_argument('--log-level', type=str, default='WARNING', choices=list(logging._nameToLevel.keys()), help='Set the log level') @staticmethod def api_keys(parser): return parser.add_argument( '--api-keys', type=str, nargs='*', default=None, help='Optional list of space separated API keys', ) @staticmethod def ssl(parser): return parser.add_argument( '--ssl', action='store_true', required=False, default=False, help='Enable SSL. Requires OS Environment variables' " 'SSL_KEYFILE' and 'SSL_CERTFILE'", ) @staticmethod def backend(parser): """Add argument backend to parser.""" return parser.add_argument('--backend', type=str, default='turbomind', choices=['pytorch', 'turbomind'], help='Set the inference backend') @staticmethod def stream_output(parser): """Add argument stream_output to parser.""" return parser.add_argument('--stream-output', action='store_true', help='Indicator for streaming output or not') @staticmethod def calib_dataset(parser): """Add argument calib_dataset to parser.""" return parser.add_argument( '--calib-dataset', type=str, default='wikitext2', choices=['wikitext2', 'c4', 'pileval', 'gsm8k', 'neuralmagic_calibration', 'open-platypus', 'openwebtext'], help='The calibration dataset name.') @staticmethod def calib_samples(parser): """Add argument calib_samples to parser.""" return parser.add_argument('--calib-samples', type=int, default=128, help='The number of samples for calibration') @staticmethod def calib_seqlen(parser): """Add argument calib_seqlen to parser.""" return parser.add_argument('--calib-seqlen', type=int, default=2048, help='The sequence length for calibration') @staticmethod def calib_batchsize(parser): """Add argument batch_size to parser.""" return parser.add_argument( '--batch-size', type=int, default=1, help=\ 'The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM' # noqa ) @staticmethod def calib_search_scale(parser): """Add argument search_scale to parser.""" return parser.add_argument( '--search-scale', action='store_true', default=False, help=\ 'Whether search scale ratio. Default to be disabled, which means only smooth quant with 0.5 ratio will be applied' # noqa ) @staticmethod def device(parser, default: str = 'cuda', choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']): """Add argument device to parser.""" return parser.add_argument('--device', type=str, default=default, choices=choices, help='The device type of running') @staticmethod def chat_template(parser): """Add chat template config to parser.""" return parser.add_argument( '--chat-template', type=str, default=None, help=\ 'A JSON file or string that specifies the chat template configuration. ' # noqa 'Please refer to https://lmdeploy.readthedocs.io/en/latest/advance/chat_template.html for the specification' # noqa ) @staticmethod def reasoning_parser(parser): """Add reasoning parser to parser.""" from lmdeploy.serve.openai.reasoning_parser import ReasoningParserManager return parser.add_argument( '--reasoning-parser', type=str, default=None, help=f'The registered reasoning parser name from {ReasoningParserManager.module_dict.keys()}. ' 'Default to None.') @staticmethod def tool_call_parser(parser): """Add tool call parser to parser.""" from lmdeploy.serve.openai.tool_parser import ToolParserManager return parser.add_argument( '--tool-call-parser', type=str, default=None, help=f'The registered tool parser name {ToolParserManager.module_dict.keys()}. Default to None.') @staticmethod def allow_terminate_by_client(parser): """Add argument allow_terminate_by_client to parser.""" return parser.add_argument('--allow-terminate-by-client', action='store_true', default=False, help='Enable server to be terminated by request from client') @staticmethod def enable_abort_handling(parser): """Add --enable-abort-handling argument to configure server abort request processing.""" return parser.add_argument('--enable-abort-handling', action='store_true', default=False, help='Enable server to handle client abort requests') @staticmethod def cache_max_entry_count(parser): """Add argument cache_max_entry_count to parser.""" return parser.add_argument('--cache-max-entry-count', type=float, default=0.8, help='The percentage of free gpu memory occupied by the k/v ' 'cache, excluding weights ') @staticmethod def adapters(parser): """Add argument adapters to parser.""" return parser.add_argument('--adapters', nargs='*', type=str, default=None, help='Used to set path(s) of lora adapter(s). One can input ' 'key-value pairs in xxx=yyy format for multiple lora ' 'adapters. If only have one adapter, one can only input ' 'the path of the adapter.') @staticmethod def work_dir(parser): """Add argument work_dir to parser.""" return parser.add_argument('--work-dir', type=str, default='./work_dir', help='The working directory to save results') @staticmethod def cache_block_seq_len(parser): """Add argument cache_block_seq_len to parser.""" return parser.add_argument('--cache-block-seq-len', type=int, default=64, help='The length of the token sequence in a k/v block. ' 'For Turbomind Engine, if the GPU compute capability ' 'is >= 8.0, it should be a multiple of 32, otherwise ' 'it should be a multiple of 64. For Pytorch Engine, ' 'if Lora Adapter is specified, this parameter will ' 'be ignored') @staticmethod def enable_prefix_caching(parser): """Add argument enable_prefix_caching to parser.""" return parser.add_argument('--enable-prefix-caching', action='store_true', default=False, help='Enable cache and match prefix') @staticmethod def num_tokens_per_iter(parser): return parser.add_argument('--num-tokens-per-iter', type=int, default=0, help='the number of tokens processed in a forward pass') @staticmethod def max_prefill_iters(parser): return parser.add_argument('--max-prefill-iters', type=int, default=1, help='the max number of forward passes in prefill stage') @staticmethod def async_(parser): return parser.add_argument('--async', type=int, default=1, choices=[0, 1], dest='async_', help='Enable async execution (default: 1, enabled). ' 'Set to 0 to disable async mode, 1 to enable it.') @staticmethod def max_prefill_token_num(parser): return parser.add_argument('--max-prefill-token-num', type=int, default=8192, help='the max number of tokens per iteration during prefill') @staticmethod def vision_max_batch_size(parser): return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size') @staticmethod def max_log_len(parser): return parser.add_argument('--max-log-len', type=int, default=None, help='Max number of prompt characters or prompt tokens being ' 'printed in log. Default: Unlimited') @staticmethod def disable_fastapi_docs(parser): return parser.add_argument('--disable-fastapi-docs', action='store_true', default=False, help="Disable FastAPI's OpenAPI schema," ' Swagger UI, and ReDoc endpoint') @staticmethod def eager_mode(parser): """Add argument eager_mode to parser.""" return parser.add_argument('--eager-mode', action='store_true', default=False, help='Whether to enable eager mode. ' 'If True, cuda graph would be disabled') @staticmethod def communicator(parser): return parser.add_argument('--communicator', type=str, default='nccl', choices=['nccl', 'native', 'cuda-ipc'], help='Communication backend for multi-GPU inference. The "native" option is ' 'deprecated and serves as an alias for "cuda-ipc"') @staticmethod def enable_microbatch(parser): """Add argument enable_microbatch to parser.""" return parser.add_argument('--enable-microbatch', action='store_true', help='enable microbatch for specified model') @staticmethod def enable_eplb(parser): """Add argument enable_eplb to parser.""" return parser.add_argument('--enable-eplb', action='store_true', help='enable eplb for specified model') @staticmethod def disable_metrics(parser): """Add argument disable_metrics to parser.""" return parser.add_argument('--disable-metrics', action='store_true', default=False, help='disable metrics system') # For Disaggregation @staticmethod def role(parser): return parser.add_argument('--role', type=str, default='Hybrid', choices=['Hybrid', 'Prefill', 'Decode'], help='Hybrid for Non-Disaggregated Engine; ' 'Prefill for Disaggregated Prefill Engine; ' 'Decode for Disaggregated Decode Engine') @staticmethod def migration_backend(parser): return parser.add_argument('--migration-backend', type=str, default='DLSlime', choices=['DLSlime', 'Mooncake'], help='kvcache migration management backend when PD disaggregation') @staticmethod def disable_vision_encoder(parser): """Disable loading vision encoder.""" return parser.add_argument('--disable-vision-encoder', action='store_true', default=False, help='disable multimodal encoder') @staticmethod def logprobs_mode(parser): """The mode of logprobs.""" return parser.add_argument('--logprobs-mode', type=str, default=None, choices=[None, 'raw_logits', 'raw_logprobs'], help='The mode of logprobs.') @staticmethod def dllm_block_length(parser): """dllm_block_length for dllm.""" return parser.add_argument('--dllm-block-length', type=int, default=None, help='Block length for dllm') @staticmethod def dllm_unmasking_strategy(parser): """Dllm unmasking strategy.""" return parser.add_argument('--dllm-unmasking-strategy', type=str, default='low_confidence_dynamic', choices=['low_confidence_dynamic', 'low_confidence_static', 'sequential'], help='The unmasking strategy for dllm.') @staticmethod def dllm_denoising_steps(parser): """Dllm denoising steps.""" return parser.add_argument('--dllm-denoising-steps', type=int, default=None, help='The number of denoising steps for dllm.') @staticmethod def dllm_confidence_threshold(parser): """Dllm confidence threshold.""" return parser.add_argument('--dllm-confidence-threshold', type=float, default=0.85, help='The confidence threshold for dllm.') @staticmethod def enable_return_routed_experts(parser): """Add argument return routed experts to parser.""" return parser.add_argument('--enable-return-routed-experts', action='store_true', default=False, help='Whether to output routed expert ids for replay') @staticmethod def add_spec_group(parser): spec_group = parser.add_argument_group('Speculative decoding arguments') spec_group.add_argument('--speculative-algorithm', type=str, default=None, choices=['eagle', 'eagle3', 'deepseek_mtp'], help='The speculative algorithm to use. `None` means speculative decoding is disabled') spec_group.add_argument('--speculative-draft-model', type=str, default=None, help='The path to speculative draft model') spec_group.add_argument('--speculative-num-draft-tokens', type=int, default=1, help='The number of speculative tokens to generate per step') return spec_group @staticmethod def distributed_executor_backend(parser): """Distributed_executor_backend.""" return parser.add_argument('--distributed-executor-backend', type=str, default=None, choices=['uni', 'mp', 'ray'], help='The distributed executor backend for pytorch engine.') # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): """"More flexible argument parser.""" def parse_args(self, args=None, namespace=None): # If args is not provided, use arguments from the command line if args is None: args = sys.argv[1:] def repl(match: re.Match) -> str: """Replaces underscores with dashes in the matched string.""" return match.group(0).replace('_', '-') # Everything between the first -- and the first . pattern = re.compile(r'(?<=--)[^\.]*') # Convert underscores to dashes and vice versa in argument names processed_args = [] for arg in args: if arg.startswith('--'): if '=' in arg: key, value = arg.split('=', 1) key = pattern.sub(repl, key, count=1) processed_args.append(f'{key}={value}') else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: # allow -O flag to be used without space, e.g. -O3 processed_args.append('-O') processed_args.append(arg[2:]) else: processed_args.append(arg) def _try_convert(value: str): """Try to convert string to float or int.""" if not isinstance(value, str): return value # try loads from json try: return json.loads(value) except json.JSONDecodeError: pass return value def create_nested_dict(keys: list[str], value: str): """Creates a nested dictionary from a list of keys and a value. For example, `keys = ["a", "b", "c"]` and `value = 1` will create: `{"a": {"b": {"c": 1}}}` """ nested_dict: Any = _try_convert(value) for key in reversed(keys): nested_dict = {key: nested_dict} return nested_dict def recursive_dict_update(original: dict, update: dict): """Recursively updates a dictionary with another dictionary.""" for k, v in update.items(): if isinstance(v, dict) and isinstance(original.get(k), dict): recursive_dict_update(original[k], v) else: original[k] = v delete = set() dict_args: dict[str, dict] = defaultdict(dict) for i, processed_arg in enumerate(processed_args): if processed_arg.startswith('--') and '.' in processed_arg: if '=' in processed_arg: processed_arg, value = processed_arg.split('=', 1) if '.' not in processed_arg: # False positive, . was only in the value continue else: value = processed_args[i + 1] delete.add(i + 1) key, *keys = processed_arg.split('.') # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) recursive_dict_update(dict_args[key], arg_dict) delete.add(i) # Filter out the dict args we set to None processed_args = [a for i, a in enumerate(processed_args) if i not in delete] # Add the dict args back as if they were originally passed as JSON for dict_arg, dict_value in dict_args.items(): processed_args.append(dict_arg) processed_args.append(json.dumps(dict_value)) return super().parse_args(processed_args, namespace) ================================================ FILE: lmdeploy/lite/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .apis import * # noqa: F401,F403 from .quantization import * # noqa: F401,F403 from .utils import * # noqa: F401,F403 ================================================ FILE: lmdeploy/lite/apis/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. ================================================ FILE: lmdeploy/lite/apis/auto_awq.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp import shutil from typing import Literal import torch from torch import nn from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers from lmdeploy.lite.utils import collect_target_modules from lmdeploy.utils import try_import_deeplink from .calibrate import LAYER_TYPE_MAP, calibrate def save_vl_model(vl_model, model_path, dst_path): vl_model.save_pretrained(dst_path, safe_serialization=True) candidate = [ 'preprocessor_config.json', 'processor_config.json', 'vit', 'generation_config.json', 'added_tokens.json' ] for name in candidate: tmp_path = osp.join(model_path, name) if osp.exists(tmp_path): if osp.isfile(tmp_path): shutil.copy(tmp_path, osp.join(dst_path, name)) elif osp.isdir(tmp_path): shutil.copytree(tmp_path, osp.join(dst_path, name)) # AutoProcessor files allfiles = os.listdir(model_path) for file in allfiles: if not file.endswith('.py'): continue copy_src = osp.join(model_path, file) copy_dst = osp.join(dst_path, file) if not osp.exists(copy_dst): shutil.copyfile(copy_src, copy_dst) def auto_awq(model: str, work_dir: str = './work_dir', calib_dataset: str = 'wikitext2', calib_samples: int = 128, batch_size: int = 1, calib_seqlen: int = 2048, w_bits: int = 4, w_sym: bool = False, w_group_size: int = 128, search_scale: bool = False, device: str = 'cuda', revision: str = None, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', download_dir: str = None): """Perform weight quantization using AWQ algorithm. Args: model (str): The path of model in hf format. work_dir (str): The working directory to save results. calib_dataset (str): The calibration dataset name. Defaults to 'wikitext2'. calib_samples (int): The number of samples for calibration. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. calib_seqlen (int): The sequence length for calibration. w_bits (int): Bit number for weight quantization. w_sym (bool): Whether to do symmetric quantization. w_group_size (int): Group size for weight quantization statistics. search_scale (bool): Whether search scale ratio. Default to False, which means only smooth quant with 0.5 ratio will be applied. device (str): Device type of running. revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. dtype (str): Data type for loading model weights and calib infer. download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. """ try_import_deeplink(device) if not osp.exists(model): print(f'can\'t find model from local_path {model}, ' 'try to download from remote') from lmdeploy.utils import get_model model = get_model(model, revision=revision, download_dir=download_dir) model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset, calib_samples, calib_seqlen, work_dir, device, w_bits=w_bits, w_group_size=w_group_size, search_scale=search_scale, dtype=dtype, batch_size=batch_size) layer_type = LAYER_TYPE_MAP[type(model).__name__] fc2fcs = FC_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type] input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True) layers = collect_target_modules(model, layer_type) fcs = {} for l_name, layer in layers.items(): name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name) fcs.update(name2fc) if search_scale: awq_ratios = input_stats['ratios'] act_scales = input_stats['absmean'] awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device) else: act_scales = input_stats['absmax'] smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device) quant_weights(model, fcs, w_bits, w_sym, w_group_size, device) quantization_config = dict(quant_method='awq', version='gemm', bits=w_bits, group_size=w_group_size, zero_point=not w_sym) model.config.update(dict(quantization_config=quantization_config)) if vl_model: save_vl_model(vl_model, model_path, work_dir) else: model.save_pretrained(work_dir, safe_serialization=True) tokenizer.save_pretrained(work_dir) if __name__ == '__main__': import fire fire.Fire(auto_awq) ================================================ FILE: lmdeploy/lite/apis/calibrate.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path from typing import Literal, Union import torch from torch import nn from transformers import AutoTokenizer from lmdeploy.archs import get_task from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2 from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders, load_hf_from_pretrained from lmdeploy.vl.model.builder import load_vl_model LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', 'InternLM2ForCausalLM': 'InternLM2DecoderLayer', 'InternLM3ForCausalLM': 'InternLM3DecoderLayer', 'QWenLMHeadModel': 'QWenBlock', 'Qwen2ForCausalLM': 'Qwen2DecoderLayer', 'Qwen3ForCausalLM': 'Qwen3DecoderLayer', 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B 'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaDecoderLayer', 'LlavaLlamaForCausalLM': 'LlamaDecoderLayer', 'MGMLlamaForCausalLM': 'LlamaDecoderLayer', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2DecoderLayer', 'Phi3ForCausalLM': 'Phi3DecoderLayer', 'ChatGLMForConditionalGeneration': 'GLMBlock', 'MixtralForCausalLM': 'MixtralDecoderLayer', 'Qwen2VLForConditionalGeneration': 'Qwen2VLDecoderLayer', 'Qwen2_5_VLForConditionalGeneration': 'Qwen2_5_VLDecoderLayer', 'MistralForCausalLM': 'MistralDecoderLayer', } NORM_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMRMSNorm', 'InternLM2ForCausalLM': 'InternLM2RMSNorm', 'InternLM3ForCausalLM': 'InternLM3RMSNorm', 'QWenLMHeadModel': 'RMSNorm', 'Qwen2ForCausalLM': 'Qwen2RMSNorm', 'Qwen3ForCausalLM': 'Qwen3RMSNorm', 'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B 'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B 'LlamaForCausalLM': 'LlamaRMSNorm', 'LlavaLlamaForCausalLM': 'LlamaRMSNorm', 'MGMLlamaForCausalLM': 'LlamaRMSNorm', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2RMSNorm', 'Phi3ForCausalLM': 'Phi3RMSNorm', 'ChatGLMForConditionalGeneration': 'RMSNorm', 'MixtralForCausalLM': 'MixtralRMSNorm', 'Qwen2VLForConditionalGeneration': 'Qwen2RMSNorm', 'Qwen2_5_VLForConditionalGeneration': 'Qwen2RMSNorm', 'MistralForCausalLM': 'MistralRMSNorm', } HEAD_NAME_MAP = { 'InternLMForCausalLM': 'lm_head', 'InternLM2ForCausalLM': 'output', 'InternLM3ForCausalLM': 'output', 'QWenLMHeadModel': 'lm_head', 'Qwen2ForCausalLM': 'lm_head', 'Qwen3ForCausalLM': 'lm_head', 'BaiChuanForCausalLM': 'lm_head', # Baichuan 7B 'BaichuanForCausalLM': 'lm_head', # Baichuan2 7B 'LlamaForCausalLM': 'lm_head', 'LlavaLlamaForCausalLM': 'lm_head', 'MGMLlamaForCausalLM': 'lm_head', # mini gemini 'InternLMXComposer2ForCausalLM': 'output', 'Phi3ForCausalLM': 'lm_head', 'ChatGLMForConditionalGeneration': 'output_layer', 'MixtralForCausalLM': 'lm_head', 'Qwen2VLForConditionalGeneration': 'lm_head', 'Qwen2_5_VLForConditionalGeneration': 'lm_head', 'MistralForCausalLM': 'lm_head', } def _prepare_for_calibrate(model: nn.Module, layer_type: Union[str, type], head_name: str = 'lm_head', device: str = 'cuda', prefix: str = '') -> None: """Prepare the model for calibration by moving specific modules to CPU. This function goes through each child of a given model and checks whether it is an instance of a certain layer type or has the name equal to `head_name`. If yes, it moves the module to CPU, otherwise to the specified device (default is CUDA). If the child contains the target layer type in its sub-modules, the function performs the same operation recursively. Parameters ---------- model : nn.Module The PyTorch model to prepare for calibration. layer_type : Union[str, Type] The type of the layer to be moved to CPU. Can be either a string of class name or the class type itself. head_name : str, optional The name of the module to be moved to CPU. Default is 'lm_head'. device : str, optional The device to which modules not matching the `layer_type` or `head_name` will be moved. Default is 'cuda'. prefix : str, optional The prefix used when printing the names of the moved modules. Default is ''. Raises ------ TypeError If `layer_type` is neither a string nor a type. """ for name, child in model.named_children(): # Check if the child is an instance of the given layer type if isinstance(layer_type, str): is_layer = type(child).__name__ == layer_type elif isinstance(layer_type, type): is_layer = isinstance(child, layer_type) else: raise TypeError('layer_type should be a string (class name) or a type') # Check if the child contains the target module type contain_layer = len(collect_target_modules(child, layer_type, [head_name]).keys()) > 0 # Check if the child matches the head name is_head = name == head_name # skip moving head layer to CPU when tie_word_embeddings is True is_head = is_head and not getattr(model.config, 'tie_word_embeddings', False) mod_name = f'{prefix}.{name}' if prefix else name # If the child is either an instance of the layer type or has the # head name, move it to CPU, otherwise move it to the specified device if is_layer or is_head: child.to('cpu') print(f'Move {mod_name} to CPU.') elif contain_layer: _prepare_for_calibrate(child, layer_type, head_name, device, mod_name) else: child.to(device) print(f'Move {mod_name} to GPU.') # TODO to be removed def make_compatible_internvl_config(model_path): """Patch model.config since after transformers v4.45.0, InternVL models can't use `save_pretrained`""" from lmdeploy.archs import get_model_arch arch, _ = get_model_arch(model_path) if arch == 'InternVLChatModel': import transformers from packaging import version if version.parse(transformers.__version__) >= version.parse('4.45.0'): def _get_non_default_generation_parameters(self): return {} from transformers import PretrainedConfig PretrainedConfig._get_non_default_generation_parameters = _get_non_default_generation_parameters # noqa def update_moe_mapping(model, model_type): """Update moe mapping.""" from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP # get experts num num_experts = 0 for n, m in model.named_modules(): if type(m).__name__ == LAYER_TYPE_MAP[model_type]: fc2fcs = FC_FCS_MAP[LAYER_TYPE_MAP[model_type]] for k, v in fc2fcs.items(): if '{i}' in k: break num_experts = len(m.get_submodule(k.split('.{i}')[0])) break # update FC_FCS_MAP updated_fc2fcs = dict() for prev_fc, post_fc in fc2fcs.items(): if '{i}' in prev_fc: for i in range(num_experts): updated_fc2fcs.update({prev_fc.format(i=i): [v.format(i=i) for v in post_fc]}) else: updated_fc2fcs.update({prev_fc: post_fc}) FC_FCS_MAP[LAYER_TYPE_MAP[model_type]] = updated_fc2fcs # update NORM_FCS_MAP norm2fcs = NORM_FCS_MAP[LAYER_TYPE_MAP[model_type]] updated_norm2fcs = dict() for norm, fc in norm2fcs.items(): updated_norm2fcs.update({norm: list(set([v.format(i=i) for v in fc for i in range(num_experts)]))}) NORM_FCS_MAP[LAYER_TYPE_MAP[model_type]] = updated_norm2fcs def calibrate(model: str, calib_dataset: str = 'wikitext2', calib_samples: int = 128, calib_seqlen: int = 2048, work_dir: str = './work_dir', device: str = 'cuda', w_bits: int = 4, w_group_size: int = 128, search_scale: bool = False, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', batch_size: int = 1) -> None: """The main function for loading the model and performing calibration on a given dataset. Args: model (str): The name or path of the model to be loaded. calib_dataset (str, optional): The calibration dataset name. Defaults to 'wikitext2'. calib_samples (int, optional): The number of samples for calibration. Defaults to 128. calib_seqlen (int, optional): The sequence length for calibration. Defaults to 2048. work_dir (str): The working directory for outputs. Defaults to './work_dir'. device (str, optional): The device to be used for calculation. Defaults to 'cuda'. w_bits (int): Bit number for weight quantization. w_group_size (int): Group size for weight quantization statistics. search_scale (bool): Whether search scale ratio. Default to False, which means only smooth quant with 0.5 ratio will be applied. dtype (str): Data type for loading model weights and calib infer. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. Returns: model (nn.Module): The loaded huggingface model. tokenizer : The loaded hugginface tokenizer. work_dir (str): The working directory for outputs. """ assert calib_dataset in ['wikitext2', 'c4', 'pileval', 'gsm8k', 'neuralmagic_calibration', 'open-platypus', 'openwebtext'], \ 'Support only `wikitext2`, `c4`, `pileval`, `gsm8k`, ' \ '`neuralmagic_calibration`, `open-platypus`, `openwebtext`.' model_type, _ = get_task(backend='turbomind', model_path=model) make_compatible_internvl_config(model) # Load tokenizer and configuration tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) if model_type == 'llm': model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True) vl_model = None elif model_type == 'vlm': vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model model = vl_model if hasattr(vl_model, 'language_model'): # deepseek-vl, ... model = vl_model.language_model if hasattr(vl_model, 'llm'): # MiniCPMV, ... model = vl_model.llm model.config.use_cache = False if dtype == 'float16': model.half() elif dtype == 'bfloat16': assert torch.cuda.is_bf16_supported( ), 'your device does not support bfloat16 please set --dtype float16' # noqa model.to(torch.bfloat16) elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16: print('Warning: we cast model to float16 to prevent OOM. You' ' may enforce it bfloat16 by `--dtype bfloat16`') model.half() model.eval() model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: raise RuntimeError(f'Currently, quantification and calibration of {model_type} are ' f'not supported. The supported model types are ' f"{', '.join(LAYER_TYPE_MAP.keys())}.") if model_type in ['MixtralForCausalLM']: update_moe_mapping(model, model_type) if model_type == 'QWenLMHeadModel': try: import flash_attn # noqa: F401 except ImportError: raise RuntimeError('When using Qwen, you need to `pip install flash-attn` first, ' 'otherwise calibration and quantification will not work ' 'properly.') layer_type = LAYER_TYPE_MAP[type(model).__name__] norm_type = NORM_TYPE_MAP[type(model).__name__] _prepare_for_calibrate(model, layer_type, HEAD_NAME_MAP[type(model).__name__], device) print('Loading calibrate dataset ...') calib_loader = get_calib_loaders(calib_dataset, tokenizer, nsamples=calib_samples, seqlen=calib_seqlen) # Initialize calibration context if search_scale: calib_ctx = CalibrationContextV2(model, tokenizer, layer_type=layer_type, norm_type=norm_type, device=device, w_bits=w_bits, w_group_size=w_group_size, batch_size=batch_size, search_scale=search_scale) else: calib_ctx = CalibrationContext(model, tokenizer, layer_type=layer_type, norm_type=norm_type, batch_size=batch_size, device=device) with calib_ctx: all_data = torch.cat(calib_loader).to(device) calib_ctx.calibrate(all_data) # Create work directory if not exists work_dir = Path(work_dir) work_dir.mkdir(parents=True, exist_ok=True) calib_ctx.export(work_dir) return vl_model, model, tokenizer, work_dir if __name__ == '__main__': import fire fire.Fire(calibrate) ================================================ FILE: lmdeploy/lite/apis/get_small_sharded_hf.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import argparse import copy import json import os import shutil import torch from mmengine.utils import mkdir_or_exist def parse_args(): parser = argparse.ArgumentParser(description='Convert a hugging face model to the smallest sharded one') parser.add_argument('src_dir', help='the directory of the model') parser.add_argument('dst_dir', help='the directory to save the new model') args = parser.parse_args() return args def main(): args = parse_args() mkdir_or_exist(args.dst_dir) all_files = os.listdir(args.src_dir) for name in all_files: if not name.startswith(('pytorch_model', '.')): src_path = os.path.join(args.src_dir, name) dst_path = os.path.join(args.dst_dir, name) shutil.copy(src_path, dst_path) with open(os.path.join(args.src_dir, 'pytorch_model.bin.index.json')) as f: index = json.load(f) n_shard = len(index['weight_map']) new_index = copy.deepcopy(index) new_index['weight_map'] = {} cnt = 1 checkpoints = set(index['weight_map'].values()) for ckpt in checkpoints: state_dict = torch.load(os.path.join(args.src_dir, ckpt), map_location='cuda', weights_only=True) keys = sorted(list(state_dict.keys())) for k in keys: new_state_dict_name = 'pytorch_model-{:05d}-of-{:05d}.bin'.format(cnt, n_shard) new_index['weight_map'][k] = new_state_dict_name new_state_dict = {k: state_dict[k]} torch.save(new_state_dict, os.path.join(args.dst_dir, new_state_dict_name)) cnt += 1 del state_dict torch.cuda.empty_cache() with open(os.path.join(args.dst_dir, 'pytorch_model.bin.index.json'), 'w') as f: json.dump(new_index, f) assert new_index['weight_map'].keys() == index['weight_map'].keys(), 'Mismatch on `weight_map`!' if __name__ == '__main__': main() ================================================ FILE: lmdeploy/lite/apis/gptq.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import logging from typing import Literal import torch from transformers import AutoConfig, AutoTokenizer from lmdeploy.lite.utils.calib_dataloader import get_calib_loaders def auto_gptq(model: str, work_dir: str = './work_dir', w_bits: int = 4, w_group_size: int = 128, calib_dataset: str = 'wikitext2', calib_samples: int = 128, calib_seqlen: int = 2048, batch_size: int = 1, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', revision: str = None): """Perform weight quantization using AWQ algorithm. Args: model (str): The path of model in hf format. work_dir (str): The working directory to save results. calib_dataset (str): The calibration dataset name. Defaults to 'wikitext2'. calib_samples (int): The number of samples for calibration. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. calib_seqlen (int): The sequence length for calibration. w_bits (int): Bit number for weight quantization. w_group_size (int): Group size for weight quantization statistics. dtype (str): Data type for loading model weights and calib infer. revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. """ try: from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig except Exception: raise ImportError('To use auto_gptq, please install auto-gptq by ' 'pip install auto-gptq') logging.basicConfig( format='%(asctime)s %(levelname)s [%(name)s] %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S', ) # support internlm2 from auto_gptq.modeling import GPTQ_CAUSAL_LM_MODEL_MAP from auto_gptq.modeling._const import SUPPORTED_MODELS from ..modeling.internlm2_gptq import InternLM2GPTQForCausalLM from ..modeling.internlm3_gptq import InternLM3GPTQForCausalLM SUPPORTED_MODELS.append('internlm2') GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm2=InternLM2GPTQForCausalLM)) SUPPORTED_MODELS.append('internlm3') GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm3=InternLM3GPTQForCausalLM)) pretrained_model_dir = model quantized_model_dir = work_dir tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, trust_remote_code=True) print('Loading calibrate dataset ...') calib_loader = get_calib_loaders(calib_dataset, tokenizer, nsamples=calib_samples, seqlen=calib_seqlen) attention_mask = [1] * calib_seqlen examples = [dict(input_ids=data.flatten().tolist(), attention_mask=attention_mask) for data in calib_loader] quantize_config = BaseQuantizeConfig( bits=w_bits, # quantize model to 4-bit group_size=w_group_size, # it is recommended to set the value to 128 desc_act=False, # lmdeploy only supports False sym=True, # lmdeploy only supports True ) # load un-quantized model, by default, # the model will always be loaded into CPU memory hf_config = AutoConfig.from_pretrained(pretrained_model_dir, revision=revision, trust_remote_code=True) torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) if dtype == 'float16': torch_dtype = torch.float16 elif dtype == 'bfloat16': torch_dtype = torch.bfloat16 model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, revision=revision, torch_dtype=torch_dtype, trust_remote_code=True).cuda() # quantize model, the examples should be list of dict whose keys # can only be "input_ids" and "attention_mask" model.quantize(examples, batch_size=batch_size) # save quantized model model.save_quantized(quantized_model_dir) tokenizer.save_pretrained(quantized_model_dir) if __name__ == '__main__': import fire fire.Fire(auto_gptq) ================================================ FILE: lmdeploy/lite/apis/smooth_quant.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from typing import Literal import fire import torch from torch import nn from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.models import QLinear, QRMSNorm from lmdeploy.utils import try_import_deeplink def smooth_quant(model: str, work_dir: str = './work_dir', calib_dataset: str = 'wikitext2', calib_samples: int = 128, calib_seqlen: int = 2048, search_scale: bool = False, batch_size: int = 1, w_bits: int = 8, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', device: str = 'cuda', quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8', revision: str = None, download_dir: str = None): try_import_deeplink(device) if quant_dtype == 'fp8': quant_dtype = 'float8_e4m3fn' quant_dtype = getattr(torch, quant_dtype, torch.int8) if quant_dtype.is_floating_point: q_dtype_info = torch.finfo(quant_dtype) else: q_dtype_info = torch.iinfo(quant_dtype) assert q_dtype_info.bits == w_bits if not osp.exists(model): print(f'can\'t find model from local_path {model}, ' 'try to download from remote') from lmdeploy.utils import get_model model = get_model(model, revision=revision, download_dir=download_dir) model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset, calib_samples, calib_seqlen, work_dir, device, w_bits=w_bits, w_group_size=-1, search_scale=search_scale, dtype=dtype, batch_size=batch_size) # calibrate function exports the calibration statistics # (inputs, outputs, keys and values) to `work_dir`. inp_stats = torch.load(work_dir / 'inputs_stats.pth', weights_only=True) act_scales = inp_stats['absmax'] model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: raise RuntimeError(f'Currently, quantification and calibration of {model_type} are ' f'not supported. The supported model types are ' f"{', '.join(LAYER_TYPE_MAP.keys())}.") if model_type == 'QWenLMHeadModel': try: import flash_attn # noqa: F401 except ImportError: raise RuntimeError('When using Qwen, you need to `pip install flash-attn` first, ' 'otherwise calibration and quantification will not work ' 'properly.') layer_type = LAYER_TYPE_MAP[type(model).__name__] norm_type = NORM_TYPE_MAP[type(model).__name__] fc2fcs = FC_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type] layers = collect_target_modules(model, layer_type) fcs = {} for l_name, layer in layers.items(): name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name) fcs.update(name2fc) if search_scale: awq_ratios = inp_stats['ratios'] act_scales = inp_stats['absmean'] awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, -1, device) else: smooth_layers(layers, fc2fcs, norm2fcs, act_scales, -1, device) rmsnorms = collect_target_modules(model, norm_type) for name, linear in fcs.items(): if skipped_module(name): continue linear.to(device) q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_linear) linear.to('cpu') q_linear.to('cpu') torch.cuda.empty_cache() for name, norm in rmsnorms.items(): if skipped_module(name): continue norm.to(device) q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_norm) norm.to('cpu') q_norm.to('cpu') torch.cuda.empty_cache() quant_dtype_s = str(quant_dtype).split('.')[1] model.config.update(dict(quantization_config=dict(quant_method='smooth_quant', quant_dtype=f'{quant_dtype_s}'))) if vl_model: from .auto_awq import save_vl_model save_vl_model(vl_model, model_path, work_dir) else: model.save_pretrained(work_dir, safe_serialization=True) tokenizer.save_pretrained(work_dir) if __name__ == '__main__': fire.Fire(smooth_quant) ================================================ FILE: lmdeploy/lite/defaults.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from torch import nn OFFLOAD_MOD = (nn.Linear, ) KV_CACHE_SIGNATURE = 'past_key_value' ================================================ FILE: lmdeploy/lite/modeling/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. ================================================ FILE: lmdeploy/lite/modeling/internlm2_gptq.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from auto_gptq.modeling import BaseGPTQForCausalLM class InternLM2GPTQForCausalLM(BaseGPTQForCausalLM): layer_type = 'InternLM2DecoderLayer' layers_block_name = 'model.layers' outside_layer_modules = ['model.tok_embeddings', 'model.norm'] inside_layer_modules = [ ['attention.wqkv'], ['attention.wo'], ['feed_forward.w3', 'feed_forward.w1'], ['feed_forward.w2'], ] ================================================ FILE: lmdeploy/lite/modeling/internlm3_gptq.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from auto_gptq.modeling import BaseGPTQForCausalLM class InternLM3GPTQForCausalLM(BaseGPTQForCausalLM): layer_type = 'InternLM3DecoderLayer' layers_block_name = 'model.layers' outside_layer_modules = ['model.embed_tokens', 'model.norm'] inside_layer_modules = [ ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj'], ] ================================================ FILE: lmdeploy/lite/quantization/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .activation import ActivationObserver, KVCacheObserver from .calibration import CalibrationContext, CalibrationContextV2 from .weight import WeightQuantizer __all__ = ['WeightQuantizer', 'ActivationObserver', 'KVCacheObserver', 'CalibrationContext', 'CalibrationContextV2'] ================================================ FILE: lmdeploy/lite/quantization/activation/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .observer import ActivationObserver, KVCacheObserver __all__ = ['ActivationObserver', 'KVCacheObserver'] ================================================ FILE: lmdeploy/lite/quantization/activation/observer.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import torch from lmdeploy.lite.utils.global_avail import GlobalAvailMixin class KVCacheObserver(GlobalAvailMixin): """A class to observe and record the max, min, and absolute max value of given tensor.""" def __init__(self, num_head: int, head_dim: int) -> None: """Constructor for KVCacheObserver. Args: num_head : Number of heads head_dim : Dimension of each head """ self.num_head = num_head self.head_dim = head_dim self.max_val = torch.full((num_head, head_dim), -torch.inf, dtype=torch.float16) self.min_val = torch.full((num_head, head_dim), torch.inf, dtype=torch.float16) self.absmax_val = torch.full((num_head, head_dim), 0, dtype=torch.float16) @torch.no_grad() def observe(self, x: torch.Tensor) -> None: """Function to observe the input tensor and update the max, min, and absolute max values. Args: x : Input tensor """ assert len(x.shape) == 4 if x.size(2) == self.num_head and x.size(3) == self.head_dim: # layout: (bs, seqlen, heads, dims) x = x elif x.size(1) == self.num_head and x.size(3) == self.head_dim: # layout: (bs, heads, seqlen, dims) x = x.transpose(1, 2) else: raise RuntimeError cur_max = x.flatten(0, 1).max(0)[0].cpu() cur_min = x.flatten(0, 1).min(0)[0].cpu() cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu() self.max_val = torch.maximum(self.max_val, cur_max) self.min_val = torch.minimum(self.min_val, cur_min) self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) class ActivationObserver(GlobalAvailMixin): """A class to observe and record the max, min, mean, absolute max, and absolute mean value of a given tensor. Also keeps track of the number of batches observed. """ observed = False def __init__(self, dim: int) -> None: """Constructor for ActivationObserver. Args: dim : Dimension of the tensor """ self.dim = dim self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16) self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16) self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16) self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16) self.mean_val = torch.full((dim, ), 0, dtype=torch.float16) self.num_batches_tracked = 0 self.value = None self.ratio = None self.num_ratio_tracked = 0 @classmethod def disable(cls): """To avoid recomputation in search scale process.""" cls.observed = True @classmethod def enable(cls): """To avoid recomputation in search scale process.""" cls.observed = False @torch.no_grad() def observe(self, x: torch.Tensor, save_input: bool = False) -> None: """Function to observe the input tensor and update the max, min, mean, absolute max, absolute mean values and number of batches tracked. Args: x : Input tensor """ assert torch.isnan(x).sum() == 0 if self.observed: return assert x.size(-1) == self.dim cur_val = x.flatten(0, 1) if any([s == 0 for s in cur_val.shape]): return cur_max = cur_val.max(0)[0].cpu() cur_min = cur_val.min(0)[0].cpu() cur_mean = cur_val.mean(0).cpu() cur_abs = cur_val.abs() cur_absmax = cur_abs.max(0)[0].cpu() cur_absmean = cur_abs.mean(0).cpu() self.max_val = torch.maximum(self.max_val, cur_max) self.min_val = torch.minimum(self.min_val, cur_min) self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) if save_input: self.value = x # Update mean and absmean value with accumulated sum divided # by total number of batches self.mean_val = ((self.mean_val * self.num_batches_tracked + cur_mean) / (self.num_batches_tracked + 1)) self.absmean_val = ((self.absmean_val * self.num_batches_tracked + cur_absmean) / (self.num_batches_tracked + 1)) # Increment the count of batches tracked self.num_batches_tracked += 1 @torch.no_grad() def save_ratio(self, ratio: float) -> None: if self.ratio is None: self.ratio = 0 self.ratio = (self.ratio * self.num_ratio_tracked + ratio) / (self.num_ratio_tracked + 1) self.num_ratio_tracked += 1 ================================================ FILE: lmdeploy/lite/quantization/awq.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import List import torch # Maps that describe the structure of your model. NORM_FCS_MAP = { 'LlamaDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'InternLMDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'InternLM2DecoderLayer': { 'attention_norm': ['attention.wqkv'], 'ffn_norm': ['feed_forward.w1', 'feed_forward.w3'] }, 'InternLM3DecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'QWenBlock': { 'ln_1': ['attn.c_attn'], 'ln_2': ['mlp.w1', 'mlp.w2'] }, 'Qwen2DecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'Qwen3DecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'DecoderLayer': { 'input_layernorm': ['self_attn.W_pack'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'Phi3DecoderLayer': { 'input_layernorm': ['self_attn.qkv_proj'], 'post_attention_layernorm': ['mlp.gate_up_proj'] }, 'GLMBlock': { 'input_layernorm': ['self_attention.query_key_value'], 'post_attention_layernorm': ['mlp.dense_h_to_4h'] }, 'MixtralDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3'] }, 'Qwen2VLDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'Qwen2_5_VLDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, 'MistralDecoderLayer': { 'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] }, } FC_FCS_MAP = { 'LlamaDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'InternLMDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'InternLM2DecoderLayer': { 'feed_forward.w3': ['feed_forward.w2'] }, 'InternLM3DecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'QWenBlock': { 'attn.c_attn': ['attn.c_proj'], 'mlp.w1': ['mlp.c_proj'] }, 'Qwen2DecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'Qwen3DecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'DecoderLayer': { 'self_attn.W_pack': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'Phi3DecoderLayer': { 'self_attn.qkv_proj': ['self_attn.o_proj'], 'mlp.gate_up_proj': ['mlp.down_proj'] }, 'GLMBlock': { # 'self_attention.query_key_value': ['self_attention.dense'] # 'mlp.dense_h_to_4h': ['mlp.dense_4h_to_h'] }, 'MixtralDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'block_sparse_moe.experts.{i}.w3': ['block_sparse_moe.experts.{i}.w2'] }, 'Qwen2VLDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'Qwen2_5_VLDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] }, 'MistralDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] } } SKIPPED_MODULE = ['lora', 'block_sparse_moe.gate'] def skipped_module(name: str): """Whether the module should be skipped from quantization.""" for m in SKIPPED_MODULE: if m in name: return True return False @torch.no_grad() def get_weight_scale(weight, q_group_size=-1): org_shape = weight.shape if q_group_size > 0: weight = weight.view(-1, q_group_size) abs_weight = weight.abs() abs_weight_amax = abs_weight.amax(dim=1, keepdim=True) if abs_weight_amax.min().item() == 0: print('weight.amax.min is zero, clamping weight.amax to 1e-4') abs_weight_amax = abs_weight_amax.clamp(min=1e-4) scale = abs_weight / abs_weight_amax scale = scale.view(org_shape) scale = scale.mean(0) return scale @torch.no_grad() def smooth_ln_fcs(ln: torch.nn.Module, fcs: List[torch.nn.Module], act_scales: torch.Tensor, group_size: int = -1, alpha: float = 0.5) -> torch.Tensor: """Smooth weights of a layer normalization and its fully connected layers. :param ln: Layer Normalization module :param fcs: List of Fully Connected modules :param act_scales: Activation scales :param alpha: Scaling factor (default is 0.5) :return: Scales """ device, dtype = fcs[0].weight.device, fcs[0].weight.dtype # If zeros exist within the weight of the layer norm, it becomes # unnecessary to perform smooth quantization at the positions where # these zeros occur. zero_positions = (ln.weight == 0).nonzero(as_tuple=True)[0] nonzero_positions = (ln.weight != 0).nonzero(as_tuple=True)[0] act_scales = act_scales.to(device=device, dtype=dtype) concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) w_scales_pow = w_scales.pow(1 - alpha) if w_scales_pow.min().item() == 0: print('w_scales.pow(1 - alpha).min is zero, ' 'clamping w_scales.pow(1 - alpha) to 1e-4') w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales[nonzero_positions].max() * scales[nonzero_positions].min()).sqrt() scales[zero_positions] = 1 ln.weight.div_(scales) if hasattr(ln, 'bias'): ln.bias.div_(scales) for fc in fcs: fc.weight.mul_(scales.view(1, -1)) for p in ln.parameters(): assert torch.isnan(p).sum() == 0 for fc in fcs: for p in fc.parameters(): assert torch.isnan(p).sum() == 0 return scales @torch.no_grad() def smooth_fc_fcs(pre_fc: torch.nn.Module, fcs: List[torch.nn.Module], act_scales: torch.Tensor, group_size: int = -1, alpha: float = 0.5) -> torch.Tensor: """Smooth weights of a fully connected layer and its downstream layers. :param pre_fc: Previous Fully Connected layer :param fcs: List of Fully Connected modules :param act_scales: Activation scales :param alpha: Scaling factor (default is 0.5) :return: Scales """ device, dtype = pre_fc.weight.device, pre_fc.weight.dtype size_a = act_scales.size(0) size_pre_fc = pre_fc.weight.size(0) # (for llama2) use group query attention, pre_fc is v_proj, fc is o_proj if size_pre_fc < size_a and size_a % size_pre_fc == 0: return act_scales = act_scales.to(device=device, dtype=dtype) concat_w = torch.cat([fc.weight for fc in fcs], dim=0) w_scales = get_weight_scale(concat_w, group_size) w_scales_pow = w_scales.pow(1 - alpha) if w_scales_pow.min().item() == 0: print('w_scales.pow(1 - alpha).min is zero, ' 'clamping w_scales.pow(1 - alpha) to 1e-4') w_scales_pow = w_scales_pow.clamp(min=1e-4) scales = (act_scales.pow(alpha) / w_scales_pow).clamp(min=1e-4).to(device).to(dtype) scales = scales / (scales.max() * scales.min()).sqrt() # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale # phi3 fused qkv and gate_up if size_pre_fc > size_a and size_pre_fc % size_a == 0 \ and size_pre_fc // size_a in [2, 3]: pre_fc.weight[-size_a:].div_(scales.view(-1, 1)) if getattr(pre_fc, 'bias', None) is not None: pre_fc.bias[-size_a:].div_(scales) else: pre_fc.weight.div_(scales.view(-1, 1)) if getattr(pre_fc, 'bias', None) is not None: pre_fc.bias.div_(scales) for fc in fcs: fc.weight.mul_(scales.view(1, -1)) for p in pre_fc.parameters(): assert torch.isnan(p).sum() == 0 for fc in fcs: for p in fc.parameters(): assert torch.isnan(p).sum() == 0 return scales def check_awq_supported(layer_type): """Check if the smooth function is supported by inspecting layer type.""" norm_fcs_found = False fc_fcs_found = False if isinstance(layer_type, str): if layer_type in NORM_FCS_MAP: norm_fcs_found = True if layer_type in FC_FCS_MAP: fc_fcs_found = True elif isinstance(layer_type, type): if layer_type.__name__ in NORM_FCS_MAP: norm_fcs_found = True if layer_type.__name__ in FC_FCS_MAP: fc_fcs_found = True else: raise NotImplementedError if not norm_fcs_found: raise NotImplementedError if not fc_fcs_found: raise NotImplementedError def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'): """Quantize the weights of the target model's linear layers.""" from lmdeploy.lite.quantization import WeightQuantizer from lmdeploy.lite.quantization.modules import WeightOnlyQLinear from lmdeploy.lite.utils import QParams for name, fc in fcs.items(): fc.to(device) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) pack_or_skip = 'packed' if skipped_module(name): q_linear = fc pack_or_skip = 'skipped' else: quantizer = WeightQuantizer(bits, symmetry, 'per_group', group_size) fc.weight.data, scales, zeros = pseudo_quantize_tensor(fc.weight.data, bits, group_size, return_scale_zeros=True) q_linear = WeightOnlyQLinear.from_linear(fc, quantizer, qparams=QParams(scales, zeros)) setattr(parent, child_name, q_linear) fc.to('cpu') torch.cuda.empty_cache() print(f'{name} weight {pack_or_skip}.') def smooth_layers(layers, fc2fcs, norm2fcs, a_scales, group_size=-1, device='cuda'): """Apply weight smoothing based on input scales.""" for l_name, layer in layers.items(): layer.to(device) submodule_names = [name for name, _ in layer.named_modules()] for ln_name, fc_names in norm2fcs.items(): a_name = [f'{l_name}.{n}' for n in fc_names if n in submodule_names][0] ln = layer.get_submodule(ln_name) fcs = [layer.get_submodule(n) for n in fc_names if n in submodule_names] smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size) for f_name, fc_names in fc2fcs.items(): a_name = [f'{l_name}.{n}' for n in fc_names if n in submodule_names][0] fc = layer.get_submodule(f_name) fcs = [layer.get_submodule(n) for n in fc_names if n in submodule_names] smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size) layer.to('cpu') torch.cuda.empty_cache() max_memory = torch.cuda.max_memory_allocated(device=device) / 1024 / 1024 / 1024 print(f'{l_name} smooth weight done.' f' max gpu memory: {max_memory:.2f} GB') def pseudo_quantize_tensor(w, w_bit=8, w_group_size=-1, return_scale_zeros=False): """Pseudo quantize tensor.""" org_w_shape = w.shape if w_group_size > 0: assert org_w_shape[-1] % w_group_size == 0 w = w.reshape(-1, w_group_size) assert w.dim() == 2 max_val = w.amax(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True) max_int = 2**w_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-5) / max_int zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) assert torch.isnan(scales).sum() == 0 assert torch.isnan(w).sum() == 0 q_w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) w = (q_w - zeros) * scales assert torch.isnan(w).sum() == 0 if return_scale_zeros: zeros = zeros.view(org_w_shape[0], org_w_shape[-1] // w_group_size, -1) scales = scales.view(org_w_shape[0], org_w_shape[-1] // w_group_size, -1) q_w = q_w.reshape(org_w_shape) return q_w, scales, zeros w = w.reshape(org_w_shape) return w def awq_layers(layers, fc2fcs, norm2fcs, a_scales, a_ratios=None, group_size=-1, device='cuda'): """Apply awq based on input scales.""" for l_name, layer in layers.items(): layer.to(device) for ln_name, fc_names in norm2fcs.items(): a_name = [f'{l_name}.{n}' for n in fc_names][0] ratios = [a_ratios[f'{l_name}.{n}'] for n in fc_names] ratio = [s for s in ratios if s is not None][0] ln = layer.get_submodule(ln_name) fcs = [layer.get_submodule(n) for n in fc_names] smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size, ratio) for f_name, fc_names in fc2fcs.items(): a_name = [f'{l_name}.{n}' for n in fc_names][0] ratios = [a_ratios[f'{l_name}.{n}'] for n in fc_names] ratios = [s for s in ratios if s is not None] ratio = 0.5 if not len(ratios) else ratios[0] fc = layer.get_submodule(f_name) fcs = [layer.get_submodule(n) for n in fc_names] smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size, ratio) layer.to('cpu') torch.cuda.empty_cache() max_memory = torch.cuda.max_memory_allocated(device=device) / 1024 / 1024 / 1024 print(f'{l_name} smooth weight done.' f' max gpu memory: {max_memory:.2f} GB') ================================================ FILE: lmdeploy/lite/quantization/calibration.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from functools import partial from typing import Union import torch from torch import nn from transformers import PreTrainedTokenizer from lmdeploy.lite.quantization.activation import ActivationObserver from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP from lmdeploy.lite.utils import (bimap_name_mod, collect_target_modules, concat_decoder_layer_outputs, split_decoder_layer_inputs) class CalibrationContext(): """Calibration context manager for model quantization. Parameters: - model: The target model to be calibrated and quantized - tokenizer: The tokenizer used in the model training - layer_type: Layer type to be targeted for calibration - norm_type: Normalization type used for calibration - device: Device on which model is to be calibrated ('cpu' or 'cuda') """ inp_obs_group = 'inputs' out_obs_group = 'outputs' def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer, layer_type: Union[str, type], norm_type: Union[str, type], batch_size: int = 1, device: str = 'cuda', **kwargs) -> None: """Initiate calibration context. Args: model (nn.Module): Model to be calibrated. tokenizer (PreTrainedTokenizer): Tokenizer of the given model. layer_type (Union[str, type]): Type of the layers to be observed. norm_type (Union[str, type]): Norm type used in the model. batch_size (int): The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM. device (str, optional): Device where the model should run. Defaults to 'cuda'. """ self.layer_type = layer_type self.norm_type = norm_type self.batch_size = batch_size num_kv_heads, num_attn_heads = self._guess_num_heads(model) self.num_kv_heads = num_kv_heads self.head_dim = model.config.hidden_size // num_attn_heads self.model = model self.tokenizer = tokenizer # Collect modules to observe self.name2layer = collect_target_modules(self.model, layer_type) self.name2fc = {} for l_name, layer in self.name2layer.items(): name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name) self.name2fc.update(name2fc) self.name2norm = collect_target_modules(self.model, norm_type) maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm]) self.name2mod, self.mod2name = maps # Initialize observers self._init_input_observers(self.name2fc) self._init_output_observers(self.name2norm) self._init_output_observers(self.name2fc) self.device = device def _guess_num_heads(self, model): if hasattr(model.config, 'num_key_value_heads'): num_kv_heads = model.config.num_key_value_heads else: num_kv_heads = model.config.num_attention_heads num_attn_heads = model.config.num_attention_heads return num_kv_heads, num_attn_heads def _init_input_observers(self, name2mod): """Initialize input observers for given modules.""" for name, mod in name2mod.items(): obs = ActivationObserver(mod.weight.size(-1)) obs.global_available(name, group=self.inp_obs_group) def _init_output_observers(self, name2mod): """Initialize output observers for given modules.""" for name, mod in name2mod.items(): obs = ActivationObserver(mod.weight.size(0)) obs.global_available(name, group=self.out_obs_group) def _insert_input_observers(self): """Insert input observers into the target modules. This function registers a forward pre-hook on each target module to observe the inputs. """ def _input_hook(mod: nn.Module, inp: torch.Tensor): m_name = self.mod2name[mod] obs = ActivationObserver.find(m_name, group=self.inp_obs_group) obs.observe(inp[0]) group = ActivationObserver.find_group(self.inp_obs_group) for name in group.keys(): mod = self.name2mod[name] hook_fn = mod.register_forward_pre_hook(_input_hook) self._hooks.append(hook_fn) def _insert_output_observers(self): """Insert output observers into the target modules. This function registers a forward hook on each target module to observe the outputs. """ def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor): m_name = self.mod2name[mod] obs = ActivationObserver.find(m_name, group=self.out_obs_group) obs.observe(out) group = ActivationObserver.find_group(self.out_obs_group) for name in group.keys(): mod = self.name2mod[name] hook_fn = mod.register_forward_hook(_output_hook) self._hooks.append(hook_fn) def _wrap_decoder_layers(self): """Method to wrap the decoder layers' forward functions for observing their key/value cache during batched forward passes.""" def _forward(mod, *args, **kwargs): mod.to(self.device) batch_args, batch_kwargs = split_decoder_layer_inputs(self.batch_size, *args, **kwargs) batch_outputs = [] samples = len(batch_args) m_name = self.mod2name[mod] for i in range(len(batch_args)): batch_outputs.append(self._ori_forwards[mod](*batch_args[i], **batch_kwargs[i])) outputs = concat_decoder_layer_outputs(batch_outputs) del batch_outputs, batch_args, batch_kwargs, args mod.to('cpu') torch.cuda.empty_cache() max_memory = torch.cuda.max_memory_allocated(device=self.device) / 1024 / 1024 / 1024 print(f'{m_name}, samples: {samples}, ' f'max gpu memory: {max_memory:.2f} GB') return outputs for layer in self.name2layer.values(): self._ori_forwards[layer] = layer.forward layer.forward = partial(_forward, layer) def collect_inputs_stats(self): """Collect statistics (min, max, absmax values) of the observed inputs. Returns a dictionary with these collected stats. """ inputs_stats = {'max': {}, 'min': {}, 'mean': {}, 'absmax': {}, 'absmean': {}} obs_group = ActivationObserver.find_group(self.inp_obs_group) for name, obs in obs_group.items(): inputs_stats['max'][name] = obs.max_val inputs_stats['min'][name] = obs.min_val inputs_stats['mean'][name] = obs.mean_val inputs_stats['absmax'][name] = obs.absmax_val inputs_stats['absmean'][name] = obs.absmean_val return inputs_stats def collect_outputs_stats(self): """Collect statistics (min, max, absmax values) of the observed outputs. Returns a dictionary with these collected stats. """ outputs_stats = {'max': {}, 'min': {}, 'mean': {}, 'absmax': {}, 'absmean': {}} obs_group = ActivationObserver.find_group(self.out_obs_group) for name, obs in obs_group.items(): outputs_stats['max'][name] = obs.max_val outputs_stats['min'][name] = obs.min_val outputs_stats['mean'][name] = obs.mean_val outputs_stats['absmax'][name] = obs.absmax_val outputs_stats['absmean'][name] = obs.absmean_val return outputs_stats def export(self, out_dir): """Export the calibration statistics (inputs, outputs, keys and values) to specified directory. Args: out_dir (Union[str, Path]): The directory path where the stats will be saved. """ inp_stats = self.collect_inputs_stats() torch.save(inp_stats, out_dir / 'inputs_stats.pth') torch.cuda.empty_cache() out_stats = self.collect_outputs_stats() torch.save(out_stats, out_dir / 'outputs_stats.pth') torch.cuda.empty_cache() def calibrate(self, data): """Forward pass through the model in inference mode with given data.""" if type(self.model).__name__ in ('QWenLMHeadModel', 'ChatGLMForConditionalGeneration'): model = self.model.transformer else: model = self.model.model with torch.inference_mode(): _ = model(data.to(self.device)) torch.cuda.empty_cache() def __enter__(self): """Prepares the Calibration object for a 'with' statement by registering hooks and wrapping layer forward methods.""" self._hooks = list() self._ori_forwards = {} for layer in self.name2layer.values(): self._ori_forwards[layer] = layer.forward self._insert_input_observers() self._insert_output_observers() self._wrap_decoder_layers() def __exit__(self, exc_type, exc_value, traceback): """Clean up after a 'with' statement by removing registered hooks, restoring original forward methods, and if no exception occurred, collecting all gathered statistics and saving them.""" for h in self._hooks: h.remove() for layer in self.name2layer.values(): layer.forward = self._ori_forwards[layer] @torch.no_grad() def auto_scale_block(module, module_kwargs, w_bit, w_group_size, input_feat, mod_name): if 'use_cache' in module_kwargs: module_kwargs.pop('use_cache') # find the best scale ratio def _search_module_scale(block, linears2scale: list, x, kwargs={}): x = x.to(next(block.parameters()).device) with torch.no_grad(): org_out = block(x, **kwargs) if isinstance(org_out, tuple): org_out = org_out[0] x_max = x.abs().view(-1, x.shape[-1]).mean(0) best_error = float('inf') best_ratio = -1 n_grid = 20 history = [] concat_w = torch.cat([_m.weight for _m in linears2scale], dim=0) from .awq import get_weight_scale, pseudo_quantize_tensor w_mean = get_weight_scale(concat_w, w_group_size) org_sd = {k: v.cpu() for k, v in block.state_dict().items()} for ratio in range(0, n_grid): ratio = ratio / n_grid w_mean_pow = w_mean.pow(1 - ratio) if w_mean_pow.min().item() == 0: print('w_mean.pow(1 - ratio).min is zero, ' 'clamping w_mean.pow(1 - ratio) to 1e-4') w_mean_pow = w_mean_pow.clamp(min=1e-4) scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() for fc in linears2scale: fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.data = pseudo_quantize_tensor(fc.weight.data, w_bit, w_group_size) / (scales.view(1, -1)) out = block(x, **kwargs) if isinstance(out, tuple): out = out[0] # float prevents overflow loss = (org_out - out).float().pow(2).mean().item() history.append(loss) if loss < best_error: best_error = loss best_ratio = ratio block.load_state_dict(org_sd) if best_ratio == -1: print(history) raise Exception return best_ratio def _auto_get_scale(layers, inp, module2inspect=None, kwargs={}): # module2inspect: if given, we will check the output diff of # this module instead of layers if module2inspect is None: assert len(layers) == 1 module2inspect = layers[0] # internlm-xcomposer2-vl applies plora, which requires im_mask arg if module2inspect._get_name() == 'InternLM2MLP': from inspect import signature if 'im_mask' in signature(module2inspect.forward).parameters: kwargs['im_mask'] = None best_ratio = _search_module_scale(module2inspect, layers, inp.value, kwargs) inp.save_ratio(best_ratio) for i, (prev_name, layer_names) in enumerate(NORM_FCS_MAP[module._get_name()].items()): # attention input _auto_get_scale( layers=[module.get_submodule(name) for name in layer_names], inp=input_feat[f'{mod_name}.{layer_names[0]}'], module2inspect=module.get_submodule(layer_names[0].split('.')[0]), kwargs=module_kwargs if i == 0 else {}, # only attention input need ) for prev_name, layer_names in FC_FCS_MAP[module._get_name()].items(): # attention input _auto_get_scale( layers=[module.get_submodule(name) for name in layer_names], inp=input_feat[f'{mod_name}.{layer_names[0]}'], ) class CalibrationContextV2(CalibrationContext): def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer, layer_type: Union[str, type], norm_type: Union[str, type], batch_size: int = 1, device: str = 'cuda', search_scale: bool = True, w_bits: int = 4, w_group_size: int = 128, **kwargs) -> None: super().__init__(model, tokenizer, layer_type, norm_type, batch_size, device) self.w_bits = w_bits self.w_group_size = w_group_size self.search_scale = search_scale def _insert_input_observers(self): """Insert input observers into the target modules. This function registers a forward pre-hook on each target module to observe the inputs. """ def _input_hook(mod: nn.Module, inp: torch.Tensor): m_name = self.mod2name[mod] obs = ActivationObserver.find(m_name, group=self.inp_obs_group) obs.observe(inp[0], self.search_scale) group = ActivationObserver.find_group(self.inp_obs_group) for name in group.keys(): mod = self.name2mod[name] hook_fn = mod.register_forward_pre_hook(_input_hook) self._hooks.append(hook_fn) def export(self, out_dir): """Export the calibration statistics (inputs, outputs, keys and values) to specified directory. Args: out_dir (Union[str, Path]): The directory path where the stats will be saved. """ inputs_stats = { 'max': {}, 'min': {}, 'mean': {}, 'absmax': {}, 'absmean': {}, 'ratios': {}, } obs_group = ActivationObserver.find_group(self.inp_obs_group) for name, obs in obs_group.items(): inputs_stats['max'][name] = obs.max_val inputs_stats['min'][name] = obs.min_val inputs_stats['mean'][name] = obs.mean_val inputs_stats['absmax'][name] = obs.absmax_val inputs_stats['absmean'][name] = obs.absmean_val inputs_stats['ratios'][name] = obs.ratio torch.save(inputs_stats, out_dir / 'inputs_stats.pth') torch.cuda.empty_cache() def _wrap_decoder_layers_for_search(self): """Method to wrap the decoder layers' forward functions for observing their key/value cache during batched forward passes.""" @torch.no_grad() def _forward(mod, *args, **kwargs): mod.to(self.device) batch_args, batch_kwargs = split_decoder_layer_inputs(self.batch_size, *args, **kwargs) batch_outputs = [] samples = len(batch_args) m_name = self.mod2name[mod] for i in range(len(batch_args)): batch_outputs.append(self._ori_forwards[mod](*batch_args[i], **batch_kwargs[i])) obs_group = ActivationObserver.find_group(self.inp_obs_group) mod_name = self.mod2name[mod] ActivationObserver.disable() auto_scale_block(mod, batch_kwargs[i], self.w_bits, self.w_group_size, obs_group, mod_name) ActivationObserver.enable() for key, item in obs_group.items(): if key.startswith(f'{mod_name}.') and item.value is not None: item.value.cpu() del item.value outputs = concat_decoder_layer_outputs(batch_outputs) del batch_outputs, batch_args, batch_kwargs, args mod.cpu() import gc gc.collect() torch.cuda.empty_cache() max_memory = torch.cuda.max_memory_allocated(device=self.device) / (1 << 30) print(f'{m_name}, samples: {samples}, ' f'max gpu memory: {max_memory:.2f} GB') return outputs for layer in self.name2layer.values(): self._ori_forwards[layer] = layer.forward layer.forward = partial(_forward, layer) layer.cpu() def __enter__(self): """Prepares the Calibration object for a 'with' statement by registering hooks and wrapping layer forward methods.""" self._hooks = list() self._insert_input_observers() self._ori_forwards = {} for layer in self.name2layer.values(): self._ori_forwards[layer] = layer.forward if self.search_scale: self._wrap_decoder_layers_for_search() ================================================ FILE: lmdeploy/lite/quantization/modules/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .linear import WeightOnlyQLinear __all__ = ['WeightOnlyQLinear'] ================================================ FILE: lmdeploy/lite/quantization/modules/linear.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Type, TypeVar import torch from torch import nn from lmdeploy.lite.utils.cal_qparams import QParams try: import awq_inference_engine except ModuleNotFoundError: awq_inference_engine = None class WeightOnlyQLinear(nn.Module): """This class implements weight only quantization linear. Args: w_bit (int): number of bits for quantization. symmetry (bool): If true, use symmetric quantization, otherwise use asymmetric quantization. group_size (int): size of the quantization group. in_features (int): size of each input sample. out_features (int): size of each output sample. bias (Tensor, optional): Defaults to None. """ def __init__( self, in_features: int, out_features: int, bias: Optional[torch.Tensor] = True, w_bit: int = 4, symmetry: bool = False, group_size: int = 128, ) -> None: super().__init__() if w_bit not in [2, 4, 8]: raise NotImplementedError('Only 2,4,8 bit are supported for now.') self.in_features = in_features self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 w_pack_oc = out_features // (32 // self.w_bit) w_inc = in_features weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32) self.register_buffer('qweight', weight) if bias: self.register_buffer('bias', torch.zeros(out_features)) else: self.bias = None s_inc = in_features // self.group_size s_oc = out_features scales = torch.zeros((s_inc, s_oc), dtype=torch.float16) self.register_buffer('scales', scales) if not symmetry: z_inc = in_features // self.group_size z_oc = out_features // (32 // self.w_bit) zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32) self.register_buffer('qzeros', zeros) else: self.qzeros = None @classmethod def from_linear(cls: Type['WeightOnlyQLinear'], linear: nn.Linear, quantizer: TypeVar('Quantizer'), awq_layout: bool = True, qparams: Optional[QParams] = None) -> 'WeightOnlyQLinear': """Create a WeightOnlyQLinear object from a PyTorch Linear object. Args: linear (nn.Linear): PyTorch Linear object. quantizer (Quantizer): Object that handles quantization. awq_layout (bool): AWQ layout. Defaults to True. Returns: WeightOnlyQLinear: A WeightOnlyQLinear object. """ device = linear.weight.device w_bit = quantizer.bits pack_num = 32 // w_bit if awq_layout: assert w_bit == 4 pack_order = [0, 2, 4, 6, 1, 3, 5, 7] else: pack_order = torch.arange(pack_num) group_size = quantizer.group_size symmetry = quantizer.symmetry in_features = linear.in_features out_features = linear.out_features bias = False if linear.bias is None else True qlinear = cls(in_features, out_features, bias, w_bit, symmetry, group_size) qlinear.bias = linear.bias if qparams is None: qparams = quantizer.calculate_qparams(linear.weight) i32_w = quantizer.quant(linear.weight, qparams, real=True) else: i32_w = linear.weight.to(torch.int32) i32_w = i32_w.t().contiguous() pack_int_w = torch.zeros_like(qlinear.qweight).to(device) for col in range(pack_int_w.shape[1]): for i in range(pack_num): pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]] pack_int_w[:, col] |= pack_int_w_col << (i * w_bit) qlinear.qweight = pack_int_w qlinear.scales = qparams.scales.squeeze(-1).t().contiguous() if qparams.zero_points is not None: zeros = qparams.zero_points.to(torch.int32).to(device) zeros = zeros.squeeze(-1).t().contiguous() pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device) for col in range(pack_int_zeros.shape[1]): for i in range(pack_num): qzero_col = zeros[:, col * pack_num + pack_order[i]] pack_int_zeros[:, col] |= qzero_col << (i * w_bit) qlinear.qzeros = pack_int_zeros qlinear.to('cpu') return qlinear @torch.no_grad() def forward(self, x): if awq_inference_engine is None: raise RuntimeError('Run the following command to install ' 'the kernel for 4bit inference\n\n' 'git clone https://github.com/mit-han-lab/llm-awq.git\n' 'cd awq/kernels\n' 'python setup.py install\n') out_shape = x.shape[:-1] + (self.out_features, ) inputs = x.reshape(-1, x.shape[-1]) out = awq_inference_engine.gemm_forward_cuda(inputs.half(), self.qweight, self.scales.half(), self.qzeros, self.group_size) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) ================================================ FILE: lmdeploy/lite/quantization/weight/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .quantizer import WeightQuantizer __all__ = ['WeightQuantizer'] ================================================ FILE: lmdeploy/lite/quantization/weight/quant_utils.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Sequence, Union import torch def _aligned_size(a, b): return (a + b - 1) // b * b def fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor: bits_x = x.view(torch.int32) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) result = (exp_x - 127).to(torch.int32) result = result + torch.where(man_bits != 0, 1, 0) return result.to(torch.int32) def fast_pow2_torch(x: torch.Tensor) -> torch.Tensor: bits_x = (x + 127) << 23 return bits_x.view(torch.float32) def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor) -> torch.Tensor: return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max)) def _get_quant_scaling(weight: torch.Tensor, fp8_dtype: torch.dtype, dim: Union[int, Sequence[int]], scale_fmt: Optional[str] = None): """Get the scaling factor for FP8 quantization.""" finfo = torch.finfo(fp8_dtype) fmax = finfo.max amax = weight.abs().amax(dim, keepdim=True).clamp_min(1e-6).float() if scale_fmt == 'ue8m0': return fast_round_scale_torch(amax, fmax) else: # default scaling = amax / fmax return scaling def quant_blocked_fp8(weight: torch.Tensor, fp8_dtype: torch.dtype, block_size: int = 128, scale_fmt: Optional[str] = None): """Quantize the weight tensor to blocked FP8 format.""" assert scale_fmt in (None, 'ue8m0'), f'Unsupported scale_fmt: {scale_fmt}' weight_shape = weight.shape K, N = weight_shape[-2:] aligned_k = _aligned_size(K, block_size) aligned_n = _aligned_size(N, block_size) # fill the weight tensor with zeros if it is not aligned if aligned_k != K or aligned_n != N: new_weight = weight.new_zeros(weight_shape[:-2] + (aligned_k, aligned_n)) new_weight[..., :K, :N] = weight weight = new_weight aligned_shape = weight.shape # reverse pixel shuffle weight = weight.unflatten(-2, (-1, block_size)).unflatten(-1, (-1, block_size)) weight = weight.to(torch.float32) # get scaling scaling = _get_quant_scaling(weight, fp8_dtype, dim=(-3, -1), scale_fmt=scale_fmt) # get quantized weight quantized_weight = weight / scaling quantized_weight = quantized_weight.to(fp8_dtype) quantized_weight = quantized_weight.view(aligned_shape) quantized_weight = quantized_weight[..., :K, :N] # reshape scaling scaling = scaling.squeeze(-3, -1) return quantized_weight, scaling ================================================ FILE: lmdeploy/lite/quantization/weight/quantizer.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Callable, Dict, Optional import torch from lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax, cal_qparams_per_group_absmax, cal_qparams_per_group_minmax, cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_minmax, precise_round) from lmdeploy.lite.utils.global_avail import GlobalAvailMixin class WeightQuantizer(GlobalAvailMixin): """A class for performing weight quantization of neural networks. The WeightQuantizer class provides various methods to quantize the weights of a neural network. This helps in reducing the memory requirements and computational complexity of the model, potentially offering faster inference and lower power consumption. Attributes: bits (int): The bit width for quantization. symmetry (bool): If True, use absmax scaling; if False, use min-max scaling. granularity (str): The granularity of quantization. Available options are 'per_channel', 'per_tensor', and 'per_group'. group_size (Optional[int]): If using 'per_group' quantization, this is the number of channels in each group. Example: # Instantiate the weight quantizer with specific quantization settings quantizer = WeightQuantizer(bits=8, symmetry=True, granularity='per_tensor') # Calculate the quantization parameters for given weights qparams = quantizer.calculate_qparams(weights) # Perform fake quantization on the weights quantized_weights = quantizer.fake_quant(weights, qparams) """ CAL_FUNC_MAP: Dict[str, Dict[str, Callable]] = { 'per_group': { 'absmax': cal_qparams_per_group_absmax, 'minmax': cal_qparams_per_group_minmax, }, 'per_channel': { 'absmax': cal_qparams_per_channel_absmax, 'minmax': cal_qparams_per_channel_minmax, }, 'per_tensor': { 'absmax': cal_qparams_per_tensor_absmax, 'minmax': cal_qparams_per_tensor_minmax, }, } def __init__(self, bits: int, symmetry: bool, granularity: str, group_size: Optional[int] = -1): assert bits in [4, 8], "The 'bits' argument must be either 4 or 8." self.bits = bits if granularity not in ['per_channel', 'per_tensor', 'per_group']: raise NotImplementedError("The 'granularity' argument must be one of 'per_channel', " "'per_tensor', or 'per_group'.") self.granularity = granularity if self.granularity == 'per_group': assert group_size > 0, \ "The 'group_size' argument must be greater than 0." self.group_size = group_size # If symmetry is True, use absmax to compute scales # If symmetry is False, use minmax to compute scales and zeor-points self.symmetry = symmetry self.observer = 'absmax' if symmetry else 'minmax' def calculate_qparams(self, weight: torch.Tensor) -> QParams: """Calculate the quantization parameters for the given weight tensor. Args: weight (torch.Tensor): The weight tensor with shape (out_features, in_features). Returns: QParams: A namedtuple containing 'scales' and 'zero_points'. """ cal_func = self.CAL_FUNC_MAP[self.granularity][self.observer] if self.granularity == 'per_group': return cal_func(weight, self.bits, self.group_size) else: return cal_func(weight, self.bits) def quant(self, weight: torch.Tensor, qparams: Optional[QParams] = None, real: bool = False) -> torch.Tensor: """Perform fake quantization on the given weight tensor. Args: weight (torch.Tensor): The weight tensor with shape (out_features, in_features). qparams (Optional[QParams]): A namedtuple containing 'scales' and 'zero_points'. real (bool): If True, return the tensor with quantized type. Returns: torch.Tensor: The fake quantized weight tensor. """ float_w = weight.float() if qparams is None: qparams = self.calculate_qparams(float_w) scales = qparams.scales zero_points = qparams.zero_points out_c, in_c = weight.shape # Reshape the weights if using per_group quantization # per tensor scales shape: [1] # per channel scales shape: [out_c, 1] # per group scales shape: [out_c, in_c//group_size, 1] if len(scales.shape) > 2: # scales shape: [out_c, in_c//group_size, 1] float_w = float_w.reshape(out_c, scales.shape[1], -1) if zero_points is None: assert self.symmetry real_qweight = (float_w / scales).round() fake_qweight = real_qweight * scales else: assert not self.symmetry real_qweight = precise_round((float_w - float_w.min(-1, keepdim=True)[0]) / scales) fake_qweight = (real_qweight - zero_points) * scales if len(scales.shape) > 2: real_qweight = real_qweight.reshape(out_c, in_c) fake_qweight = fake_qweight.reshape(out_c, in_c) if real: return real_qweight.to(torch.int32) else: return fake_qweight.to(weight.dtype) ================================================ FILE: lmdeploy/lite/utils/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from .batch_split import concat_decoder_layer_outputs, split_decoder_layer_inputs from .cal_qparams import (QParams, cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax, cal_qparams_per_group_absmax, cal_qparams_per_group_minmax, cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_minmax, precise_round) from .calib_dataloader import get_calib_loaders from .collect import bimap_name_mod, collect_target_modules, collect_target_weights from .global_avail import GlobalAvailMixin from .load import load_hf_from_pretrained __all__ = [ 'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax', 'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax', 'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax', 'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round', 'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs', 'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained' ] ================================================ FILE: lmdeploy/lite/utils/batch_split.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Dict, List, Tuple, Union import torch def split_decoder_layer_inputs(batch_size, *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any]) -> Tuple[List[List[Any]], List[Dict[str, Any]]]: """This function splits batched decoder layer inputs into individual elements. Args: *args (Union[torch.Tensor, Any]): Positional arguments which could be a mix of tensors and other types. **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could be a mix of tensors and other types. Returns: Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two lists, one for positional arguments, one for keyword arguments. Each list contains individual elements from the batch. """ if not isinstance(args[0], torch.Tensor): raise ValueError('The first argument must be a Tensor') bs = args[0].size(0) batch_args = [] batch_kwargs = [] for i in range(0, bs, batch_size): new_args = [] # Iterate over each argument. If it's a torch.Tensor and its first # dimension equals the batch size, then get the value corresponding # to the current index, else directly add the whole value. for val in args: if isinstance(val, torch.Tensor) and val.size(0) == bs: new_args.append(val[i:i + batch_size]) else: new_args.append(val) new_kwargs = {} # Execute the same operation for the keyword arguments. for name, val in kwargs.items(): if isinstance(val, torch.Tensor) and val.size(0) == bs: new_kwargs[name] = val[i:i + batch_size] elif isinstance(val, torch.Tensor) and len(val.shape) > 1 and val.size(1) == bs: # qwen2-vl new_kwargs[name] = val[:, i:i + batch_size] elif name == 'position_embeddings' and isinstance(val, Tuple) and len( val[0].shape) > 1 and val[0].size(1) == bs: # qwen2-vl new_kwargs[name] = (val[0][:, i:i + batch_size], val[1][:, i:i + batch_size]) else: new_kwargs[name] = val batch_args.append(new_args) batch_kwargs.append(new_kwargs) return batch_args, batch_kwargs def concat_decoder_layer_outputs(batch_outputs: List[Any]) -> Any: """This function concatenates individual decoder layer outputs into a batched output. Args: batch_outputs (List[Any]): A list, where each tuple represents the output from an individual element in the batch. Returns: Any: Batched output. """ output_is_tuple = True if not isinstance(batch_outputs[0], tuple): output_is_tuple = False batch_outputs = [(output, ) for output in batch_outputs] num_returns = len(batch_outputs[0]) def is_past_key_value(data: Any) -> bool: """Check whether data is a past key-value pair. Args: data (Any): The data to check. Returns: bool: True if data is a past key-value pair, False otherwise. """ flag = isinstance(data, tuple) flag = flag and len(data) == 2 flag = flag and isinstance(data[0], torch.Tensor) flag = flag and isinstance(data[1], torch.Tensor) return flag new_outputs = [] # Iterate over all types of return values. for i in range(num_returns): # Check if the current element is a past key-value pair. flag = is_past_key_value(batch_outputs[0][i]) if flag: # Concatenate the keys and values separately. key = torch.cat([out[i][0] for out in batch_outputs]) value = torch.cat([out[i][1] for out in batch_outputs]) out_i = (key, value) elif batch_outputs[0][i] is None: # glm4 out_i = None else: # If it's not a past key-value pair, concatenate directly. out_i = torch.cat([out[i] for out in batch_outputs]) new_outputs.append(out_i) if output_is_tuple: return tuple(new_outputs) else: return new_outputs[0] ================================================ FILE: lmdeploy/lite/utils/cal_qparams.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import NamedTuple, Optional import torch class QParams(NamedTuple): """A class to hold the quantization parameters.""" scales: torch.Tensor zero_points: Optional[torch.Tensor] @torch.no_grad() def precise_round(x): return x.sign() * (x.abs() + 0.5).floor() @torch.no_grad() def cal_qparams_per_channel_absmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for each channel using absolute max value.""" float_w = w.float() absmax = float_w.abs().max(dim=-1, keepdim=True)[0] q_max = 2**(n_bits - 1) - 1 scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax else: return QParams(scales=scales, zero_points=None) @torch.no_grad() def cal_qparams_per_channel_minmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for each channel using min and max values.""" float_w = w.float() w_min = float_w.min(dim=-1, keepdim=True)[0] w_max = float_w.max(dim=-1, keepdim=True)[0] q_max = 2**n_bits - 1 scales = (w_max - w_min) scales = scales.div_(q_max) zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) else: return QParams(scales=scales, zero_points=zero_points) @torch.no_grad() def cal_qparams_per_group_absmax(w: torch.Tensor, n_bits: int, group_size: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for each group using absolute max value.""" outc, inc = w.shape assert inc >= group_size, \ 'Input channels should be greater than or equal to group_size.' assert inc % group_size == 0, \ 'Input channels should be divisible by group_size.' float_w = w.float() absmax = float_w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0] q_max = 2**(n_bits - 1) - 1 scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax else: return QParams(scales=scales, zero_points=None) @torch.no_grad() def cal_qparams_per_group_minmax(w: torch.Tensor, n_bits: int, group_size: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for each group using min and max values.""" outc, inc = w.shape assert inc >= group_size, \ 'Input channels should be greater than or equal to group_size.' assert inc % group_size == 0, \ 'Input channels should be divisible by group_size.' float_w = w.float() w_group_wise = float_w.reshape(outc, -1, group_size) w_min = w_group_wise.min(dim=-1, keepdim=True)[0] w_max = w_group_wise.max(dim=-1, keepdim=True)[0] q_max = 2**n_bits - 1 scales = (w_max - w_min) scales = scales.div_(q_max) zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) else: return QParams(scales=scales, zero_points=zero_points) @torch.no_grad() def cal_qparams_per_tensor_minmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for the entire tensor using min and max values.""" float_w = w.float() w_min = float_w.min() w_max = float_w.max() q_max = 2**n_bits - 1 scales = (w_max - w_min) scales = scales.clamp_(min=1e-5).div_(q_max) zero_points = precise_round(-w_min / scales) if return_stats: return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) else: return QParams(scales=scales, zero_points=zero_points) @torch.no_grad() def cal_qparams_per_tensor_absmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams: """Calculate quantization parameters for the entire tensor using absolute max value.""" float_w = w.float() absmax = float_w.abs().max() q_max = 2**(n_bits - 1) - 1 scales = absmax.div(q_max) if return_stats: return QParams(scales=scales, zero_points=None), absmax else: return QParams(scales=scales, zero_points=None) ================================================ FILE: lmdeploy/lite/utils/calib_dataloader.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch NUM_LOADED_SAMPLES = 30000 def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) # adapted from https://github.com/vllm-project/llm-compressor/blob/main/tests/testing_utils.py def process_dataset(ds, tokenizer, max_seq_length): """Helper function to preprocess and tokenize a dataset according to presets. Args: ds: Language dataset to preprocess and tokenize. tokenizer: Tokenizer to encode text. max_seq_length: Maximum sequence length of samples. Returns: ds: Tokenized dataset. """ ds_name = ds.info.dataset_name.lower() if ds_name == 'gsm8k': def tokenize(sample): return tokenizer( sample['question'], padding=False, max_length=max_seq_length, truncation=True, add_special_tokens=False, ) elif ds_name == 'open-platypus': # use the output rather than the instruction def tokenize(sample): messages = [{ 'role': 'user', 'content': sample['instruction'] + ' ' + sample['input'] }, { 'role': 'assistant', 'content': sample['output'] }] return tokenizer( tokenizer.apply_chat_template( messages, tokenize=False, ), padding=False, max_length=max_seq_length, truncation=True, add_special_tokens=False, ) # "neuralmagic/calibration" elif ds_name == 'calibration': def tokenize(sample): messages = [] for message in sample['messages']: if message['role'] == 'user': messages.append({'role': 'user', 'content': message['content']}) elif message['role'] == 'assistant': messages.append({'role': 'assistant', 'content': message['content']}) return tokenizer( tokenizer.apply_chat_template( messages, tokenize=False, ), padding=False, max_length=max_seq_length, truncation=True, add_special_tokens=False, ) elif ds_name == 'openwebtext': def tokenize(sample): return tokenizer( sample['text'], padding=False, max_length=max_seq_length, truncation=True, add_special_tokens=False, ) else: raise NotImplementedError(f'Cannot preprocess dataset {ds.info.dataset_name} ' f'Only `gsm8k`, `open-platypus`, `calibration`, `openwebtext` ' f'are supported by preprocess. ') ds = ds.map(tokenize, remove_columns=ds.column_names) return ds def get_wikitext2(dataset, tokenizer, nsamples, seed, seqlen): """Load Wikitext-2 train and test datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ trainenc = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') import random random.seed(seed) trainloader = [] for _ in range(nsamples): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) j = i + seqlen inp = trainenc.input_ids[:, i:j] trainloader.append(inp) return trainloader def get_c4(dataset, tokenizer, nsamples, seed, seqlen): """Load C4 train and validation datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ import random random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(dataset) - 1) trainenc = tokenizer(dataset[i]['text'], return_tensors='pt') if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) j = i + seqlen inp = trainenc.input_ids[:, i:j] trainloader.append(inp) return trainloader def get_pileval(dataset, tokenizer, nsamples, seed, seqlen=512): """Load pileval train dataset and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ # pileval samples have far fewer tokens than seqlen; recompute how many # train items to select so it can still yield enough samples after concatenation. samples_encode = [] lengths = [] for data in dataset: ids = tokenizer.encode(data['text'].strip()) if not ids or len(ids) > 512: continue samples_encode.append(torch.tensor([ids])) lengths.append(len(ids)) if len(samples_encode) >= len(dataset): break avg_tokens = sum(lengths) / len(lengths) needed_samples = max(1, int((seqlen * nsamples) // avg_tokens)) dataset = dataset.shuffle(seed=seed) samples = [] n_run = 0 for data in dataset: line = data['text'] line = line.strip() line_encoded = tokenizer.encode(line) if len(line_encoded) > 512: continue sample = torch.tensor([line_encoded]) if sample.numel() == 0: continue samples.append(sample) n_run += 1 if n_run == needed_samples: break # now concatenate all samples and split according to block size cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // seqlen print(f' * Split into {n_split} blocks') return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)] def get_gsm8k(dataset, tokenizer, nsamples, seed, seqlen): """Load GSM8K train and test datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ dataset = dataset.shuffle(seed=seed) dataset = process_dataset(dataset, tokenizer, seqlen) # GSM8K samples have far fewer tokens than seqlen; recompute how many # train items to select so it can still yield enough samples after concatenation. lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long) avg_tokens = lengths.sum().item() // len(dataset) needed_samples = max(1, int((seqlen * nsamples) // avg_tokens)) samples = [] n_run = 0 for i in range(len(dataset)): line = dataset[i]['input_ids'] sample = torch.tensor([line]) if sample.numel() == 0: continue samples.append(sample) n_run += 1 if n_run == needed_samples: break cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // seqlen print(f' * Split into {n_split} blocks') return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)] def get_neuralmagic_calibration(dataset, tokenizer, nsamples, seed, seqlen): """Load neuralmagic_calibration train and test datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ dataset = dataset.shuffle(seed=seed) dataset = process_dataset(dataset, tokenizer, seqlen) # neuralmagic_calibration samples have far fewer tokens than seqlen; recompute how many # train items to select so it can still yield enough samples after concatenation. lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long) avg_tokens = lengths.sum().item() / len(dataset) needed_samples = max(1, int((seqlen * nsamples) // avg_tokens)) samples = [] n_run = 0 for i in range(len(dataset)): line = dataset[i]['input_ids'] sample = torch.tensor([line]) if sample.numel() == 0: continue samples.append(sample) n_run += 1 if n_run == needed_samples: break cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // seqlen print(f' * Split into {n_split} blocks') return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)] def get_open_platypus(dataset, tokenizer, nsamples, seed, seqlen): """Load open-platypus train and test datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ dataset = dataset.shuffle(seed=seed) dataset = process_dataset(dataset, tokenizer, seqlen) # open-platypus samples have far fewer tokens than seqlen; recompute how many # train items to select so it can still yield enough samples after concatenation. lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long) avg_tokens = lengths.sum().item() / len(dataset) needed_samples = max(1, int((seqlen * nsamples) // avg_tokens)) samples = [] n_run = 0 for i in range(len(dataset)): line = dataset[i]['input_ids'] sample = torch.tensor([line]) if sample.numel() == 0: continue samples.append(sample) n_run += 1 if n_run == needed_samples: break cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // seqlen print(f' * Split into {n_split} blocks') return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)] def get_openwebtext(dataset, tokenizer, nsamples, seed, seqlen): """Load openwebtext train and test datasets and tokenize. Args: dataset: calib dataset tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ dataset = dataset.shuffle(seed=seed) dataset = process_dataset(dataset, tokenizer, seqlen) import random random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(dataset) - 1) trainenc = dataset[i] if len(trainenc['input_ids']) >= seqlen: break i = random.randint(0, len(trainenc['input_ids']) - seqlen) j = i + seqlen inp = trainenc['input_ids'][i:j] inp = torch.tensor([inp]) trainloader.append(inp) return trainloader def get_calib_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048): """Get calibration data loaders for a dataset. Args: name: Dataset name ('wikitext2', 'c4', 'pileval', 'gsm8k', 'neuralmagic_calibration', 'open-platypus', 'openwebtext'). tokenizer: Tokenizer to encode text. nsamples: Number of samples to take from train set. seed: Random seed for sampling. seqlen: Maximum sequence length. Returns: List of sampled and tokenized training examples. """ from datasets import VerificationMode, load_dataset if 'wikitext2' in name: dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') return get_wikitext2(dataset, tokenizer, nsamples, seed, seqlen) if 'c4' in name: dataset = load_dataset('allenai/c4', 'en', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', verification_mode=VerificationMode.NO_CHECKS) return get_c4(dataset, tokenizer, nsamples, seed, seqlen) if 'pileval' in name: from datasets.builder import DatasetGenerationError try: dataset = load_dataset('mit-han-lab/pile-val-backup', split=f'validation[:{NUM_LOADED_SAMPLES}]') except DatasetGenerationError: raise InterruptedError('There have been some issues when generating ' 'the dataset, you could try to download it ' 'locally first, and replace the `data_files`' 'with local addresses or use other datasets ' '(c4, wiki, ptb).') return get_pileval(dataset, tokenizer, nsamples, seed, seqlen) if 'gsm8k' in name: dataset = load_dataset('openai/gsm8k', 'main', split='train') return get_gsm8k(dataset, tokenizer, nsamples, seed, seqlen) if 'neuralmagic_calibration' in name: dataset = load_dataset('neuralmagic/calibration', 'LLM', split='train') return get_neuralmagic_calibration(dataset, tokenizer, nsamples, seed, seqlen) if 'open-platypus' in name: dataset = load_dataset('garage-bAInd/Open-Platypus', split='train') return get_open_platypus(dataset, tokenizer, nsamples, seed, seqlen) if 'openwebtext' in name: dataset = load_dataset('Skylion007/openwebtext', data_files={'train': 'plain_text/train-00000-of-00080.parquet'}, split=f'train[:{NUM_LOADED_SAMPLES}]', verification_mode=VerificationMode.NO_CHECKS) return get_openwebtext(dataset, tokenizer, nsamples, seed, seqlen) ================================================ FILE: lmdeploy/lite/utils/collect.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Tuple, Union from torch import nn def collect_target_modules(model: nn.Module, target: Union[str, type], skip_names: List[str] = [], prefix: str = '') -> Dict[str, nn.Module]: """Collects the specific target modules from the model. Args: model : The PyTorch module from which to collect the target modules. target : The specific target to be collected. It can be a class of a module or the name of a module. skip_names : List of names of modules to be skipped during collection. prefix : A string to be added as a prefix to the module names. Returns: A dictionary mapping from module names to module instances. """ if not isinstance(target, (type, str)): raise TypeError('Target must be a string (name of the module) ' 'or a type (class of the module)') def _is_target(n, m): if isinstance(target, str): return target == type(m).__name__ and n not in skip_names return isinstance(m, target) and n not in skip_names name2mod = {} for name, mod in model.named_modules(): m_name = f'{prefix}.{name}' if prefix else name if _is_target(name, mod): name2mod[m_name] = mod return name2mod def collect_target_weights(model: nn.Module, target: Union[str, type], skip_names: List[str]) -> Dict[str, nn.Module]: """Collects weights of the specific target modules from the model. Args: model : The PyTorch module from which to collect the weights of target modules. target : The specific target whose weights to be collected. It can be a class of a module or the name of a module. skip_names : Names of modules to be skipped during weight collection. Returns: A dictionary mapping from module instances to their corresponding weights. """ named_modules = collect_target_modules(model, target, skip_names) mod2weight = {} for _, mod in named_modules.items(): assert hasattr(mod, 'weight'), "The module does not have a 'weight' attribute" mod2weight[mod] = mod.weight return mod2weight def bimap_name_mod(name2mod_mappings: List[Dict[str, nn.Module]]) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]: """Generates bidirectional maps from module names to module instances and vice versa. Args: name2mod_mappings : List of dictionaries each mapping from module names to module instances. Returns: Two dictionaries providing bidirectional mappings between module names and module instances. """ name2mod = {} mod2name = {} for mapping in name2mod_mappings: mod2name.update({v: k for k, v in mapping.items()}) name2mod.update(mapping) return name2mod, mod2name ================================================ FILE: lmdeploy/lite/utils/global_avail.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Union from torch import nn class GlobalAvailMixin: """Mixin class to make instances globally available.""" _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {'default': {}} def global_available(self, key: Union[str, nn.Module] = 'default', group: str = 'default') -> None: """Make the instance globally available. Args: key (Union[str, nn.Module], optional): Key to save the instance. Defaults to 'default'. group (str, optional): Group to save the instance. Defaults to 'default'. """ self._save_instance(self, key, group) @classmethod def _save_instance(cls, instance: 'GlobalAvailMixin', key: Union[str, nn.Module] = 'default', group: str = 'default') -> None: """Save the instance. Args: instance (GlobalAvailMixin): Instance to save. key (Union[str, nn.Module], optional): Key to save the instance. Defaults to 'default'. group (str, optional): Group to save the instance. Defaults to 'default'. """ if group not in cls._instances: assert isinstance(group, str) cls._instances[group] = {} cls._instances[group][key] = instance @classmethod def find(cls, key: Union[str, nn.Module] = 'default', group: str = 'default') -> Union[None, 'GlobalAvailMixin']: """Find an instance by its key and group. Args: key (Union[str, nn.Module], optional): Key of the instance. Defaults to 'default'. group (str, optional): Group of the instance. Defaults to 'default'. Returns: Union[None, GlobalAvailMixin]: The found instance, or None if it does not exist. """ return cls._instances.get(group, {}).get(key) @classmethod def find_group(cls, group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']: """Find all instances in a group. Args: group (str): Group of the instances. Returns: Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in the group. """ return cls._instances.get(group, {}) @classmethod def instances(cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]: """Get all instances.""" return cls._instances ================================================ FILE: lmdeploy/lite/utils/load.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. from typing import Literal import torch from transformers import AutoConfig, AutoModelForCausalLM class LoadNoInit: """Initialize model without parameter initialization.""" def __init__(self): self.constant_ = torch.nn.init.constant_ self.zeros_ = torch.nn.init.zeros_ self.ones_ = torch.nn.init.ones_ self.uniform_ = torch.nn.init.uniform_ self.normal_ = torch.nn.init.normal_ self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_ self.kaiming_normal_ = torch.nn.init.kaiming_normal_ self.tensor_normal_ = torch.Tensor.normal_ def __enter__(self, *args, **kwargs): """Replace initializers with no-op.""" torch.nn.init.constant_ = lambda *args, **kwargs: None torch.nn.init.zeros_ = lambda *args, **kwargs: None torch.nn.init.ones_ = lambda *args, **kwargs: None torch.nn.init.uniform_ = lambda *args, **kwargs: None torch.nn.init.normal_ = lambda *args, **kwargs: None torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None torch.Tensor.normal_ = lambda *args, **kwargs: None def __exit__(self, *args, **kwargs): """Recover.""" torch.nn.init.constant_ = self.constant_ torch.nn.init.zeros_ = self.zeros_ torch.nn.init.ones_ = self.ones_ torch.nn.init.uniform_ = self.uniform_ torch.nn.init.normal_ = self.normal_ torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ torch.nn.init.kaiming_normal_ = self.kaiming_normal_ torch.Tensor.normal_ = self.tensor_normal_ def load_hf_from_pretrained(pretrained_model_name_or_path, dtype: Literal['float16', 'bfloat16', 'auto'], **kwargs): if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported(): raise RuntimeError('Your device does not supports bf16(bfloat16), ' 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) # HACK hard code for qwen, other configs do not have the `fp16` attribute. if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'): if dtype == 'bfloat16': hf_config.bf16 = True else: hf_config.fp16 = True torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16) if dtype == 'bfloat16': torch_dtype = torch.bfloat16 elif dtype == 'float16': torch_dtype = torch.float16 elif dtype == 'auto' and torch_dtype == torch.bfloat16: print('Warning: we cast model to float16 to prevent OOM. ' 'You may enforce it bfloat16 by `--dtype bfloat16`') torch_dtype = torch.float16 with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, config=hf_config, torch_dtype=torch_dtype, **kwargs) model.config.use_cache = False return model ================================================ FILE: lmdeploy/lite/utils/memory_efficient.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import inspect import re import warnings from contextlib import contextmanager from functools import partial from typing import List import torch from torch import nn from lmdeploy.lite.defaults import KV_CACHE_SIGNATURE, OFFLOAD_MOD def extract_return_values(module: nn.Module) -> List[str]: """Extracts return values from given module's forward method. Args: module (nn.Module): Module to inspect Returns: list[str]: List of return values """ last_line = inspect.getsource(module.forward).rstrip('\n').split('\n')[-1] pattern = r'return ([\w\s,]+)' match = re.search(pattern, last_line) if match: return_values = match.group(1).split(',') return [value.strip() for value in return_values] else: return [] def find_kv_cache_idx(module: nn.Module) -> int: """Finds index of kv cache signature in module's forward parameters.""" signatures = list(inspect.signature(module.forward).parameters.keys()) if KV_CACHE_SIGNATURE not in signatures: raise ValueError(f'{KV_CACHE_SIGNATURE} not in signatures of ' f'{type(module)} forward.') return signatures.index(KV_CACHE_SIGNATURE) def find_modules_by_return_value(model: nn.Module, value: str) -> List[nn.Module]: """Finds modules in model that return given value. Args: model (nn.Module): Model to inspect value (str): Return value to search for Returns: list[nn.Module]: List of matching modules Raises: ValueError: If no matching modules found """ modules = [] for name, module in model.named_modules(): returns = extract_return_values(module) if value in returns: print(f'Found {name} returning {value}') modules.append(module) if not modules: error_msg = f'No modules found returning {value}. ' error_msg += 'Please check if the default KV_CACHE_SIGNATURE ' error_msg += f"'{KV_CACHE_SIGNATURE}' matches what is used in your " error_msg += 'model code. If not, you can modify KV_CACHE_SIGNATURE ' error_msg += 'in `lmdeploy.lite.defaults`.' raise ValueError(error_msg) return modules @contextmanager def offload_kv_cache(model: nn.Module, device: str = 'cuda') -> None: """Offloads kv cache to given device during forward pass. Args: model (nn.Module): Model for inference device (str): Device to offload to Yields: None """ modules = find_modules_by_return_value(model, KV_CACHE_SIGNATURE) original_forwards = {mod: mod.forward for mod in modules} input_idxs = {mod: find_kv_cache_idx(mod) for mod in modules} output_idxs = {mod: extract_return_values(mod).index(KV_CACHE_SIGNATURE) for mod in modules} def wrap_forward(module, *args, **kwargs): idx = input_idxs[module] if idx >= len(args): # kv cache in kwargs if KV_CACHE_SIGNATURE in kwargs: if kwargs[KV_CACHE_SIGNATURE]: kwargs[KV_CACHE_SIGNATURE] = kwargs[KV_CACHE_SIGNATURE].to(device) else: raise ValueError(f'No kv cache input found at index {idx}') else: # kv cache in args args = list(args) args[idx] = args[idx].to(device) args = tuple(args) result = original_forwards[module](*args, **kwargs) result = list(result) idx = output_idxs[module] # Move kv cache outputs back to CPU key = result[idx][0].to('cpu') value = result[idx][1].to('cpu') torch.cuda.empty_cache() result[idx] = (key, value) result = tuple(result) return result try: for module in modules: original_forwards[module] = module.forward module.forward = partial(wrap_forward, module) yield finally: for module in modules: module.forward = original_forwards[module] del original_forwards[module] @contextmanager def offload_weights(model: nn.Module, device: str = 'cuda') -> None: """Offloads specified modules to given device during forward pass. Args: model (nn.Module): Model for inference device (str): Device to offload to Yields: None """ target_modules = OFFLOAD_MOD def before_forward(module: nn.Module, inp: torch.Tensor): module.to(device) def after_forward(module: nn.Module, inp: torch.Tensor, out: torch.Tensor): module.to('cpu') torch.cuda.empty_cache() def _to_device(m, spec_modules, dev): if len(spec_modules) == 0 or len(list(m.children())) == 0: m.to(dev) return for child in m.children(): if isinstance(child, spec_modules): child.to('cpu') else: _to_device(child, spec_modules, dev) # m.to(dev) warnings.warn('By default, offloading will be done on ' '`nn.Linear`. You can add modules which want offload to ' 'the `lmdeploy.lite.defaults.OFFLOAD_MOD`.') target = OFFLOAD_MOD _to_device(model, target, device) handles = [] for module in model.modules(): if isinstance(module, target_modules): handle1 = module.register_forward_pre_hook(before_forward) handle2 = module.register_forward_hook(after_forward) handles.extend([handle1, handle2]) try: yield finally: for handle in handles: handle.remove() model.to('cpu') torch.cuda.empty_cache() @contextmanager def memory_efficient_inference(model: nn.Module, offload: bool = True, device: str = 'cuda') -> None: """Memory efficient inference context manager. Moves model to device for inference, with option to offload specific modules. Args: model (nn.Module): Model for inference offload (bool): Whether to offload modules device (str): Device for inference Yields: None """ if offload: warnings.warn('Using offload mode - modules defined in OFFLOAD_MOD ' 'will be moved to GPU during forward pass only.') warnings.warn('Using offload mode will incur performance penalty due to ' 'frequent CPU-GPU data transfers.') with torch.inference_mode(): with offload_kv_cache(model, device): with offload_weights(model, device): yield else: model.to(device) with torch.inference_mode(): yield ================================================ FILE: lmdeploy/logger.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. # modify from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/logger.py # noqa from typing import List, Optional from .messages import GenerationConfig from .utils import get_logger logger = get_logger('lmdeploy') class RequestLogger: """A class responsible for logging requests, ensuring that logs do not exceed a specified maximum length. Args: max_log_len (Optional[int]): The maximum length of the log entries. If None, no maximum length is enforced. """ def __init__(self, max_log_len: Optional[int]) -> None: self.max_log_len = max_log_len def log_prompt(self, session_id: int, prompt: str) -> None: if not isinstance(prompt, str): # Prompt may be a GPT4V message with base64 images; # logging might be impractical due to length return if self.max_log_len is not None: if prompt is not None: prompt = prompt[:self.max_log_len] logger.info(f'session={session_id}, ' f'prompt={prompt!r}') def log_inputs(self, session_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]], gen_config: GenerationConfig, adapter_name: str) -> None: max_log_len = self.max_log_len input_tokens = len(prompt_token_ids) if max_log_len is not None: if prompt is not None: prompt = prompt[:max_log_len] if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] logger.info(f'session={session_id}, ' f'adapter_name={adapter_name}, ' f'input_tokens={input_tokens}, ' f'gen_config={gen_config}, ' f'prompt={prompt!r}, ' f'prompt_token_id={prompt_token_ids}') ================================================ FILE: lmdeploy/messages.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import enum import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Literal, Optional import torch from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest from .tokenizer import Tokenizer from .utils import get_logger logger = get_logger('lmdeploy') LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] """LogitsProcessor is a function that takes a tensor of input_ids, the logits tensor for the next token, and returns a modified tensor of logits to sample from.""" @dataclass class GenerationConfig: """Generation parameters used by inference engines. Args: n: Define how many chat completion choices to generate for each input message. **Only 1** is supported now. max_new_tokens: The maximum number of tokens that can be generated in the chat completion do_sample: Whether or not to use sampling, use greedy decoding otherwise. Default to be False. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass top_k: An alternative to sampling with temperature, where the model considers the top_k tokens with the highest probability min_p: Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range (use the opposite of normal `top_p` values) temperature: Sampling temperature repetition_penalty: Penalty to prevent the model from generating repeated words or phrases. A value larger than 1 discourages repetition ignore_eos: Indicator to ignore the eos_token_id or not random_seed: Seed used when sampling a token stop_words: Words that stop generating further tokens bad_words: Words that the engine will never generate stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will not contain the stop tokens. bad_token_ids: List of tokens that the engine will never generate. min_new_tokens: The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. skip_special_tokens: Whether or not to remove special tokens in the decoding. Default to be True. spaces_between_special_tokens: Whether or not to add spaces around special tokens. The behavior of Fast tokenizers is to have this to False. This is setup to True in slow tokenizers. logprobs: Number of log probabilities to return per output token. response_format: Generate responses according to given formatting. Examples: .. code-block:: json { "type": "json_schema", "json_schema": { "name": "test", "schema": { "properties": { "name": { "type": "string" } }, "required": ["name"], "type": "object" } } } or, .. code-block:: json { "type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}" } logits_processors: Custom logit processors. repetition_ngram_size: The size of n-grams to consider for repetition early stop. repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop. """ n: int = 1 max_new_tokens: int = 512 do_sample: bool = False top_p: float = 1.0 top_k: int = 50 min_p: float = 0.0 temperature: float = 0.8 repetition_penalty: float = 1.0 ignore_eos: bool = False random_seed: int = None stop_words: List[str] = None bad_words: List[str] = None stop_token_ids: List[int] = None bad_token_ids: List[int] = None min_new_tokens: int = None skip_special_tokens: bool = True spaces_between_special_tokens: bool = True logprobs: int = None response_format: Optional[Dict] = None logits_processors: Optional[List[LogitsProcessor]] = None output_logits: Literal['all', 'generation'] = None output_last_hidden_state: Literal['all', 'generation'] = None include_stop_str_in_output: bool = False # for disaggregation with_cache: bool = False preserve_cache: bool = False migration_request: Optional[MigrationRequest] = None # router replay return_routed_experts: bool = False # ngram, generation would stop if latest [size] tokens are repeated for [threshold] times repetition_ngram_size: int = 0 repetition_ngram_threshold: int = 0 def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" def special_word_token_ids(words): if words is not None: assert isinstance(words, List) and \ all(isinstance(elem, str) for elem in words), \ f'stop_words must be a list of str but got {type(words)}' indexes = [] for word in words: indexes += tokenizer.indexes_containing_token(word) return indexes return None stop_token_ids = special_word_token_ids(self.stop_words) or [] bad_token_ids = special_word_token_ids(self.bad_words) or [] stop_token_ids.extend(self.stop_token_ids or []) bad_token_ids.extend(self.bad_token_ids or []) self.stop_token_ids = list(set(stop_token_ids)) or None self.bad_token_ids = list(set(bad_token_ids)) or None def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id): """Update the stop_token_ids.""" stop_token_ids = set(self.stop_token_ids or []) # add tokenizer's eos_token_id if tokenizer_eos_token_id is not None: stop_token_ids.add(tokenizer_eos_token_id) # add eos_token_id from model's generation_config.json file if there # is any. eos_token_id = generation_config.get('eos_token_id') if eos_token_id is not None: if isinstance(eos_token_id, int): stop_token_ids.add(eos_token_id) else: stop_token_ids.update(eos_token_id) self.stop_token_ids = list(stop_token_ids) def __post_init__(self): """Check input validation.""" assert type(self.n) == int and self.n > 0, 'n is not a positive integer' assert self.top_p >= 0 and self.top_p <= 1 # [0, 1] assert self.top_k >= 0, 'top_k can not be a negative integer' assert self.temperature >= 0 and self.temperature <= 2 # [0,2] assert 0 <= self.min_p <= 1, \ f'min_p should be in range [0, 1], but found {self.min_p}' @pydantic_dataclass class TurbomindEngineConfig: """TurboMind Engine config. Args: dtype: data type for model weights and activations. It can be one of the following values, ['auto', 'float16', 'bfloat16'] The `auto` option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. model_format: the layout of the deployed model. It can be one of the following values [hf, awq, gptq],`hf` meaning huggingface model(.bin, .safetensors), `awq` and `gptq` meaning the quantized model by AWQ and GPTQ, respectively. If it is not specified, i.e. None, it will be extracted from the input model tp: the number of GPU cards used in tensor parallelism, default to 1 session_len: the max session length of a sequence, default to None max_batch_size: the max batch size during inference. If it is not specified, the engine will automatically set it according to the device cache_max_entry_count: the percentage of gpu memory occupied by the k/v cache. For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it defaults to 0.5, depicting the percentage of TOTAL GPU memory to be allocated to the k/v cache. For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory to be reserved for the k/v cache. When it's an integer > 0, it represents the total number of k/v blocks. cache_chunk_size: The policy to apply for KV block from the block manager, default to -1. cache_block_seq_len: the length of the token sequence in a k/v block, default to 64 enable_prefix_caching: enable cache prompts for block reuse, default to False quant_policy: default to 0. When k/v is quantized into 4 or 8 bit, set it to 4 or 8, respectively rope_scaling_factor: scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention use_logn_attn: whether or not to use log attn: default to False download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. max_prefill_token_num: the number of tokens each iteration during prefill, default to 8192 num_tokens_per_iter: the number of tokens processed in each forward pass. Working with `max_prefill_iters` enables the "Dynamic SplitFuse"-like scheduling max_prefill_iters: the max number of forward pass during prefill stage async_: enable async execution, default to 1 (enabled) devices: the used devices empty_init: Whether to load the model weights, you should set it to True if you want to update weights after create the pipeline hf_overrides: Huggingface overrides for the model. It can be used to override the default config of the model enable_metrics: enable metrics system """ dtype: str = 'auto' model_format: Optional[str] = None tp: int = 1 dp: int = 1 cp: int = 1 device_num: int = None attn_tp_size: int = None attn_cp_size: int = None attn_dp_size: int = None mlp_tp_size: int = None mlp_dp_size: int = None outer_dp_size: int = None nnodes: int = 1 node_rank: int = 0 dist_init_addr: Optional[str] = None devices: List[int] = None session_len: Optional[int] = None max_batch_size: int = None cache_max_entry_count: float = 0.8 cache_chunk_size: int = -1 cache_block_seq_len: int = 64 enable_prefix_caching: bool = False quant_policy: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: bool = False download_dir: Optional[str] = None revision: Optional[str] = None max_prefill_token_num: int = 8192 num_tokens_per_iter: int = 0 max_prefill_iters: int = 1 async_: int = 1 devices: Optional[List[int]] = None empty_init: bool = False communicator: str = 'nccl' hf_overrides: Optional[Dict[str, Any]] = None enable_metrics: bool = True def __post_init__(self): """Check input validation.""" assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'tp must be a positive integer' assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor' assert self.max_prefill_token_num >= 0, \ 'invalid max_prefill_token_num' assert self.num_tokens_per_iter >= 0, 'invalid num_tokens_per_iter' assert self.async_ in (0, 1), 'async_ must be 0 (disabled) or 1 (enabled)' @dataclass class PytorchEngineConfig: """PyTorch Engine Config. Args: dtype: data type for model weights and activations. It can be one of the following values, ['auto', 'float16', 'bfloat16'] The `auto` option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. tp: Tensor Parallelism. default 1. dp: Data Parallelism. default 1. dp_rank: rank of dp. ep: Expert Parallelism. default 1. session_len: Max session length. Default None. max_batch_size: Max batch size. If it is not specified, the engine will automatically set it according to the device attn_tp_size: tp size for attention, only works for dp>1 mlp_tp_size: tp size for mlp, only works for dp>1 moe_tp_size: tp size for moe, only works for dp>1 cache_max_entry_count: the percentage of gpu memory occupied by the k/v cache. For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory to be reserved for the k/v cache prefill_interval: Interval to perform prefill, Default 16. block_size: paging cache block size, default 64. num_cpu_blocks: Num cpu blocks. If num is 0, cache would be allocate according to current environment. num_gpu_blocks: Num gpu blocks. If num is 0, cache would be allocate according to current environment. adapters: The path configs to lora adapters. max_prefill_token_num: tokens per iteration. thread_safe: thread safe engine instance. enable_prefix_caching: Enable token match and sharing caches. device_type: The inference device type, options ['cuda'] eager_mode: Enable "eager" mode or not custom_module_map: nn module map customized by users. Once provided, the original nn modules of the model will be substituted by the mapping ones download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. quant_policy: default to 0. When k/v is quantized into 4 or 8 bit, set it to 4 or 8, respectively distributed_executor_backend: backend of distributed backend, options: ['uni', 'mp', 'ray'] empty_init: Whether to load the model weights, you should set it to True if you want to update weights after create the pipeline enable_microbatch: enable microbatch for specified model enable_eplb: enable eplb for specified model enable_metrics: enable metrics system role: role of engin, options: ['Hybrid', 'Prefill', 'Decode']. Default to `EngineRole.Hybrid`. migration_backend: migration backend. options: ['DLSlime']. Default to `MigrationBackend.DLSlime`. enable_mp_engine: run engine in multi-process mode. mp_engine_backend: backend of mp engine, options: ['mp', 'ray']. Default to `mp`. model_format: weight quantization policy, options: ['fp8']. hf_overrides: Huggingface overrides for the model. It can be used to override the default config of the model, disable_vision_encoder: Whether to disable loading vision encoder. Default to False. logprobs_mode: The mode of logprob, options: ['raw_logits', 'raw_logprobs'] dllm_block_length: Block size of block diffusion model. dllm_unmasking_strategy: Dllm unmasking strategy, options: ['low_confidence_dynamic', 'low_confidence_static', 'sequential']. dllm_denoising_steps: Dllm denoising steps. dllm_confidence_threshold: dllm unmasking threshold for dynamic unmasking. """ dtype: str = 'auto' tp: int = 1 dp: int = 1 dp_rank: int = 0 ep: int = 1 session_len: int = None max_batch_size: int = None attn_tp_size: int = None mlp_tp_size: int = None moe_tp_size: int = None cache_max_entry_count: float = 0.8 prefill_interval: int = 16 block_size: int = 64 num_cpu_blocks: int = 0 num_gpu_blocks: int = 0 adapters: Dict[str, str] = None max_prefill_token_num: int = 4096 thread_safe: bool = False enable_prefix_caching: bool = False device_type: str = 'cuda' eager_mode: bool = False custom_module_map: Dict[str, str] = None download_dir: str = None revision: str = None quant_policy: Literal[0, 4, 8] = 0 distributed_executor_backend: str = None empty_init: bool = False enable_microbatch: bool = False enable_eplb: bool = False enable_mp_engine: bool = False mp_engine_backend: str = 'mp' model_format: str = None enable_metrics: bool = True hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False logprobs_mode: str = None # router replay enable_return_routed_experts: bool = False enable_transfer_obj_ref: bool = False # dllm dllm_block_length: int = None dllm_unmasking_strategy: str = 'low_confidence_dynamic' dllm_denoising_steps: int = None dllm_confidence_threshold: float = 0.85 role: EngineRole = EngineRole.Hybrid migration_backend: MigrationBackend = MigrationBackend.DLSlime def __post_init__(self): """Check input validation.""" assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'invalid tp' assert self.dp >= 1, 'invalid dp' assert self.ep >= 1, 'invalid ep' assert 0 < self.cache_max_entry_count < 1, \ 'invalid cache_max_entry_count' assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks' assert self.max_prefill_token_num >= 0, \ 'invalid max_prefill_token_num' assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}') assert self.block_size >= 16 and (self.block_size & (self.block_size - 1)) == 0, \ f'block_size must be >= 16 and a power of 2, but got {self.block_size}' if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']: assert False, \ 'kv cache quantization only works for CUDA and ASCEND.' if self.device_type == 'camb' and self.block_size != 16: self.block_size = 16 logger.warning('Currently, camb device requires block size to be 16, \ setting block size to 16') class ResponseType(enum.Enum): """Response type.""" SUCCESS = enum.auto() FINISH = enum.auto() ENGINE_STOP_ERROR = enum.auto() SESSION_REPEAT = enum.auto() SESSION_NOT_EXIST = enum.auto() HANDLER_NOT_EXIST = enum.auto() INPUT_LENGTH_ERROR = enum.auto() INTERNAL_ENGINE_ERROR = enum.auto() CANCEL = enum.auto() PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE = enum.auto() NO_QUEUE = enum.auto() @dataclass class Response: """Pack all response information together. Args: text: the response text from the server. If the output text is an empty str and the finish_reason is length, it means the session length is reached. generate_token_len: the response token length. input_token_len: the input prompt token length. Note that it may contains chat template part. session_id: the id for running the session. finish_reason: the reason the model stopped generating tokens. This will be 'stop' if the model hit a natural stop point or a provided stop sequence, 'length' if the maximum number of tokens specified in the request was reached. token_ids:: the output token ids. logprobs:: the top logprobs for each output position. index: it refers to the position index of the input request batch """ text: str generate_token_len: int input_token_len: int finish_reason: Optional[Literal['stop', 'length']] = None token_ids: List[int] = field(default_factory=list) logprobs: List[Dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None index: int = 0 routed_experts: Any = None def __str__(self): return f'text={self.text}\n{self._format_none_text_fields()}' def __repr__(self): return f'text={self.text!r}\n{self._format_none_text_fields()}' def _format_none_text_fields(self): fields = [] fields.append(f'input_token_len={self.input_token_len}') fields.append(f'generate_token_len={self.generate_token_len}') fields.append(f'finish_reason="{self.finish_reason}"') fields.append(f'token_ids={self.token_ids}') fields.append(f'logprobs={self.logprobs}') # Helper function to format tensor information def _format_tensor(name: str, tensor: Optional[torch.Tensor]) -> List[str]: if tensor is None: return [f'{name}=None'] try: return [f'{name}.shape={tensor.shape}', f'{name}={tensor}'] except: # noqa # in case tensor is not torch.Tensor or has no shape return [f'{name}={tensor}'] # Format tensor fields fields.extend(_format_tensor('logits', self.logits)) fields.extend(_format_tensor('last_hidden_state', self.last_hidden_state)) fields.extend(_format_tensor('routed_experts', self.routed_experts)) return '\n'.join(fields) def extend(self, other: 'Response') -> 'Response': """Extend this response with another response. This method merges the content of another Response into this one, similar to list.extend(). The text, token_ids, and logprobs are concatenated, while other fields are updated from the other response. Args: other: Another Response to append to this one. Returns: Self (for method chaining). """ self.text += other.text self.generate_token_len = other.generate_token_len self.input_token_len = other.input_token_len self.finish_reason = other.finish_reason self.index = other.index if other.token_ids: self.token_ids += other.token_ids if other.logprobs: self.logprobs = self.logprobs or [] self.logprobs += other.logprobs self.routed_experts = other.routed_experts return self # modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py class EventType(enum.IntEnum): """The type of request event. QUEUED - when the request was enqued by the engine SCHEDULED - when the request was first scheduled for execution PREEMPTED - the request has been put back in the waiting queue in order to make room for other requests to complete. It will be re-scheduled in future and re-start its prefill phase """ QUEUED = 1 SCHEDULED = 2 PREEMPTED = 3 # FIXME, currently ignored for simplicity # modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py @dataclass class EngineEvent: """A timestamped engine event associated with a request. Attributes: type: the type of an event associated with a request during its life cycle timestamp: the WALL-CLOCK time when the event happens. """ type: EventType timestamp: float @classmethod def new_event(cls, event_type: EventType, timestamp: Optional[float] = None) -> 'EngineEvent': # Timestamps MUST use wall-clock time (time.time()) to maintain consistency # between csrc(std::chrono::system_clock) and python timestamp = time.time() if timestamp is None else timestamp return cls(event_type, timestamp) @dataclass class ScheduleMetrics: active_seqs: int = 0 waiting_seqs: int = 0 total_blocks: int = 0 active_blocks: int = 0 cached_blocks: int = 0 free_blocks: int = 0 prefix_cache_hit_rate: float = 0 @dataclass class RequestMetrics: """Basic metrics for a request. Attributes: token_timestamp: A wall-clock time when a token is generated. engine_events: List of engine events during inference. """ token_timestamp: float = 0.0 engine_events: List[EngineEvent] = field(default_factory=list) spec_info: Optional[Dict[str, Any]] = None @dataclass class EngineOutput: """Engine output from turbomind/pytorch engine. Args: status: the response type. token_ids: the newly generated token ids in each iteration. logprobs: the top logprobs for each output position. cache_block_ids: send cache blocks back for migration in Disaggregated LLM Serving when Prefill Engine is Done. req_metrics: request metrics information """ status: ResponseType token_ids: List[int] logprobs: List[Dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None cache_block_ids: Optional[List[int]] = None req_metrics: Optional[RequestMetrics] = None routed_experts: torch.Tensor = None @dataclass class VisionConfig: """Vision model configs. Args: max_batch_size: the max image size passed to the model, since some models will use image patch, the actual running batch could be larger than this value. thread_safe: Specifies whether the engine instance is thread-safe. Please set it to True when using the pipeline in a multi-threaded environment. """ max_batch_size: int = 1 thread_safe: bool = False @dataclass class SpeculativeConfig: """Speculative decoding config. Args: method: the speculative decoding method. model: the path of speculative model. num_speculative_tokens: number of generated token of draft model per step """ method: str model: str = '' num_speculative_tokens: int = 1 ================================================ FILE: lmdeploy/metrics/__init__.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. ================================================ FILE: lmdeploy/metrics/loggers.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. # adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/loggers.py import time from abc import ABC, abstractmethod from datetime import datetime from typing import List import numpy as np from lmdeploy.metrics.stats import IterationStats, RequestStats, SchedulerStats, SpeculativeDecodingStats from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') class StatLoggerBase(ABC): @abstractmethod def record_schedule(self, stats: SchedulerStats) -> None: ... @abstractmethod def record_iteration(self, stats: IterationStats) -> None: ... @abstractmethod def record_specdecode(self, stats: SpeculativeDecodingStats) -> None: ... def log(self): # noqa pass class LoggingStatLogger(StatLoggerBase): def __init__(self, dp_rank: int = 0): self.dp_rank = dp_rank self._reset(time.perf_counter()) self.last_scheduler_stats = SchedulerStats() def _reset(self, now): self.last_log_time = now self.total_prompt_tokens = 0 self.total_generation_tokens = 0 # spec decode self.num_drafts: int = 0 self.num_draft_tokens: int = 0 self.num_accepted_tokens: int = 0 self.num_accepted_tokens_per_pos: np.ndarray = None def record_schedule(self, stats: SchedulerStats): self.last_scheduler_stats = stats def record_iteration(self, stats: IterationStats): # In the first iteration of a sequence, stats.prompt_tokens is the # prompt token number of a sequence. In subsequent iterations, # the value is 0. This enables cumulative counting in `total_prompt_tokens` self.total_prompt_tokens += stats.prompt_tokens self.total_generation_tokens += stats.new_generation_tokens def record_specdecode(self, stats: SpeculativeDecodingStats): """Record spec decoding stats.""" if stats.num_drafts <= 0: return if self.num_accepted_tokens_per_pos is None: self.num_accepted_tokens_per_pos = np.zeros(stats.num_spec_tokens) self.num_drafts += stats.num_drafts self.num_draft_tokens += stats.num_draft_tokens self.num_accepted_tokens += stats.num_accepted_tokens self.num_accepted_tokens_per_pos += stats.num_accepted_tokens_per_pos def record_finish(self, stats: RequestStats): pass def get_spec_msg(self): """Get spec decoding logging msg.""" if self.num_drafts == 0: return None draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens * 100 if self.num_draft_tokens > 0 else float('nan')) # conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (self.num_accepted_tokens / self.num_drafts) acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates) log_msg = ('SpecDecoding metrics: ' f'Draft acceptance rate: {draft_acceptance_rate:.2f}%, ' f'Mean acceptance length: {mean_acceptance_length:.2f}, ' f'Accepted: {self.num_accepted_tokens} tokens, ' f'Drafted: {self.num_draft_tokens} tokens, ' f'Per-position acceptance rate: {rates_str}') return log_msg def log(self): now = time.perf_counter() # skip logging if no tokens were processed if self.total_prompt_tokens == 0 and self.total_generation_tokens == 0: self._reset(now) return # derive log information prompt_throughput = self.total_prompt_tokens / (now - self.last_log_time) generation_throughput = self.total_generation_tokens / (now - self.last_log_time) scheduler_stats = self.last_scheduler_stats scheduler_stats.num_api_waiting_reqs = scheduler_stats.num_total_reqs - \ scheduler_stats.num_completed_reqs - scheduler_stats.num_api_routed_reqs spec_msg = self.get_spec_msg() # format and print log_msg = ( f"[{datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')} DP{self.dp_rank}] " f'Avg thr (in/out): {prompt_throughput:.1f} / {generation_throughput:.1f} tokens/s, ' f'API server (completed/routed/waiting): {scheduler_stats.num_completed_reqs} / ' f'{scheduler_stats.num_api_routed_reqs} / {scheduler_stats.num_api_waiting_reqs}, ' f'Engine (running/waiting): {scheduler_stats.num_running_reqs} / {scheduler_stats.num_waiting_reqs}, ' f'KV cache: {scheduler_stats.gpu_cache_usage * 100 :.1f}%, ') if scheduler_stats.prefix_cache_hit_rate != 0: log_msg += f'Prefix cache hit rate: {scheduler_stats.prefix_cache_hit_rate * 100 :.1f}%, ' if spec_msg is not None: log_msg += spec_msg print(log_msg, flush=True) self._reset(now) class PrometheusStatLogger(StatLoggerBase): def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0): try: import prometheus_client prometheus_client.disable_created_metrics() # disable noisy creation timestamp gauge in prometheus except ImportError: raise ImportError( 'To use metrics system , please install prometheus_client by `pip install prometheus_client`') self.dp_rank = dp_rank # unregister any existing lmdeploy collectors for collector in list(prometheus_client.REGISTRY._collector_to_names): if hasattr(collector, '_name') and 'lmdeploy' in collector._name: prometheus_client.REGISTRY.unregister(collector) # config information self.info_backend_config = prometheus_client.Info(name='lmdeploy:backend_config', documentation='information of backend_config') labelnames = ['model_name', 'engine'] labelvalues = [model_name, str(dp_rank)] # # Scheduler stats # self.gauge_scheduler_completed = prometheus_client.Gauge(name='lmdeploy:num_requests_completed', documentation='Number of current completed requests.', labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_api_routed = prometheus_client.Gauge( name='lmdeploy:num_api_requests_routed', documentation='Number of requests routed to request handles.', labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_api_waiting = prometheus_client.Gauge( name='lmdeploy:num_api_requests_waiting', documentation='Number of requests waiting for free request handles.', labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_running = prometheus_client.Gauge( name='lmdeploy:num_requests_running', documentation='Number of requests in model execution batches.', labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = prometheus_client.Gauge( name='lmdeploy:num_requests_waiting', documentation='Number of requests waiting to be processed.', labelnames=labelnames).labels(*labelvalues) # # GPU cache # self.gauge_gpu_cache_usage = prometheus_client.Gauge( name='lmdeploy:gpu_cache_usage_perc', documentation='GPU KV-cache usage. 1 means 100 percent usage.', labelnames=labelnames).labels(*labelvalues) # # Counters # self.counter_prompt_tokens = prometheus_client.Counter(name='lmdeploy:prompt_tokens_total', documentation='Number of prefill tokens processed.', labelnames=labelnames).labels(*labelvalues) self.counter_generation_tokens = prometheus_client.Counter( name='lmdeploy:generation_tokens_total', documentation='Number of generation tokens processed.', labelnames=labelnames).labels(*labelvalues) from lmdeploy.messages import ResponseType self.counter_request_success: dict[ResponseType, prometheus_client.Counter] = {} counter_request_success_base = prometheus_client.Counter( name='lmdeploy:request_success_total', documentation='Count of successfully processed requests.', labelnames=labelnames + ['finished_reason']) for reason in ResponseType: self.counter_request_success[reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)])) # # Histograms of counts # self.histogram_num_prompt_tokens_request = \ prometheus_client.Histogram( name='lmdeploy:request_prompt_tokens', documentation='Number of prefill tokens processed.', buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_num_generation_tokens_request = \ prometheus_client.Histogram( name='lmdeploy:request_generation_tokens', documentation='Number of generation tokens processed.', buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_iteration_tokens = \ prometheus_client.Histogram( name='lmdeploy:iteration_tokens_total', documentation='Histogram of number of tokens per engine_step.', buckets=[ 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 ], labelnames=labelnames).labels(*labelvalues) # # Histogram of timing intervals # self.histogram_time_to_first_token = \ prometheus_client.Histogram( name='lmdeploy:time_to_first_token_seconds', documentation='Histogram of time to first token in seconds.', buckets=[ 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0 ], labelnames=labelnames).labels(*labelvalues) self.histogram_time_per_output_token = \ prometheus_client.Histogram( name='lmdeploy:time_per_output_token_seconds', documentation='Histogram of time per output token in seconds.', buckets=[ 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 ], labelnames=labelnames).labels(*labelvalues) self.histogram_iter_token_latency = \ prometheus_client.Histogram( name='lmdeploy:iter_token_latency', documentation='Histogram of inter-token latency', buckets=[ 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 ], labelnames=labelnames).labels(*labelvalues) request_latency_buckets = [ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 ] self.histogram_e2e_time_request = \ prometheus_client.Histogram( name='lmdeploy:e2e_request_latency_seconds', documentation='Histogram of e2e request latency in seconds.', buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_queue_time_request = \ prometheus_client.Histogram( name='lmdeploy:request_queue_time_seconds', documentation='Histogram of time spent in WAITING phase for request.', buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_inference_time_request = \ prometheus_client.Histogram( name='lmdeploy:request_inference_time_seconds', documentation='Histogram of time spent in RUNNING phase for request.', buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_prefill_time_request = \ prometheus_client.Histogram( name='lmdeploy:request_prefill_time_seconds', documentation='Histogram of time spent in PREFILL phase for request.', buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_decode_time_request = \ prometheus_client.Histogram( name='lmdeploy:request_decode_time_seconds', documentation='Histogram of time spent in DECODE phase for request.', buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) def record_schedule(self, stats: SchedulerStats) -> None: """Report schedule metrics to prometheus.""" self.gauge_scheduler_completed.set(stats.num_completed_reqs) self.gauge_scheduler_api_routed.set(stats.num_api_routed_reqs) self.gauge_scheduler_api_waiting.set(stats.num_total_reqs - stats.num_completed_reqs - stats.num_api_routed_reqs) self.gauge_scheduler_running.set(stats.num_running_reqs) self.gauge_scheduler_waiting.set(stats.num_waiting_reqs) self.gauge_gpu_cache_usage.set(stats.gpu_cache_usage) def record_iteration(self, stats: IterationStats) -> None: """Report token-related metrics to prometheus.""" self.counter_prompt_tokens.inc(stats.prompt_tokens) self.counter_generation_tokens.inc(stats.new_generation_tokens) self.histogram_iteration_tokens.observe(stats.prompt_tokens + stats.new_generation_tokens) if stats.ttft: self.histogram_time_to_first_token.observe(stats.ttft) if stats.tpot: self.histogram_time_per_output_token.observe(stats.tpot) if stats.itl: self.histogram_iter_token_latency.observe(stats.itl) def record_finish(self, stats: RequestStats) -> None: self.counter_request_success[stats.finish_reason].inc() self.histogram_e2e_time_request.observe(stats.e2e_latency) self.histogram_queue_time_request.observe(stats.queued_time_interval) self.histogram_prefill_time_request.observe(stats.prefill_time_interval) self.histogram_inference_time_request.observe(stats.inference_time_interval) self.histogram_decode_time_request.observe(stats.decode_time_interval) self.histogram_num_prompt_tokens_request.observe(stats.prompt_tokens) self.histogram_num_generation_tokens_request.observe(stats.generation_tokens) def record_specdecode(self, stats: SpeculativeDecodingStats) -> None: pass def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: """Builds a list of buckets with increasing powers of 10 multiplied by mantissa values until the value exceeds the specified maximum.""" exponent = 0 buckets: List[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent if value <= max_value: buckets.append(value) else: return buckets exponent += 1 def build_1_2_5_buckets(max_value: int) -> List[int]: """ Example: >>> build_1_2_5_buckets(100) [1, 2, 5, 10, 20, 50, 100] """ return build_buckets([1, 2, 5], max_value) ================================================ FILE: lmdeploy/metrics/metrics_processor.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import asyncio from lmdeploy.messages import ResponseType, ScheduleMetrics from lmdeploy.pytorch.utils import singleton from lmdeploy.utils import get_logger from .stats import SchedulerStats logger = get_logger('lmdeploy') @singleton class MetricsProcessor(): """Metrics processor.""" def __init__(self): """Init metrics processor.""" self.enable_metrics: bool = False self.scheduler_stats = SchedulerStats() self.stat_loggers = [] self.metrics_queue: asyncio.Queue = None self.metrics_handler: asyncio.Task = None def start_metrics_handler(self, enable_metrics: bool): """Start metrics handler.""" self.enable_metrics = enable_metrics if enable_metrics and self.metrics_handler is None: self.metrics_queue = asyncio.Queue() self.metrics_handler = asyncio.create_task(self._run_metrics_handler()) logger.info('Metrics handler task started.') async def stop_metrics_handler(self): """Stop metrics handler.""" if self.metrics_handler is not None: self.metrics_handler.cancel() try: await self.metrics_handler except asyncio.CancelledError: pass # Expected cancellation finally: self.metrics_handler = None logger.info('Metrics handler task stopped.') async def _run_metrics_handler(self): """A background task that consumes and processes metrics data.""" while True: try: # fetch data from the queue update_data = await self.metrics_queue.get() outputs, req_stats, iteration_stats, specdecode_stats = update_data # update request stats if outputs and outputs.req_metrics: # when users visit "/abort_request" endpoint, `req_metrics` might be None req_stats.update_from_events(outputs.req_metrics.engine_events) # update iteration stats # some attributes of req_stats will also be updated, e.g., lastest_token_time iteration_stats.update_from_output(outputs, req_stats) # update spec decode stats if specdecode_stats is not None: specdecode_stats.update_from_output(outputs) # record iteration stats for stat_logger in self.stat_loggers: stat_logger.record_iteration(iteration_stats) if specdecode_stats is not None: stat_logger.record_specdecode(specdecode_stats) # record finished request stats if outputs.status == ResponseType.FINISH: for stat_logger in self.stat_loggers: stat_logger.record_finish(req_stats) self.metrics_queue.task_done() except asyncio.CancelledError: break except Exception as e: logger.exception(f'Metrics handler background task failed: {e}') async def update_schedule_stats(self, schedule_metrics: ScheduleMetrics): """Update schedule stats.""" self.scheduler_stats.update_from_schedule_metrics(schedule_metrics) # record schedule stats for stat_logger in self.stat_loggers: stat_logger.record_schedule(self.scheduler_stats) def queue_update(self, update_data: tuple): """Queue update.""" if not self.enable_metrics or self.metrics_queue is None: return self.metrics_queue.put_nowait(update_data) def increase_total_requests(self): """Increase total requests.""" self.scheduler_stats.num_total_reqs += 1 def increase_completed_requests(self): """Increase completed requests.""" self.scheduler_stats.num_completed_reqs += 1 def increase_api_routed_requests(self): """Increase API routed requests.""" self.scheduler_stats.num_api_routed_reqs += 1 def decrease_api_routed_requests(self): """Decrease API routed requests.""" self.scheduler_stats.num_api_routed_reqs -= 1 metrics_processor = MetricsProcessor() ================================================ FILE: lmdeploy/metrics/stats.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/stats.py import time from dataclasses import dataclass from typing import List, Optional import numpy as np from lmdeploy.messages import EngineEvent, EngineOutput, ResponseType, ScheduleMetrics @dataclass class SchedulerStats: """Stats associated with the scheduler. Desc: Dataflow: client --> API server --> Engine core API server total = completed + uncompleted = completed + (api_routed + api_waiting) Engine core total = running + waiting = api_routed Attributes: num_total_reqs: API server, the number of all requests received since server start. num_completed_reqs: API server, the number of successfully completed requests since server start. num_api_routed_reqs: API server, the number of requests routed to request handles. num_api_waiting_reqs: API server, the number of requests waiting for free request handles. num_running_reqs: Engine core, currently executing requests. num_waiting_reqs: Engine core, requests queued waiting for execution. gpu_cache_usage: Fraction of GPU KV blocks utilized (0.0 to 1.0). prefix_cache_hit_rate: Prefix caching hit rate. """ # api server num_total_reqs: int = 0 num_completed_reqs: int = 0 num_api_routed_reqs: int = 0 num_api_waiting_reqs: int = 0 # engine core num_running_reqs: int = 0 num_waiting_reqs: int = 0 gpu_cache_usage: float = 0.0 prefix_cache_hit_rate: float = 0.0 def __repr__(self): return ('SchedulerStats(\n' f' num_total_reqs={self.num_total_reqs},\n' f' num_completed_reqs={self.num_completed_reqs},\n' f' num_api_routed_reqs={self.num_api_routed_reqs},\n' f' num_api_waiting_reqs={self.num_api_waiting_reqs},\n' f' num_running_reqs={self.num_running_reqs},\n' f' num_waiting_reqs={self.num_waiting_reqs},\n' f' gpu_cache_usage={self.gpu_cache_usage:.6f},\n' f' prefix_cache_hit_rate={self.prefix_cache_hit_rate:.6f},\n' ')') def update_from_schedule_metrics(self, scheduled_metrics: ScheduleMetrics): self.num_running_reqs = scheduled_metrics.active_seqs self.num_waiting_reqs = scheduled_metrics.waiting_seqs self.gpu_cache_usage = 1.0 - (scheduled_metrics.free_blocks / scheduled_metrics.total_blocks) self.prefix_cache_hit_rate = scheduled_metrics.prefix_cache_hit_rate class RequestStats: """Stats associated with a request.""" def __init__(self, arrival_time: float = None, prompt_tokens: int = 0): """Initialize the stats of a request. Args: arrival_time (float, optional): The timestamp when the request arrives. If not provided, the current time will be used. Defaults to None. prompt_tokens (int, optional): The number of tokens in the prompt. Defaults to 0. Attributes: generation_tokens (int): The number of tokens generated during the request inference. It will be updated by IterationStats.update_from_output. queued_time (float): Time when the request is put to the inference engine's queue. It will be updated according the EngineEvent. scheduled_time (float): Time when the request is scheduled to run. It will be updated according the EngineEvent. first_token_time (float): Time when the first token is generated. It will be updated by IterationStats.update_from_output. lastest_token_time (float): Time when the latest token is generated. It will be updated by IterationStats.update_from_output. finish_time (float): Time when a request finishes generation. It will be updated by IterationStats.update_from_output. finish_reason (ResponseType): The reason why the request finished. """ self.arrival_time = time.time() if arrival_time is None else arrival_time self.prompt_tokens = prompt_tokens self.generation_tokens: int = 0 self.queued_time: float = 0.0 self.scheduled_time: float = 0.0 self.first_token_time: float = 0.0 self.lastest_token_time: float = 0.0 self.finish_time: float = 0.0 self.finish_reason: ResponseType = None def __repr__(self): return ('RequestStats(\n' f' arrival_time={self.arrival_time:.6f},\n' f' prompt_tokens={self.prompt_tokens},\n' f' generation_tokens={self.generation_tokens},\n' f' queued_time={self.queued_time:.6f},\n' f' scheduled_time={self.scheduled_time:.6f},\n' f' first_token_time={self.first_token_time:.6f},\n' f' latest_token_time={self.lastest_token_time:.6f},\n' ')') def update_from_events(self, engine_events: List[EngineEvent]): # avoid circular dependency from lmdeploy.messages import EventType for event in engine_events: if event.type == EventType.QUEUED: self.queued_time = event.timestamp elif event.type == EventType.SCHEDULED: if self.scheduled_time == 0.0: # ignore preemptions self.scheduled_time = event.timestamp # FIXME: deal with preempted case # elif event.type == EventType.PREEMPTED: # self.num_preempted_reqs += 1 @property def e2e_latency(self) -> float: """End-to-end latency.""" return self.finish_time - self.arrival_time @property def queued_time_interval(self) -> float: """Queued interval is from first QUEUED event to first SCHEDULED.""" return self.scheduled_time - self.queued_time @property def prefill_time_interval(self) -> float: """Prefill interval is from first SCHEDULED to first NEW_TOKEN. Any preemptions during prefill is included in the interval. """ return self.first_token_time - self.scheduled_time @property def decode_time_interval(self) -> float: """Decode interval is from first NEW_TOKEN to last NEW_TOKEN. Any preemptions during decode are included. """ return self.finish_time - self.first_token_time @property def inference_time_interval(self) -> float: """Inference interval is from first SCHEDULED to last NEW_TOKEN. Any preemptions during prefill or decode are included. """ return self.finish_time - self.scheduled_time class IterationStats: """Stats associated with one token generation iteration of a request.""" def __init__(self): """Initialize the stats of one iteration. Attributes: iteration_timestamp (float): The timestamp when this iteration finishes. new_generation_tokens (int): The number of newly generated tokens in this iteration. prompt_tokens (int): The number of prompt tokens processed in this iteration. ttft (float | None): Time to First Token (TTFT). tpot (float | None): Time per Output Token (TPOT). itl (float | None): Iter-Token Latency (ITL). """ self.iteration_timestamp = time.time() self.new_generation_tokens = 0 self.prompt_tokens = 0 self.ttft: Optional[float] = None self.tpot: Optional[float] = None self.itl: Optional[float] = None def __repr__(self): return ('IterationStats(\n' f' iteration_timestamp={self.iteration_timestamp:.6f},\n' f' new_generation_tokens={self.new_generation_tokens},\n' f' prompt_tokens={self.prompt_tokens},\n' f' ttft={self.ttft},\n' f' tpot={self.tpot},\n' f' itl={self.itl},\n' ')') def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" return self.iteration_timestamp - start def update_from_output(self, outputs: EngineOutput, req_stats: RequestStats): """Update the iteration statistics. Args: outputs (EngineOutput): The output from the engine containing information about the current iteration. req_stats (RequestStats): The stats of the request, including timestamps and token counts. """ if outputs.req_metrics is None: # when users visit "/abort_request" endpoint, `req_metrics` might be None return new_generation_tokens = len(outputs.token_ids) if new_generation_tokens == 0: return self.new_generation_tokens = new_generation_tokens if req_stats.first_token_time == 0: # the first token is generated in this iteration req_stats.first_token_time = outputs.req_metrics.token_timestamp self.prompt_tokens = req_stats.prompt_tokens self.ttft = self._time_since(req_stats.arrival_time) else: self.itl = self._time_since(req_stats.lastest_token_time) self.tpot = self._time_since(req_stats.lastest_token_time) / self.new_generation_tokens req_stats.lastest_token_time = outputs.req_metrics.token_timestamp req_stats.generation_tokens += new_generation_tokens if outputs.status != ResponseType.SUCCESS: req_stats.finish_reason = outputs.status req_stats.finish_time = self.iteration_timestamp # modify from vllm @dataclass class SpeculativeDecodingStats: """Speculative decoding stats.""" num_spec_tokens: int num_drafts: int = 0 num_draft_tokens: int = 0 num_accepted_tokens: int = 0 num_accepted_tokens_per_pos: np.ndarray = None def __post_init__(self): assert self.num_spec_tokens > 0 self.num_accepted_tokens_per_pos = np.zeros(self.num_spec_tokens) def update_from_output(self, outputs: EngineOutput): """Update from engine output.""" spec_info = getattr(outputs.req_metrics, 'spec_info', None) if spec_info: self.num_drafts += 1 self.num_draft_tokens += spec_info['num_draft_tokens'] self.num_accepted_tokens += spec_info['num_accepted_tokens'] self.num_accepted_tokens_per_pos[:spec_info['num_accepted_tokens']] += 1 def update_per_draft(self, num_draft_tokens: int, num_accepted_tokens: int): """Update with per draft stats.""" if num_draft_tokens > 0: self.num_drafts += 1 self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens self.num_accepted_tokens_per_pos[:num_accepted_tokens] += 1 def __repr__(self): draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens * 100 if self.num_draft_tokens > 0 else float('nan')) # conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (self.num_accepted_tokens / self.num_drafts) if self.num_drafts > 0 else float('nan') acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts if self.num_drafts > 0 else [ float('nan') ] * self.num_accepted_tokens rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates) return ('SpeculativeDecodingStats(' f'num_spec_tokens={self.num_spec_tokens}, ' f'num_drafts={self.num_drafts}, ' f'num_draft_tokens={self.num_draft_tokens}, ' f'num_accepted_tokens={self.num_accepted_tokens}, ' f'draft_acceptance_rate={draft_acceptance_rate:.2f}%, ' f'mean_acceptance_length={mean_acceptance_length:.2f}, ' f'per_position_acceptance_rate={rates_str})') ================================================ FILE: lmdeploy/model.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import dataclasses import json import uuid from typing import List, Literal, Optional, Union from mmengine import Registry from lmdeploy.archs import get_model_arch from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') MODELS = Registry('model', locations=['lmdeploy.model']) def random_uuid() -> str: """Return a random uuid.""" return str(uuid.uuid4().hex) def get_text(content: Union[str, List[dict]]): """Within the OpenAI API, the content field may be specified as either a string or a list of ChatCompletionContentPartTextParam (defined in openai). When a list is provided, lmdeploy selects the first element to incorporate into the chat template, as the manner in which OpenAI processes lists is not explicitly defined. """ if isinstance(content, str): return content return content[0]['text'] @dataclasses.dataclass class ChatTemplateConfig: """Parameters for chat template. Args: model_name (str): the name of the deployed model. Determine which chat template will be applied. All the chat template names: `lmdeploy list` system (str | None): begin of the system prompt meta_instruction (str | None): system prompt eosys (str | None): end of the system prompt user (str | None): begin of the user prompt eoh (str | None): end of the user prompt assistant (str | None): begin of the assistant prompt eoa (str | None): end of the assistant prompt tool (str | None): begin of the tool prompt eotool (str | None): end of the tool prompt capability: ('completion' | 'infilling' | 'chat' | 'python') = None """ # noqa: E501 model_name: str model_path: Optional[str] = None system: Optional[str] = None meta_instruction: Optional[str] = None eosys: Optional[str] = None user: Optional[str] = None eoh: Optional[str] = None assistant: Optional[str] = None eoa: Optional[str] = None tool: Optional[str] = None eotool: Optional[str] = None separator: Optional[str] = None capability: Optional[Literal['completion', 'infilling', 'chat', 'python']] = None stop_words: Optional[List[str]] = None @property def chat_template(self): attrs = {key: value for key, value in dataclasses.asdict(self).items() if value is not None} attrs.pop('model_name', None) if self.model_name in MODELS.module_dict.keys(): model = MODELS.get(self.model_name)(**attrs) else: logger.warning(f'Could not find {self.model_name} in registered models. ' f'Register {self.model_name} using the BaseChatTemplate.') model = BaseChatTemplate(**attrs) return model def to_json(self, file_path=None): """Convert the dataclass instance to a JSON formatted string and optionally save to a file.""" json_str = json.dumps(dataclasses.asdict(self), ensure_ascii=False, indent=4) if file_path: with open(file_path, 'w', encoding='utf-8') as file: file.write(json_str) return json_str @classmethod def from_json(cls, file_or_string): """Construct a dataclass instance from a JSON file or JSON string.""" try: # Try to open the input_data as a file path with open(file_or_string, 'r', encoding='utf-8') as file: json_data = file.read() except FileNotFoundError: # If it's not a file path, assume it's a JSON string json_data = file_or_string except IOError: # If it's not a file path and not a valid JSON string, raise error raise ValueError('Invalid input. Must be a file path or a valid JSON string.') json_data = json.loads(json_data) if json_data.get('model_name', None) is None: json_data['model_name'] = random_uuid() if json_data['model_name'] not in MODELS.module_dict.keys(): MODELS.register_module(json_data['model_name'], module=BaseChatTemplate) return cls(**json_data) @MODELS.register_module(name='base') class BaseChatTemplate: """Base Chat template.""" def __init__(self, system='', meta_instruction='', eosys='', user='', eoh='', assistant='', eoa='', separator='', tool='', eotool='', capability='chat', stop_words=None, **kwargs): self.system = system self.meta_instruction = meta_instruction self.user = user self.eoh = eoh self.eoa = eoa self.separator = separator self.eosys = eosys self.assistant = assistant self.tool = tool self.eotool = eotool self.stop_words = stop_words self.capability = capability def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the chat template. Args: prompt (str): user's input prompt sequence_start (bool): indicator for the first round chat of a session sequence Returns: str: the concatenated prompt """ if self.capability == 'completion': return prompt if sequence_start: # None is different from '' if self.meta_instruction is not None: return f'{self.system}{self.meta_instruction}{self.eosys}' \ f'{self.user}{prompt}{self.eoh}' \ f'{self.assistant}' else: return f'{self.user}{prompt}{self.eoh}' \ f'{self.assistant}' else: return f'{self.separator}{self.user}{prompt}{self.eoh}' \ f'{self.assistant}' def messages2prompt(self, messages, sequence_start=True, **kwargs): """Return the prompt that is concatenated with other elements in the chat template. Args: messages (str | List): user's input prompt Returns: str: the concatenated prompt """ if isinstance(messages, str): return self.get_prompt(messages, sequence_start) box_map = dict(user=self.user, assistant=self.assistant, system=self.system, tool=self.tool) eox_map = dict(user=self.eoh, assistant=self.eoa + self.separator, system=self.eosys, tool=self.eotool) ret = '' if self.meta_instruction is not None and sequence_start: if len(messages) and messages[0]['role'] != 'system': ret += f'{self.system}{self.meta_instruction}{self.eosys}' for message in messages: role = message['role'] content = get_text(message['content']) ret += f'{box_map[role]}{content}{eox_map[role]}' if len(messages) and messages[-1]['role'] == 'assistant' and len(eox_map['assistant']) > 0: return ret[:-len(eox_map['assistant'])] # prefix of response ret += f'{self.assistant}' return ret @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ return None @MODELS.register_module(name='cogvlm') class CogVLM(BaseChatTemplate): """Chat template of CogVLM model.""" def __init__(self, meta_instruction='', eosys='', user='Question: ', separator='\n', eoh=' ', assistant='Answer:', eoa='', stop_words=[''], **kwargs): super().__init__(meta_instruction=meta_instruction, eosys=eosys, user=user, eoh=eoh, separator=separator, assistant=assistant, eoa=eoa, stop_words=stop_words, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ path = model_path.lower() if 'cogvlm' in path and 'cogvlm2' not in path: return 'cogvlm' @MODELS.register_module(name='vicuna') class Vicuna(BaseChatTemplate): """Chat template of vicuna model.""" def __init__( self, meta_instruction="""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""", # noqa: E501 eosys=' ', user='USER: ', eoh=' ', assistant='ASSISTANT: ', eoa='', stop_words=[''], **kwargs): super().__init__(meta_instruction=meta_instruction, eosys=eosys, user=user, eoh=eoh, assistant=assistant, eoa=eoa, stop_words=stop_words, **kwargs) def get_prompt(self, prompt, sequence_start=True): if self.capability == 'chat': return super().get_prompt(prompt, sequence_start)[:-1] return super().get_prompt(prompt, sequence_start) def messages2prompt(self, messages, sequence_start=True, **kwargs): if isinstance(messages, str): return self.get_prompt(messages, sequence_start) return super().messages2prompt(messages, sequence_start, **kwargs)[:-1] @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ path = model_path.lower() if 'vicuna' in path and 'llava' not in path: return 'vicuna' if 'wizardlm' in path: return 'wizardlm' @MODELS.register_module(name='llava-v1') class Llavav1(Vicuna): """Chat template of llava-v1 model.""" def __init__( self, meta_instruction="""A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.""", # noqa: E501 **kwargs): super().__init__(meta_instruction=meta_instruction, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ path = model_path.lower() if 'llava' in path and 'v1' in path and 'v1.6-34b' not in path \ and 'mistral' not in path: return 'llava-v1' elif 'llava-1.5' in path: return 'llava-v1' @MODELS.register_module(name='internlm') class InternLMChat7B(BaseChatTemplate): """Chat template of InternLM model.""" def __init__( self, system='<|System|>:', meta_instruction="""You are an AI assistant whose name is InternLM (书生·浦语). - InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. - InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文. """, # noqa: E501 eosys='\n', user='<|User|>:', eoh='\n', assistant='<|Bot|>:', eoa='', separator='\n', stop_words=[''], **kwargs): super().__init__(system=system, meta_instruction=meta_instruction, eosys=eosys, user=user, eoh=eoh, assistant=assistant, eoa=eoa, separator=separator, stop_words=stop_words, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ path = model_path.lower() if all([c not in path for c in ['internlm3', 'internlm2', '8k']]) and \ all([c in path for c in ['internlm', 'chat']]): return 'internlm' @MODELS.register_module(name='baichuan2') class Baichuan2(BaseChatTemplate): """Chat template and generation parameters of Baichuan2-7B-Base and Baichuan2-7B-Chat models.""" def __init__(self, user='', assistant='', **kwargs): super().__init__(user=user, assistant=assistant, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ path = model_path.lower() if 'baichuan2' in path and 'chat' in path: return 'baichuan2' @MODELS.register_module(name='llama2') class Llama2(BaseChatTemplate): """Chat template of LLaMA2 model.""" def __init__( self, system='[INST] <>\n', meta_instruction="""\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501 eosys='\n<>\n\n', assistant=' [/INST] ', eoa='', separator='[INST] ', session_len=4096, **kwargs): super().__init__(system=system, meta_instruction=meta_instruction, eosys=eosys, assistant=assistant, eoa=eoa, separator=separator, session_len=session_len, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: """Return the model_name that was registered to MODELS. Args: model_path (str): the model path used for matching. """ if 'llama-2' in model_path.lower() or 'llama2' in model_path.lower(): return 'llama2' @MODELS.register_module(name='codellama') class CodeLlama(Llama2): def __init__(self, meta_instruction='', suffix_first=False, stop_words=None, **kwargs): super().__init__(meta_instruction=meta_instruction, stop_words=stop_words, **kwargs) caps = ['completion', 'infilling', 'chat', 'python'] assert self.capability in caps, \ f'{self.capability} is not supported. ' \ f'The supported capabilities are: {caps}' self.meta_instruction = meta_instruction self.suffix_first = suffix_first self.stop_words = stop_words if self.capability == 'infilling': if self.stop_words is None: self.stop_words = [''] def get_prompt(self, prompt, sequence_start=True): if self.capability == 'infilling': return self._infill_prompt(prompt) elif self.capability == 'chat': return super().get_prompt(prompt, sequence_start) else: # python speicalist return prompt def _infill_prompt(self, prompt): prefix, suffix = prompt.split('') if self.suffix_first: # format as "
 {suf}  {pre}"
            prompt = f'
 {suffix}  {prefix}'
        else:
            # format as "
 {pre} {suf} "
            prompt = f'
 {prefix} {suffix} '
        return prompt

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        if 'codellama' in model_path.lower():
            return 'codellama'


@MODELS.register_module(name='chatglm')
class ChatGLM2(BaseChatTemplate):

    def __init__(self, user='问:', eoh='\n\n', assistant='答:', eoa='\n\n', **kwargs):
        super().__init__(**kwargs)
        self._user = user
        self._assistant = assistant
        self._eoh = eoh
        self._eoa = eoa
        self.count = 0

    def get_prompt(self, prompt, sequence_start=True):
        """Get prompt."""
        # need more check
        # https://github.com/THUDM/ChatGLM2-6B/issues/48
        # [64790, 64792] to be prepended
        self.count += 1
        ret = f'[Round {self.count}]\n\n'
        ret += f'{self._user}{prompt}{self._eoh}'
        ret += f'{self._assistant}'
        return ret

    def messages2prompt(self, messages, sequence_start=True, **kwargs):
        """Message to prompt."""
        if isinstance(messages, str):
            return self.get_prompt(messages, sequence_start)
        ret = ''
        count = 0
        for message in messages:
            role = message['role']
            content = get_text(message['content'])
            if role == 'user':
                count += 1
                ret += f'[Round {count}]\n\n'
                ret += f'{self._user}{content}{self._eoh}'
                ret += f'{self._assistant}'
            if role == 'assistant':
                ret += f'{content}'
        return ret

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        path = model_path.lower()
        if 'chatglm2' in path:
            return 'chatglm'


@MODELS.register_module(name=['mistral', 'mixtral'])
class MistralChat(BaseChatTemplate):
    """Template of Mistral and Mixtral Instruct models.

    `https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1`
    `https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1`
    """

    def __init__(self, user='[INST] ', eoh=' [/INST]', eoa='', **kwargs):
        super().__init__(user=user, eoh=eoh, eoa=eoa, **kwargs)

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        model_path = model_path.lower()
        if 'instruct' in model_path or 'llava' in model_path:
            if 'mistral' in model_path:
                return 'mistral'
            if 'mixtral' in model_path:
                return 'mixtral'


@MODELS.register_module(name=['internvl-zh'])
class InternVLZH(BaseChatTemplate):

    def __init__(self, user=': ', eoh=' ', assistant=': ', eoa='', **kwargs):
        super().__init__(user=user, eoh=eoh, assistant=assistant, eoa=eoa, **kwargs)

    def get_prompt(self, prompt, sequence_start=True):
        if self.capability == 'chat':
            return super().get_prompt(prompt, sequence_start)[:-1]
        return super().get_prompt(prompt, sequence_start)

    def messages2prompt(self, messages, sequence_start=True, **kwargs):
        if isinstance(messages, str):
            return self.get_prompt(messages, sequence_start)
        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        path = model_path.lower()
        if 'internvl-chat' in path and 'v1-1' in path:
            return 'internvl-zh'


@MODELS.register_module(name=['deepseek-vl'])
class DeepseekVL(BaseChatTemplate):

    def __init__(
            self,
            meta_instruction="""You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.""",  # noqa: E501
            eosys='\n\n',
            user='User: ',
            eoh='\n\n',
            assistant='Assistant: ',
            eoa='<|end▁of▁sentence|>',
            **kwargs):
        super().__init__(meta_instruction=meta_instruction,
                         eosys=eosys,
                         user=user,
                         eoh=eoh,
                         assistant=assistant,
                         eoa=eoa,
                         **kwargs)

    def get_prompt(self, prompt, sequence_start=True):
        if self.capability == 'chat':
            return super().get_prompt(prompt, sequence_start)[:-1]
        return super().get_prompt(prompt, sequence_start)

    def messages2prompt(self, messages, sequence_start=True, **kwargs):
        if isinstance(messages, str):
            return self.get_prompt(messages, sequence_start)
        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        path = model_path.lower()
        if 'deepseek-vl' in path and 'chat' in path:
            return 'deepseek-vl'


@MODELS.register_module(name=['deepseek-vl2'])
class DeepseekVL2(BaseChatTemplate):

    def __init__(self,
                 meta_instruction='',
                 eosys='',
                 user='<|User|>: ',
                 eoh='\n\n',
                 assistant='<|Assistant|>: ',
                 eoa='<|end▁of▁sentence|>',
                 **kwargs):
        super().__init__(meta_instruction=meta_instruction,
                         eosys=eosys,
                         user=user,
                         eoh=eoh,
                         assistant=assistant,
                         eoa=eoa,
                         **kwargs)

    def get_prompt(self, prompt, sequence_start=True):
        return super().get_prompt(prompt, sequence_start)[:-1]

    def messages2prompt(self, messages, sequence_start=True, **kwargs):
        if isinstance(messages, str):
            return self.get_prompt(messages, sequence_start)
        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        path = model_path.lower()
        if 'deepseek-vl2' in path:
            return 'deepseek-vl2'


@MODELS.register_module(name=['llava-chatml'])
class ChatmlDirect(BaseChatTemplate):

    def __init__(self,
                 system='<|im_start|>system\n',
                 meta_instruction='Answer the questions.',
                 eosys='<|im_end|>',
                 user='<|im_start|>user\n',
                 eoh='<|im_end|>',
                 assistant='<|im_start|>assistant\n',
                 eoa='<|im_end|>',
                 separator='',
                 **kwargs):
        super().__init__(system,
                         meta_instruction=meta_instruction,
                         eosys=eosys,
                         user=user,
                         eoh=eoh,
                         assistant=assistant,
                         eoa=eoa,
                         separator=separator,
                         **kwargs)

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        """Return the model_name that was registered to MODELS.

        Args:
            model_path (str): the model path used for matching.
        """
        path = model_path.lower()
        if 'llava' in path and 'v1.6-34b' in path:
            return 'llava-chatml'


@MODELS.register_module(name=['hf'])
class HFChatTemplate(BaseChatTemplate):
    """Chat template for HuggingFace models with `apply_chat_template` method.

    It MUST be at the end of @MODLES registry
    """

    def __init__(self, model_path: str = '', **kwargs):
        self.model_path = model_path
        try:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            # Verify if the model can perform apply_chat_template with different roles.
            self.user_start, self.user_end, _, _ = self._user_instruction()
            self.assistant_start, self.assistant_end, _, _ = self._assistant_instruction()
            _, _, self.sentinel_system_messages, self.sentinel_system_prompt = self._system_instruction()
            self.stop_words = []
            if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token is not None:
                self.stop_words.append(self.tokenizer.eos_token)
            if hasattr(self.tokenizer, 'eot_token') and self.tokenizer.eot_token is not None:
                self.stop_words.append(self.tokenizer.eot_token)
            arch, _ = get_model_arch(model_path)
            self.is_gpt_oss = arch == 'GptOssForCausalLM'
            if self.is_gpt_oss:
                self.stop_words.append('<|call|>')
        except Exception as e:
            raise ValueError(f'Try apply_chat_template failed: {e}')

    def get_prompt(self, prompt, sequence_start=True, **kwargs):
        messages = [{'role': 'user', 'content': prompt}]
        return self.messages2prompt(messages, sequence_start, **kwargs)

    def messages2prompt(self, messages, sequence_start=True, **kwargs):
        if isinstance(messages, str):
            messages = [{'role': 'user', 'content': messages}]
        assert all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages), \
            'Each message should be a dict with "role" and "content" keys.'

        if 'enable_thinking' in kwargs and kwargs['enable_thinking'] is None:
            # Workaround for internlm/Intern-S1: when enable_thinking=None passed apply_chat_template,
            # the  tag is not generated.
            kwargs.pop('enable_thinking')
        if 'reasoning_effort' in kwargs and kwargs['reasoning_effort'] is None:
            kwargs.pop('reasoning_effort')
        add_generation_prompt = messages[-1]['role'] != 'assistant'
        if sequence_start:
            prompt = self.tokenizer.apply_chat_template(messages,
                                                        tokenize=False,
                                                        add_generation_prompt=add_generation_prompt,
                                                        **kwargs)
        else:
            # Use a sentinel position to avoid the influence of default system role in the tokenizer's chat template
            # in interactive chat mode
            messages = self.sentinel_system_messages + messages if self.sentinel_system_messages else messages
            prompt = self.tokenizer.apply_chat_template(messages,
                                                        tokenize=False,
                                                        add_generation_prompt=add_generation_prompt,
                                                        **kwargs)
            # Remove the sentinel part.
            prompt = prompt[len(self.sentinel_system_prompt):] if len(self.sentinel_system_prompt) > 0 else prompt
        if messages[-1]['role'] == 'assistant' and len(self.assistant_end) > 0:
            prompt = prompt[:-len(self.assistant_end)]  # prefix of response to let the model complete the response
        if self.is_gpt_oss and not kwargs.get('tools'):
            # for gpt-oss model, remove this seems more conducive to instruction following.
            prompt = prompt.replace('commentary, ', '', 1)
        return prompt

    def _user_instruction(self):
        """Extract user message template markers from the tokenizer's chat
        template."""

        messages = [{'role': 'user', 'content': 'sentinel'}]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
        user_pos = prompt.find('sentinel')
        user_start = prompt[:user_pos]
        user_end = prompt[user_pos + len('sentinel'):]
        return user_start, user_end, messages, prompt

    def _assistant_instruction(self):
        """Extract assistant message template markers from the tokenizer's chat
        template."""

        # Some models, such as google/gemma-2-2b-it, require conversation roles to strictly
        # alternate between 'user' and 'assistant' (e.g., user/assistant/user/assistant...).
        # Consequently, we construct test messages containing both user and assistant roles
        # with special tokens, and parse the assistant tag according to user markers and
        # special tokens.
        messages = [{'role': 'user', 'content': 'placeholder'}, {'role': 'assistant', 'content': 'sentinel'}]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
        user_end_pos = prompt.find(self.user_end)
        assistant_pos = prompt.find('sentinel')
        assistant_start = prompt[user_end_pos + len(self.user_end):assistant_pos]
        assistant_end = prompt[assistant_pos + len('sentinel'):]
        return assistant_start, assistant_end, messages, prompt

    def _system_instruction(self):
        """Extract system message template markers from the tokenizer's chat
        template."""
        messages = [{'role': 'system', 'content': 'sentinel'}]
        try:
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
            system_pos = prompt.find('sentinel')
            if system_pos == -1:
                return None, None, [], self.tokenizer.bos_token or ''
            system_start = prompt[:system_pos]
            system_end = prompt[system_pos + len('sentinel'):]
            return system_start, system_end, messages, prompt
        except Exception:
            # Some models, such as google/gemma-2-2b-it, do not support a system role in the message structure.
            return None, None, [], self.tokenizer.bos_token or ''

    @classmethod
    def match(cls, model_path: str) -> Optional[str]:
        try:
            cls(model_path)
        except Exception:
            return False
        return True


def get_chat_template(model_path: str, config: Optional[ChatTemplateConfig] = None) -> BaseChatTemplate:
    """Get the chat template for the model.

    Args:
        model_path (str): the model path.
        config (Optional[ChatTemplateConfig]): the chat template config.
    Returns:
        BaseChatTemplate: the chat template.
    """
    if config is not None:
        return config.chat_template
    chat_template_name = 'base'
    for name, model in MODELS.module_dict.items():
        if model.match(model_path):
            chat_template_name = name
            break
    config = ChatTemplateConfig(chat_template_name, model_path=model_path)
    return config.chat_template


================================================
FILE: lmdeploy/monitoring/docker-compose.yaml
================================================
# copy from https://github.com/sgl-project/sglang/blob/main/examples/monitoring/docker-compose.yaml
version: '3'
services:
  prometheus:
    image: prom/prometheus:latest
    container_name: prometheus
    network_mode: host
    volumes:
      - ./prometheus.yaml:/etc/prometheus/prometheus.yml
    command:
      - '--config.file=/etc/prometheus/prometheus.yml'
      - '--storage.tsdb.path=/prometheus'

  grafana:
    image: grafana/grafana:latest
    container_name: grafana
    network_mode: host
    volumes:
      - ./grafana/datasources:/etc/grafana/provisioning/datasources
      - ./grafana/dashboards/config:/etc/grafana/provisioning/dashboards
      - ./grafana/dashboards/json:/var/lib/grafana/dashboards
    environment:
      - GF_AUTH_ANONYMOUS_ENABLED=true
      - GF_AUTH_ANONYMOUS_ORG_ROLE=Viewer
      - GF_AUTH_BASIC_ENABLED=false
      - GF_USERS_ALLOW_SIGN_UP=false
      - GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH=/var/lib/grafana/dashboards/lmdeploy-dashboard.json
    depends_on:
      - prometheus


================================================
FILE: lmdeploy/monitoring/grafana/dashboards/config/dashboard.yaml
================================================
apiVersion: 1
providers:
  - name: 'LMDeploy'
    orgId: 1
    folder: 'LMDeploy Monitoring'
    type: file
    disableDeletion: false
    updateIntervalSeconds: 10
    allowUiUpdates: false
    options:
      path: /var/lib/grafana/dashboards


================================================
FILE: lmdeploy/monitoring/grafana/dashboards/json/lmdeploy-dashboard.json
================================================
{
  "_comment": "json file adapted from https://github.com/vllm-project/vllm/blob/main/examples/online_serving/prometheus_grafana/grafana.json",
  "annotations": {
    "list": [
      {
        "builtIn": 1,
        "datasource": {
          "type": "grafana",
          "uid": "-- Grafana --"
        },
        "enable": true,
        "hide": true,
        "iconColor": "rgba(0, 211, 255, 1)",
        "name": "Annotations & Alerts",
        "target": {
          "limit": 100,
          "matchAny": false,
          "tags": [],
          "type": "dashboard"
        },
        "type": "dashboard"
      }
    ]
  },
  "description": "Monitoring LMDeploy Inference Server",
  "editable": true,
  "fiscalYearStartMonth": 0,
  "graphTooltip": 0,
  "id": 1,
  "links": [],
  "liveNow": false,
  "panels": [
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "End to end request latency measured in seconds.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "s"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 0,
        "y": 0
      },
      "id": 9,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.99, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P99",
          "range": true,
          "refId": "A",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.95, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P95",
          "range": true,
          "refId": "B",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.9, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P90",
          "range": true,
          "refId": "C",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.5, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P50",
          "range": true,
          "refId": "D",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "editorMode": "code",
          "expr": "rate(lmdeploy:e2e_request_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(lmdeploy:e2e_request_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])",
          "hide": false,
          "instant": false,
          "legendFormat": "Average",
          "range": true,
          "refId": "E"
        }
      ],
      "title": "E2E Request Latency",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "Number of tokens processed per second",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          }
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 12,
        "y": 0
      },
      "id": 8,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "rate(lmdeploy:prompt_tokens_total{model_name=\"$model_name\"}[$__rate_interval])",
          "fullMetaSearch": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "Prompt Tokens/Sec",
          "range": true,
          "refId": "A",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "rate(lmdeploy:generation_tokens_total{model_name=\"$model_name\"}[$__rate_interval])",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "Generation Tokens/Sec",
          "range": true,
          "refId": "B",
          "useBackend": false
        }
      ],
      "title": "Token Throughput",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "TOPT latency in seconds.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "s"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 0,
        "y": 8
      },
      "id": 10,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.99, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P99",
          "range": true,
          "refId": "A",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.95, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P95",
          "range": true,
          "refId": "B",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.9, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P90",
          "range": true,
          "refId": "C",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.5, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P50",
          "range": true,
          "refId": "D",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "editorMode": "code",
          "expr": "rate(lmdeploy:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(lmdeploy:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])",
          "hide": false,
          "instant": false,
          "legendFormat": "Mean",
          "range": true,
          "refId": "E"
        }
      ],
      "title": "Time Per Output Token Latency",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "Inter-token latency in seconds.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "s"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 0,
        "y": 8
      },
      "id": 10,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.99, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P99",
          "range": true,
          "refId": "A",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.95, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P95",
          "range": true,
          "refId": "B",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.9, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P90",
          "range": true,
          "refId": "C",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.5, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P50",
          "range": true,
          "refId": "D",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "editorMode": "code",
          "expr": "rate(lmdeploy:iter_token_latency_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(lmdeploy:iter_token_latency_count{model_name=\"$model_name\"}[$__rate_interval])",
          "hide": false,
          "instant": false,
          "legendFormat": "Mean",
          "range": true,
          "refId": "E"
        }
      ],
      "title": "Inter-Token Latency",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "Number of requests in RUNNING, WAITING, and SWAPPED state",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "none"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 12,
        "y": 8
      },
      "id": 3,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "lmdeploy:num_requests_running{model_name=\"$model_name\"}",
          "fullMetaSearch": false,
          "includeNullMetadata": true,
          "instant": false,
          "legendFormat": "Num Running",
          "range": true,
          "refId": "C",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "lmdeploy:num_requests_waiting{model_name=\"$model_name\"}",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": true,
          "instant": false,
          "legendFormat": "Num Waiting",
          "range": true,
          "refId": "D",
          "useBackend": false
        }
      ],
      "title": "Scheduler Stats",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "P50, P90, P95, and P99 TTFT latency in seconds.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "s"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 0,
        "y": 16
      },
      "id": 5,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.99, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P99",
          "range": true,
          "refId": "A",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.95, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P95",
          "range": true,
          "refId": "B",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.9, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P90",
          "range": true,
          "refId": "C",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "histogram_quantile(0.5, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))",
          "fullMetaSearch": false,
          "hide": false,
          "includeNullMetadata": false,
          "instant": false,
          "legendFormat": "P50",
          "range": true,
          "refId": "D",
          "useBackend": false
        },
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "editorMode": "code",
          "expr": "rate(lmdeploy:time_to_first_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(lmdeploy:time_to_first_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])",
          "hide": false,
          "instant": false,
          "legendFormat": "Average",
          "range": true,
          "refId": "E"
        }
      ],
      "title": "Time To First Token Latency",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "Percentage of used cache blocks by LMDeploy.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green",
                "value": null
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          },
          "unit": "percentunit"
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 12,
        "y": 16
      },
      "id": 4,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "editorMode": "code",
          "expr": "lmdeploy:gpu_cache_usage_perc{model_name=\"$model_name\"}",
          "instant": false,
          "legendFormat": "GPU Cache Usage",
          "range": true,
          "refId": "A"
        }
      ],
      "title": "Cache Utilization",
      "type": "timeseries"
    },
    {
      "datasource": {
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.",
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green"
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          }
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 0,
        "y": 32
      },
      "id": 11,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "builder",
          "expr": "sum by(finished_reason) (increase(lmdeploy:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))",
          "fullMetaSearch": false,
          "includeNullMetadata": true,
          "instant": false,
          "interval": "",
          "legendFormat": "__auto",
          "range": true,
          "refId": "A",
          "useBackend": false
        }
      ],
      "title": "Finish Reason",
      "type": "timeseries"
    },
    {
      "datasource": {
        "default": false,
        "type": "prometheus",
        "uid": "${DS_PROMETHEUS}"
      },
      "fieldConfig": {
        "defaults": {
          "color": {
            "mode": "palette-classic"
          },
          "custom": {
            "axisBorderShow": false,
            "axisCenteredZero": false,
            "axisColorMode": "text",
            "axisLabel": "seconds",
            "axisPlacement": "auto",
            "barAlignment": 0,
            "barWidthFactor": 0.6,
            "drawStyle": "line",
            "fillOpacity": 0,
            "gradientMode": "none",
            "hideFrom": {
              "legend": false,
              "tooltip": false,
              "viz": false
            },
            "insertNulls": false,
            "lineInterpolation": "linear",
            "lineWidth": 1,
            "pointSize": 5,
            "scaleDistribution": {
              "type": "linear"
            },
            "showPoints": "auto",
            "spanNulls": false,
            "stacking": {
              "group": "A",
              "mode": "none"
            },
            "thresholdsStyle": {
              "mode": "off"
            }
          },
          "mappings": [],
          "thresholds": {
            "mode": "absolute",
            "steps": [
              {
                "color": "green"
              },
              {
                "color": "red",
                "value": 80
              }
            ]
          }
        },
        "overrides": []
      },
      "gridPos": {
        "h": 8,
        "w": 12,
        "x": 12,
        "y": 32
      },
      "id": 14,
      "options": {
        "legend": {
          "calcs": [],
          "displayMode": "list",
          "placement": "bottom",
          "showLegend": true
        },
        "tooltip": {
          "mode": "single",
          "sort": "none"
        }
      },
      "targets": [
        {
          "datasource": {
            "type": "prometheus",
            "uid": "${DS_PROMETHEUS}"
          },
          "disableTextWrap": false,
          "editorMode": "code",
          "expr": "rate(lmdeploy:request_queue_time_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])",
          "fullMetaSearch": false,
          "includeNullMetadata": true,
          "instant": false,
          "legendFormat": "__auto",
          "range": true,
          "refId": "A",
          "useBackend": false
        }
      ],
      "title": "Queue Time",
      "type": "timeseries"
    }
  ],
  "refresh": "",
  "schemaVersion": 39,
  "tags": [],
  "templating": {
    "list": [
      {
        "current": {
          "selected": false,
          "text": "prometheus",
          "value": "edx8memhpd9tsa"
        },
        "hide": 0,
        "includeAll": false,
        "label": "datasource",
        "multi": false,
        "name": "DS_PROMETHEUS",
        "options": [],
        "query": "prometheus",
        "queryValue": "",
        "refresh": 1,
        "regex": "",
        "skipUrlSync": false,
        "type": "datasource"
      },
      {
        "current": {
          "selected": false,
          "text": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct",
          "value": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct"
        },
        "datasource": {
          "type": "prometheus",
          "uid": "${DS_PROMETHEUS}"
        },
        "definition": "label_values(model_name)",
        "hide": 0,
        "includeAll": false,
        "label": "model_name",
        "multi": false,
        "name": "model_name",
        "options": [],
        "query": {
          "query": "label_values(model_name)",
          "refId": "StandardVariableQuery"
        },
        "refresh": 1,
        "regex": "",
        "skipUrlSync": false,
        "sort": 0,
        "type": "query"
      }
    ]
  },
  "time": {
    "from": "now-5m",
    "to": "now"
  },
  "timepicker": {},
  "timezone": "",
  "title": "LMDeploy",
  "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b",
  "version": 8,
  "weekStart": ""
}


================================================
FILE: lmdeploy/monitoring/grafana/datasources/datasource.yaml
================================================
apiVersion: 1
datasources:
  - name: Prometheus
    type: prometheus
    access: proxy
    url: http://localhost:9090
    isDefault: true
    editable: false


================================================
FILE: lmdeploy/monitoring/prometheus.yaml
================================================
# prometheus.yaml
global:
  scrape_interval: 5s
  evaluation_interval: 30s

scrape_configs:
  - job_name: lmdeploy
    static_configs:
      - targets:
          - '127.0.0.1:23333'


================================================
FILE: lmdeploy/pipeline.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import atexit
import concurrent.futures
import os
from contextlib import closing
from functools import partial
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple

import torch
import tqdm
from typing_extensions import deprecated

from .archs import autoget_backend_config, get_task
from .messages import GenerationConfig, PytorchEngineConfig, Response, SpeculativeConfig, TurbomindEngineConfig
from .model import ChatTemplateConfig
from .serve.processors import MultimodalProcessor
from .utils import get_logger, get_model

if TYPE_CHECKING:
    from PIL.Image import Image

    from .serve.managers import Session

logger = get_logger('lmdeploy')


class Pipeline:
    """Pipeline - User-facing API layer for inference."""

    def __init__(self,
                 model_path: str,
                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,
                 chat_template_config: ChatTemplateConfig | None = None,
                 log_level: str = 'WARNING',
                 max_log_len: int | None = None,
                 speculative_config: SpeculativeConfig | None = None,
                 **kwargs):
        """Initialize Pipeline.

        Args:
            model_path: Path to the model.
            backend_config: Backend configuration.
            chat_template_config: Chat template configuration.
            log_level: Log level.
            max_log_len: Max number of prompt characters or prompt tokens being printed in log.
            speculative_config: Speculative decoding configuration.
            **kwargs: Additional keyword arguments.
        """

        os.environ.setdefault('TM_LOG_LEVEL', log_level)
        logger.setLevel(log_level)

        # Download model if the path does not exist locally
        if not os.path.exists(model_path):
            download_dir = backend_config.download_dir if backend_config else None
            revision = backend_config.revision if backend_config else None
            model_path = get_model(model_path, download_dir, revision)

        # Download speculative model if the path does not exist locally
        if speculative_config and speculative_config.model and not os.path.exists(speculative_config.model):
            download_dir = backend_config.download_dir if backend_config else None
            speculative_config.model = get_model(speculative_config.model, download_dir)

        # Create inference engine
        backend, backend_config = autoget_backend_config(model_path, backend_config)
        _, pipeline_class = get_task(backend, model_path)
        self.async_engine = pipeline_class(model_path,
                                           backend=backend,
                                           backend_config=backend_config,
                                           chat_template_config=chat_template_config,
                                           max_log_len=max_log_len,
                                           speculative_config=speculative_config,
                                           **kwargs)
        self.internal_thread = _EventLoopThread(daemon=True)
        self.limiter: asyncio.Semaphore = None
        self.session_mgr = self.async_engine.session_mgr
        self.backend_config = self.async_engine.backend_config
        self.async_engine.start_loop(self.internal_thread.loop, use_async_api=False)

    def infer(self,
              prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple],
              gen_config: GenerationConfig | List[GenerationConfig] | None = None,
              do_preprocess: bool = True,
              adapter_name: str | None = None,
              use_tqdm: bool = False,
              **kwargs):
        """Inference prompts.

        Args:
            prompts: Prompts to inference. It can be a single prompt, a list of prompts, a list of tuples, or a tuple.
                Tuple can be (prompt, image or [images]) or (image or [images], prompt).
            gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s).
            do_preprocess(bool): Whether to pre-process messages.
            adapter_name(str | None): Adapter name.
            use_tqdm(bool): Whether to use progress bar.
            **kwargs(dict): Additional keyword arguments.
        """
        is_single = self._is_single(prompts)
        # format prompts to openai message format, which is a list of dicts
        prompts = MultimodalProcessor.format_prompts(prompts)
        pbar = tqdm.tqdm(total=len(prompts)) if use_tqdm else None
        outputs = []
        try:
            requests = self._request_generator(prompts,
                                               gen_config=gen_config,
                                               do_preprocess=do_preprocess,
                                               adapter_name=adapter_name,
                                               stream_response=False,
                                               **kwargs)
            for g in self._infer(requests, multiplex=False, pbar=pbar):
                res = None
                for out in g:
                    res = res.extend(out) if res else out
                outputs.append(res)
        finally:
            if pbar: pbar.close()  # noqa
        if is_single:
            return outputs[0]
        return outputs

    @deprecated('This method is deprecated. Please use "Pipeline.infer" instead.')
    def batch_infer(self, *args, **kwargs):
        return self.infer(*args, **kwargs)

    def stream_infer(self,
                     prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple],
                     sessions: 'Session' | List['Session'] | None = None,
                     gen_config: GenerationConfig | List[GenerationConfig] | None = None,
                     do_preprocess: bool = True,
                     adapter_name: str | None = None,
                     stream_response: bool = True,
                     **kwargs):
        """Stream inference.

        Args:
            prompts(List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple]): Prompts to inference.
                It can be a single prompt, a list of prompts, a list of tuples, or a tuple.
                Tuple can be (prompt, image or [images]) or (image or [images], prompt).
            sessions(Session | List[Session] | None): Sessions. Each of which corresponds to a prompt.
            gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s).
            do_preprocess(bool): Whether to pre-process messages.
            adapter_name(str | None): Adapter name.
            stream_response(bool): Whether to stream the response. If True, the generator will stream the response.
                Otherwise, the generator will run until finish and return the final response. This argument
                is introduced to support the streaming and non-streaming modes of Pipeline.chat.
            **kwargs(dict): Additional keyword arguments.

        Returns:
            Generator: A generator that yields the output (i.e. instance of class `Response`) of the inference.
        """
        prompts = MultimodalProcessor.format_prompts(prompts)
        requests = self._request_generator(prompts,
                                           sessions=sessions,
                                           gen_config=gen_config,
                                           do_preprocess=do_preprocess,
                                           adapter_name=adapter_name,
                                           stream_response=stream_response,
                                           **kwargs)
        return self._infer(requests, multiplex=True)

    def close(self):
        """Close the pipeline."""
        self.internal_thread.close()
        self.async_engine.close()

    def chat(self,
             prompt: str | Tuple[str, 'Image' | List['Image']],
             session=None,
             gen_config: GenerationConfig | None = None,
             stream_response=False,
             adapter_name=None,
             **kwargs) -> 'Session' | Iterator:
        """Chat.

        Args:
            prompt (str): prompt
            session (Session): the chat session
            gen_config (GenerationConfig | None): a instance of
                GenerationConfig. Default to None.
            stream_response (bool): whether to stream the response.
            adapter_name (str): adapter name.
            **kwargs (dict): additional keyword arguments.
        """
        if session is None:
            session = self.session_mgr.get()
        session.update(prompt=prompt, response=None)

        prompt = MultimodalProcessor.format_prompts(prompt)

        sequence_start = session.step == 0
        generator = self.stream_infer(prompts=prompt,
                                      sessions=session,
                                      gen_config=gen_config,
                                      stream_response=stream_response,
                                      adapter_name=adapter_name,
                                      multiplex=True,
                                      sequence_start=sequence_start,
                                      sequence_end=False,
                                      step=session.step,
                                      **kwargs)

        def _gen():
            resp = None
            try:
                for out in generator:
                    resp = resp.extend(out) if resp else out
                    yield out
            except:  # noqa
                self._run(coro=session.async_abort())
                raise
            else:
                session.response = resp
                session.step += resp.generate_token_len + resp.input_token_len
                session.history.append((session.prompt, resp.text))

        if stream_response:
            return _gen()
        else:
            # run the generator until finish
            with closing(_gen()) as gen:
                for _ in gen:
                    pass
            session.generator = None

        return session

    def session(self) -> 'Session':
        """Create a new session."""
        return self.session_mgr.get()

    def get_reward_score(self, input_ids: List) -> List[float]:
        """Get reward score.

        Args:
            input_ids(List): a list of token_id or a list of token_id list or token_id tensor
        Return:
            reward score in a list. If the input_ids is a list of token_id, the return value
            is still a list with length 1.
        """
        supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
        arch = self.async_engine.arch
        if arch not in supported_reward_models:
            raise ValueError(f'{arch} is not in reward model list: {supported_reward_models}')
        assert isinstance(input_ids, List)
        assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
        # Make input_ids a list of token_id list
        input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids
        logits = self._run(coro=self.async_engine.async_get_logits(input_ids=input_ids)).result()
        logits = [x.squeeze() for x in logits]
        scores = [x[-1].cpu().item() for x in logits]
        return scores

    def get_ppl(self, input_ids: List[int] | List[List[int]]) -> List[float]:
        """Get perplexity scores given a list of input tokens that have to be
        of the same length.

        Args:
            input_ids (List[int] | List[List[int]]): the batch of input token ids

        Returns:
            List[float]: A list of perplexity scores.
        """
        assert isinstance(input_ids, List)
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]
        assert all(len(_) > 1 for _ in input_ids)

        # TODO: a better way to determine `max_input_len`, at most allocate
        # 2G mem for logits with shape [bs, max_input_len, vocab_size]
        vocab_size = self.async_engine.hf_cfg.vocab_size
        max_input_len = 2 * 1024**3 // (vocab_size * 4)
        sizes = [len(_) for _ in input_ids]
        result = []
        sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True)
        sizes = [value for index, value in sorted_index_values]
        indices = [index for index, value in sorted_index_values]
        logger.info(f'sorted sizes: {sizes}')
        logger.info(f'sorted indices: {indices}')
        for (start, end) in self._batch_iterator(sizes, max_input_len):
            logger.info(f'start: {start}, end: {end}')
            if start == end:
                _input_ids = input_ids[indices[start]]
                session = self.session_mgr.get()
                res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len)
                result.append(res)
                self.session_mgr.remove(session)
            else:
                _input_ids = [input_ids[indices[i]] for i in range(start, end)]
                sessions = [self.session_mgr.get() for _ in range(start, end)]
                res = self._get_ppl(
                    sessions=sessions,
                    input_ids=_input_ids,
                    max_input_len=max_input_len,
                )
                result.extend(res)
                for session in sessions:
                    self.session_mgr.remove(session)
        output = list(range(len(result)))
        for index, sorted_index in enumerate(indices):
            output[sorted_index] = result[index]
        return output

    def __call__(self,
                 prompts: List[str] | str | List[Dict] | List[List[Dict]],
                 gen_config: GenerationConfig | List[GenerationConfig] | None = None,
                 **kwargs):
        return self.infer(prompts, gen_config=gen_config, **kwargs)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    @deprecated('This method is deprecated. Please use "AsyncEngine.generate" instead.')
    async def generate(self, *args, **kwargs):
        """Generate responses as an async generator.

        This method delegates to async_engine.generate and forwards all yielded values.
        """
        async for item in self.async_engine.generate(*args, **kwargs):
            yield item

    @staticmethod
    def _is_single(prompts):
        """Check if prompts is a single prompt."""
        return (isinstance(prompts, str) or (isinstance(prompts, tuple) and len(prompts) == 2)
                or (isinstance(prompts, list) and len(prompts) > 0 and isinstance(prompts[0], Dict)))

    def _request_generator(self,
                           prompts: List[str] | str | List[Dict] | List[List[Dict]],
                           sessions: List['Session'] | 'Session' | None = None,
                           gen_config: GenerationConfig | List[GenerationConfig] | None = None,
                           **kwargs):
        """Generate requests."""
        is_single = self._is_single(prompts)
        prompts = [prompts] if is_single else prompts

        if sessions is None:
            sessions = [self.session_mgr.get() for _ in prompts]
        elif isinstance(sessions, list):
            sessions = sessions
        else:
            sessions = [sessions]

        if len(prompts) != len(sessions):
            raise ValueError(f'prompts and sessions should have the same length. '
                             f'Got {len(prompts)} prompts and {len(sessions)} sessions')

        if gen_config is None:
            gen_configs = [GenerationConfig()] * len(prompts)
        elif isinstance(gen_config, list):
            gen_configs = gen_config
        else:
            gen_configs = [gen_config] * len(prompts)

        if len(prompts) != len(gen_configs):
            raise ValueError(f'input gen_config length differs from the length of prompts. '
                             f'Got {len(prompts)} prompts and {len(gen_configs)} gen_configs')

        for prompt, gen_cfg, session in zip(prompts, gen_configs, sessions):
            # Use session_id is for backward compatibility. We will remove it in the future.
            # Since AsyncEngine.generate defines session_id in the argument lists, here we
            # use session_id to pass the session to the AsyncEngine.generate. It's
            yield dict(session_id=session, messages=prompt, gen_config=gen_cfg, **kwargs)

    def _get_limiter(self):
        if not self.limiter:
            self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size)
        return self.limiter

    def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]:

        async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
            async for out in g:
                que.put(out.to_response(idx))
            sem.release()
            if not multiplex:
                que.put(None)  # sentinel of inner generator
            if pbar:
                pbar.update(1)

        que = Queue()

        async def _infer():
            sem = self._get_limiter()
            tasks = []
            for idx, req in enumerate(requests):
                await sem.acquire()
                gen = self.async_engine.generate(**req)
                dst = que if multiplex else Queue()
                if not multiplex:
                    que.put(iter(dst.get, None))
                # create a task to send the responses
                task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
                tasks.append(task)
            if not multiplex:  # sentinel of outer generator
                que.put(None)
            await asyncio.gather(*tasks)
            if multiplex:
                que.put(None)  # sentinel of inner generator

        loop = loop or self.internal_thread.loop
        # submit the coroutine to async world
        asyncio.run_coroutine_threadsafe(_infer(),
                                         loop).add_done_callback(lambda f: None if f.cancelled() else f.result())

        return iter(que.get, None)

    def _run(self, fn=None, coro=None):
        assert (fn or coro) and not (fn and coro)
        loop = self.internal_thread.loop
        if fn:

            async def _coro():
                return fn()

            coro = _coro()
        return asyncio.run_coroutine_threadsafe(coro, loop)

    def _batch_iterator(self, sizes, max_value):
        """Return an iterator that calculates intervals (start, end) of a
        descend-order list, in which the sum of values in the range is the
        maximum number not less than max_value. By "the sum of values",

        here it means $$len(sizes[start:end]) * sizes[start]$$
        """
        i = 0
        while i < len(sizes):
            current_sum = 0
            start_index = i

            while i < len(sizes) and current_sum + sizes[start_index] <= max_value:
                current_sum += sizes[start_index]
                i += 1

            yield (start_index, i)
            if i > start_index:
                continue
            else:
                i += 1

    def _get_long_text_ppl(self, session, input_ids, max_input_len):
        assert all(isinstance(_, int) for _ in input_ids)
        seq_len = len(input_ids)
        assert seq_len > max_input_len
        logger.info(f'get long text ppl: seq_len {seq_len}')

        losses = []
        target_counts = []
        for i in range(0, seq_len, max_input_len):
            token_ids = input_ids[i:i + max_input_len]
            session.update(step=i)
            # shift token_ids by 1 to the left
            target_ids = input_ids[i + 1:i + 1 + max_input_len]
            loss = self._get_ppl(sessions=[session],
                                 input_ids=[token_ids],
                                 max_input_len=len(token_ids),
                                 target_ids=[target_ids],
                                 sequence_start=(i == 0),
                                 sequence_end=False)
            losses.extend(loss)
            target_counts.append(len(target_ids))
        losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
        loss_sum = sum(losses)
        target_count = sum(target_counts)
        return loss_sum / target_count

    def _get_ppl(self,
                 sessions: List['Session'],
                 input_ids: List[List[int]],
                 max_input_len: int,
                 target_ids=None,
                 sequence_start: bool = True,
                 sequence_end: bool = True):
        assert (isinstance(input_ids, List) and all(isinstance(_, List) for _ in input_ids))
        assert target_ids is None or len(target_ids) == len(input_ids)
        assert len(sessions) == len(input_ids)

        lens = [len(_) for _ in input_ids]
        total_len = sum(lens)
        assert sum(lens) <= max_input_len

        logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
                    f'total_len: {total_len}')
        torch.cuda.empty_cache()

        logits = self._run(coro=self.async_engine.async_get_logits(
            input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result()
        padding_token_id = -100
        if target_ids is None:
            target_ids = [x[1:] + [padding_token_id] for x in input_ids]
        else:
            target_ids = [
                target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
                for i in range(len(input_ids))
            ]
        target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids]

        result = []
        for _logits, _target_ids in zip(logits, target_ids):
            _logits = _logits.float()
            vocab_size = _logits.shape[-1]
            _target_ids = _target_ids.to(_logits.device)
            target_mask = _target_ids != padding_token_id
            # compute cross entropy loss
            flat_logits = _logits.contiguous().view(-1, vocab_size)
            flat_target_ids = _target_ids.contiguous().view(-1)
            flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits,
                                                                 flat_target_ids,
                                                                 reduction='none',
                                                                 ignore_index=padding_token_id)
            loss = flat_loss_matrix.sum()
            target_count = target_mask.sum()
            result.append(loss.item() / target_count.item())
        logger.info(f'ppl result: {result}')
        return result


class _EventLoopThread:

    def __init__(self, daemon=False):
        fut = concurrent.futures.Future()
        self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon)
        self.thread.start()
        self.loop: asyncio.AbstractEventLoop = fut.result()
        self.closed = False
        if daemon:
            atexit.register(self.close)

    def _thread_entry(self, fut):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        fut.set_result(loop)
        try:
            loop.run_forever()
        except BaseException as e:
            logger.error(f'[internal_thread] {type(e).__name__} {e}')
        finally:
            try:
                self._cancel_all_tasks()
                loop.run_until_complete(loop.shutdown_asyncgens())
            finally:
                asyncio.set_event_loop(None)
                loop.close()

    def _cancel_all_tasks(self):
        """Modified from asyncio/runners.py."""
        to_cancel = asyncio.all_tasks(self.loop)
        if not to_cancel:
            return

        for task in to_cancel:
            task.cancel()

        async def _gather():
            await asyncio.gather(*to_cancel, return_exceptions=True)

        self.loop.run_until_complete(_gather())

        for task in to_cancel:
            if task.cancelled():
                continue
            if task.exception() is not None:
                self.loop.call_exception_handler({
                    'message': 'unhandled exception during worker thread shutdown',
                    'exception': task.exception(),
                    'task': task,
                })

    def close(self):
        if self.closed:
            return
        self.closed = True
        self.loop.call_soon_threadsafe(self.loop.stop)
        self.thread.join()


================================================
FILE: lmdeploy/profiler.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os
import time
from typing import List

import numpy as np


class Session:

    UNKNOWN = 0
    SUCCESS = 1
    FAIL = 2

    def __init__(self, input_len, req_output_len):
        self.ts = []
        self.ns = []
        self.input_len = input_len
        self.req_output_len = req_output_len
        self.status = Session.UNKNOWN

    def tick(self, n_token):
        self.ts.append(time.perf_counter())
        self.ns.append(n_token)

    def finish(self, status):
        self.status = status


class Profiler:

    def __init__(self, stream_output: bool, percentages: List[int]):
        self.sessions: List[Session] = []
        self.stream_output = stream_output
        self.percentages = percentages

    def new_session(self, *args, **kwargs):
        sess = Session(*args, **kwargs)
        self.sessions.append(sess)
        return sess

    def start(self):
        self.t_start = time.perf_counter()

    def finish(self):
        self.elapsed_time = time.perf_counter() - self.t_start

    def compute_metrics(self):
        self.ttfts: List[float] = []
        self.tpots: List[float] = []
        self.e2es: List[float] = []
        self.itls: List[float] = []
        self.tpts: List[int] = []
        self.total_output = 0
        self.total_input = 0
        self.success = 0

        for sess in self.sessions:
            if sess.status != Session.SUCCESS:
                continue
            ns = sess.ns
            ts = sess.ts
            if ns[-1] < sess.req_output_len:
                continue
            self.success += 1
            self.total_output += ns[-1]
            self.total_input += sess.input_len
            self.e2es.append(ts[-1] - ts[0])
            self.ttfts.append(ts[1] - ts[0])
            if ns[-1] > ns[1]:
                self.tpots.append((ts[-1] - ts[1]) / (ns[-1] - ns[1]))
            else:  # no-stream-output
                self.tpots.append((ts[-1] - ts[0]) / (ns[-1] - ns[0]))
            t_dif = np.subtract(ts[1:], ts[:-1])
            n_dif = np.subtract(ns[1:], ns[:-1])
            self.itls.extend(t_dif[1:])
            self.tpts.extend(n_dif)

        self.output_throughput = self.total_output / self.elapsed_time
        self.input_throughput = self.total_input / self.elapsed_time

        qs = self.percentages

        self.e2es = self.e2es or [float('inf')]
        self.tpots = self.tpots or [float('inf')]
        self.ttfts = self.ttfts or [float('inf')]
        self.itls = self.itls or [float('inf')]
        self.tpts = self.tpts or [0]

        self.tpot_mean = np.mean(self.tpots)
        self.tpot_stat = tuple(np.percentile(self.tpots, qs))
        self.e2e_mean = np.mean(self.e2es)
        self.e2e_stat = tuple(np.percentile(self.e2es, qs))

        if self.stream_output:
            self.ttft_mean = np.mean(self.ttfts)
            self.ttft_stat = tuple(np.percentile(self.ttfts, qs))
            self.itls_mean = np.mean(self.itls)
            self.itls_stat = tuple(np.percentile(self.itls, qs))
            self.tpts_mean = np.mean(self.tpts)
            self.tpts_stat = tuple(np.percentile(self.tpts, qs).astype(int))

        self.rps = self.success / self.elapsed_time

    def summarize(self, title: str, hyperparams: List = None, header=40, digits=10):

        width = header + digits * (1 + len(self.percentages))

        def tab_row(name, *items):

            def fmt(x):
                return '{:>{d}.3f}'.format(x, d=digits) if isinstance(x, float) else '{:>{d}}'.format(x, d=digits)

            print('{:<{p}}{}'.format(name, ''.join([fmt(x) for x in items]), p=header))

        print('\n{s:{c}^{n}}'.format(s=f' {title} ', n=width, c='='))
        tab_row('Benchmark duration', self.elapsed_time)
        tab_row('Total requests', len(self.sessions))
        tab_row('Successful requests', self.success)
        if hyperparams:
            for k, v in hyperparams:
                tab_row(k, v)
        tab_row('Total input tokens', self.total_input)
        tab_row('Total generated tokens', self.total_output)
        tab_row('Input throughput (tok/s)', self.input_throughput)
        tab_row('Output throughput (tok/s)', self.output_throughput)
        tab_row('Request throughput (req/s)', self.rps)
        print('-' * width)
        tab_row('', 'mean', *(f'P{q}' for q in self.percentages))
        tab_row('End-to-end Latency', self.e2e_mean, *self.e2e_stat)
        if self.stream_output:
            tab_row('Time to First Token (TTFT)', self.ttft_mean, *self.ttft_stat)
        tab_row('Time per Output Token (TPOT)', self.tpot_mean, *self.tpot_stat)
        if self.stream_output:
            tab_row('Inter-token Latency (ITL)', self.itls_mean, *self.itls_stat)
            tab_row('Tokens per Tick', self.tpts_mean, *self.tpts_stat)
        print('=' * width)

    def save_csv(self, csv_file: str, hyperparams):
        """Export legacy metrics to CSV."""
        file_exists = os.path.isfile(csv_file)
        with open(csv_file, mode='a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            keys, vals = zip(*hyperparams)
            if not file_exists:
                writer.writerow([
                    *keys,
                    'completed',
                    'total_input_tokens',
                    'total_output_tokens',
                    'duration',
                    'request_throughput',
                    'input_throughput',
                    'output_throughput',
                    'mean_e2e_latency_ms',
                    'mean_ttft_ms',
                    'mean_tpot_ms',
                    'mean_itl_ms',
                ])
            writer.writerow([
                *vals,
                self.success,
                self.total_input,
                self.total_output,
                self.elapsed_time,
                f'{self.rps:.3f}',
                f'{(self.input_throughput):.3f}',
                f'{self.output_throughput:.3f}',
                f'{self.e2e_mean*1000:.3f}',
                f'{self.ttft_mean*1000:.3f}' if self.stream_output else '-',
                f'{self.tpot_mean*1000:.3f}',
                f'{self.itls_mean*1000:.3f}' if self.stream_output else '-',
            ])


================================================
FILE: lmdeploy/pytorch/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/adapter/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/adapter/adapter.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import re
from typing import Dict, Iterable, List, Tuple

import torch
from torch import nn


def get_ranks_and_scalings(target_name: str, cfgs: Iterable, device: torch.device = None):
    """Get ranks and scalings."""
    ranks = []
    scalings = []
    for cfg in cfgs:
        if target_name not in cfg.target_modules:
            ranks.append(0)
            scalings.append(1)
            continue
        ranks.append(cfg.r)
        scalings.append(float(cfg.lora_alpha / cfg.r))
    ranks = torch.tensor(ranks, device=device)
    scalings = torch.tensor(scalings, device=device)
    return ranks, scalings


def find_all_target(model: torch.nn.Module, target_name: str):
    """Find all targets."""
    # find packed name
    packed_name = target_name
    pack_idx = None
    packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict())
    for name, sub_names in packed_modules_mapping.items():
        if target_name in sub_names:
            pack_idx = sub_names.index(target_name)
            packed_name = name
            break

    found_mods = []
    name_postfix = f'.{packed_name}'
    for name, mod in model.named_modules():
        if not name.endswith(name_postfix):
            continue
        found_mods.append((name, mod))

    return found_mods, pack_idx


def get_layer_index(key: str, layers_pattern: str = None):
    """Get layer index of the lora linear."""
    if isinstance(layers_pattern, str):
        layers_pattern = [layers_pattern]
    if layers_pattern is None or len(layers_pattern) == 0:
        layer_index = re.match(r'.*\.[^.]*\.(\d+)\.', key)
        return int(layer_index[1])
    else:
        for pattern in layers_pattern:
            layer_index = re.match(f'.*.{pattern}\\.(\\d+)\\.*', key)

            if layer_index is not None:
                return int(layer_index[1])


def _get_reverse_pack_map(model: nn.Module):
    """Get reverse pack map."""
    packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict())
    reverse_map = dict()
    for pack_name, names in packed_modules_mapping.items():
        for name in names:
            reverse_map[name] = pack_name
    return reverse_map


def _get_key_map(reverse_map: Dict[str, str]):
    """Get key map."""
    key_map = dict()
    for name, pack_name in reverse_map.items():
        key = f'.{name}'
        val = f'.{pack_name}.lora_adapters.{name}'
        key_map[key] = val

    return key_map


def load_lora_weights(model: nn.Module, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
    """Load lora weights."""
    from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
    prefix_len = len('base_model.model.')
    w_len = len('.weight')
    reverse_map = _get_reverse_pack_map(model)
    key_map = _get_key_map(reverse_map)

    params_dict = dict(model.named_parameters())
    for name, loaded_weight in weights:
        name = name[prefix_len:]
        splited_name = name.split('.')
        assert splited_name[-1] == 'weight'
        assert splited_name[-2] in ['lora_A', 'lora_B']
        mod_name = splited_name[-3]
        dot_mod_name = f'.{mod_name}'
        if dot_mod_name in key_map:
            replace_name = key_map[dot_mod_name]
        else:
            replace_name = f'.{mod_name}.lora_adapters.{mod_name}'
        name = name[:-w_len]
        param_name = name.replace(dot_mod_name, replace_name)

        param = params_dict[param_name]
        load_weight(param, loaded_weight, adapter_id=adapter_id)


class AdapterManager:
    """Adapter manager."""

    def __init__(self, adapters: Dict[str, str]):
        if adapters is None:
            adapters = dict()

        adapter_names = list(adapters.keys())
        adapter_names = sorted(adapter_names)
        adapter_names = [None] + adapter_names

        adapter_id_map = dict(zip(adapter_names, range(len(adapter_names))))
        self.adapter_id_map = adapter_id_map

    def get_adapter_ids(self, names: List[str]):
        return [self.adapter_id_map[name] for name in names]

    def num_adapters(self):
        return len(self.adapter_id_map)


================================================
FILE: lmdeploy/pytorch/backends/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .base import OpType  # noqa: F401
from .selector import get_backend  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod


class SiluAndMulImpl(ABC):
    """Silu + multiple residual fused implementation."""

    @abstractmethod
    def forward(self, x):
        """forward."""
        raise NotImplementedError


class SiluAndMulBuilder(ABC):
    """Silu and mul implementation builder."""

    @staticmethod
    @abstractmethod
    def build(inplace: bool = False):
        """build."""
        raise NotImplementedError


class GeluAndMulImpl(ABC):
    """Gelu + multiple residual fused implementation."""

    @abstractmethod
    def forward(self, x):
        """forward."""
        raise NotImplementedError


class GeluAndMulBuilder(ABC):
    """Gelu and mul implementation builder."""

    @staticmethod
    @abstractmethod
    def build(approximate: str = 'none'):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/apply_rotary_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

from torch import Tensor


class ApplyRotaryEmbImpl(ABC):
    """Apply rotary embedding implementation."""

    @abstractmethod
    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
        """forward."""
        raise NotImplementedError


class ApplyRotaryEmbBuilder(ABC):
    """Apply rotary embedding implementation builder."""

    @staticmethod
    @abstractmethod
    def build():
        """Build implementation."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Generic, Literal, TypeVar

import torch


@dataclass
class AttentionMetadata:
    """Base Attention metadata."""
    is_decoding: bool
    block_offsets: torch.Tensor
    q_start_loc: torch.Tensor = None
    q_seqlens: torch.Tensor = None
    kv_seqlens: torch.Tensor = None
    fill_seqlens: torch.Tensor = None
    cu_seqlens_q: torch.Tensor = None
    cu_seqlens_k: torch.Tensor = None
    quant_policy: Literal[0, 4, 8] = 0


T = TypeVar('T', bound=AttentionMetadata)


class AttentionImpl(ABC, Generic[T]):
    """Attention implementation."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = None,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        use_flash_mla: bool = False,
        **kwargs,
    ) -> None:
        if scale is None:
            scale = 1.0 / (head_size**0.5)

        if num_kv_heads is None:
            num_kv_heads = num_heads

        if v_head_size is None:
            v_head_size = head_size

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.v_head_size = v_head_size
        self.alibi = alibi
        self.sliding_window = sliding_window
        self.logit_softcapping = logit_softcapping
        self.causal = causal
        self.use_flash_mla = use_flash_mla
        self.alibi_slopes = None

    @staticmethod
    @lru_cache(maxsize=4)
    def make_alibi_slopes(head_start: int, head_end: int, num_heads: int, alibi_scale: float, dtype: torch.dtype,
                          device: torch.device):
        """Make alibi slopes."""
        head_ids = torch.arange(head_start, head_end, dtype=dtype, device=device)
        num_heads_tensor = head_ids.new_full([1], num_heads)
        num_heads_p2 = num_heads_tensor.log2().to(torch.int64).exp2()

        # update head_ids and closest_power_of_2
        mask = head_ids < num_heads_p2
        head_ids = torch.where(mask, head_ids, (head_ids - num_heads_p2) * 2)
        closest_power_of_2 = torch.where(mask, num_heads_p2, num_heads_p2 * 2)

        # get slope
        start = torch.sub(3, closest_power_of_2.log2()).exp2().neg()
        start = start.exp2()
        ratio = start
        return start * torch.pow(ratio, head_ids) * alibi_scale

    def set_alibi_slopes(self, slopes: torch.Tensor):
        self.alibi_slopes = slopes

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: T,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        learnable_sink: torch.Tensor = None,
        nsa_indices: torch.Tensor = None,
        inplace: bool = False,
    ) -> torch.Tensor:
        """forward."""
        raise NotImplementedError


class AttentionBuilder(ABC, Generic[T]):
    """Attention implementation builder."""

    @staticmethod
    @abstractmethod
    def build(
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        use_flash_mla: bool = False,
        learnable_sink: bool = False,
        block_sparse_size: int = 1,
        **kwargs,
    ) -> AttentionImpl[T]:
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/awq_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional

import torch


class LinearW4A16Impl(ABC):
    """W4a16 linear implementation."""

    def update_weights(self,
                       qweight: torch.Tensor,
                       scales: torch.Tensor,
                       qzeros: torch.Tensor,
                       bias: Optional[torch.Tensor] = None):
        """Update weights."""
        return qweight, scales, qzeros, bias

    @abstractmethod
    def forward(self,
                x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        raise NotImplementedError


class LinearW4A16Builder(ABC):
    """W4a16 linear implementation builder."""

    @staticmethod
    @abstractmethod
    def build(in_features: int,
              out_features: int,
              w_bit: int,
              group_size: int,
              bias: bool = False,
              dtype: torch.dtype = None):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from:
# https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/abstract.py
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig


class OpType(Enum):
    """Layer type enumerate."""
    PagedAttention = auto()
    FlashAttention = auto()
    Linear = auto()
    RotaryEmbedding = auto()
    ApplyRotaryEmb = auto()
    SiluAndMul = auto()
    GeluAndMul = auto()
    RMSNorm = auto()
    LayerNorm = auto()
    LoRA = auto()
    LinearW8A8 = auto()
    RMSNormW8A8 = auto()
    MultinomialSampling = auto()
    LinearW4A16 = auto()
    SoftmaxTopK = auto()
    FusedMoE = auto()
    FusedMoEW8A8 = auto()
    LinearBlockedF8 = auto()
    FusedMoEBlockedF8 = auto()
    NSAIndexFP8 = auto()
    Embedding = auto()

    # MoE router
    RouterNoauxTC = auto()

    # Gated Delta
    CausalConv1d = auto()
    GatedDeltaRule = auto()


class OpsBackend(ABC):
    """Layer backend abstract."""

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        """Get backend name."""
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def get_layer_impl_builder(cls, layer_type: OpType):
        """Get builder of given layer type."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_attention_metadata_cls():
        """Get attention metadata class."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get block shape of k."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get block shape of v."""
        raise NotImplementedError

    @classmethod
    def update_step_context(cls, step_context):
        """Update StepContext for inference.

        attention meta should be built here.
        """
        return step_context

    @staticmethod
    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                           backend_config: BackendConfig, device: torch.device):
        """Build graph runner."""
        from .graph_runner import GraphRunner
        return GraphRunner(model, model_config, cache_config, backend_config, device)

    @staticmethod
    def device_count():
        """Get num available devices."""
        return None

    @staticmethod
    def support_ray():
        """Support ray."""
        return False


================================================
FILE: lmdeploy/pytorch/backends/blockedf8_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearBlockedF8Impl(ABC):
    """Linear BlockedF8 implementation api."""

    def __init__(self):
        self.scale_fmt: Optional[str] = None

    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Update weights."""
        return weight, scale, bias

    def set_scale_fmt(self, scale_fmt: Optional[str]):
        """Set scale fmt."""
        self.scale_fmt = scale_fmt

    @abstractmethod
    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[dist.ProcessGroup] = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        raise NotImplementedError


class LinearBlockedF8Builder(ABC):
    """Linear BlockedF8 implementation builder."""

    @staticmethod
    @abstractmethod
    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/causal_conv1d.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch


class CausalConv1dImpl(ABC):
    """CausalConv1d implementation api."""

    @abstractmethod
    def conv1d_fn(self,
                  x: torch.Tensor,
                  weight: torch.Tensor,
                  bias: torch.Tensor | None = None,
                  seq_idx: torch.Tensor | None = None,
                  return_final_states: bool = False,
                  activation: str | None = None):
        """forward."""
        raise NotImplementedError

    @abstractmethod
    def update_fn(self,
                  x: torch.Tensor,
                  conv_state: torch.Tensor,
                  weight: torch.Tensor,
                  bias: torch.Tensor | None = None,
                  activation: str | None = None,
                  conv_state_indices: torch.Tensor | None = None):
        """Update conv state."""
        raise NotImplementedError


class CausalConv1dBuilder(ABC):
    """CausalConv1d implementation builder."""

    @staticmethod
    @abstractmethod
    def build():
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/cuda/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import CudaOpsBackend  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/cuda/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul

from ..activation import SiluAndMulBuilder, SiluAndMulImpl


class TritonSiluAndMulImpl(SiluAndMulImpl):
    """Silu + multiple residual fused implementation."""

    def __init__(self, inplace: bool):
        self.inplace = inplace

    def forward(self, x):
        """forward."""
        out = None
        x_shape = None
        if x.dim() != 2:
            x_shape = x.shape
            x = x.flatten(0, -2)
        if self.inplace:
            out = x.chunk(2, -1)[0]

        out = silu_and_mul(x, out)

        if x_shape is not None:
            out = out.unflatten(0, x_shape[:-1])
        return out


class TritonSiluAndMulBuilder(SiluAndMulBuilder):
    """Silu and mul implementation builder."""

    @staticmethod
    def build(inplace: bool = False):
        """build."""
        return TritonSiluAndMulImpl(inplace)


================================================
FILE: lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor

from lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb

from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl


class TritonApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
    """Apply rotary embedding implementation."""

    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
        """forward."""
        if inplace:
            q_embed = query
            k_embed = key
        else:
            q_embed = torch.empty_like(query)
            k_embed = torch.empty_like(key)
        return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)


class TritonApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):
    """Apply rotary embedding implementation builder."""

    @staticmethod
    def build():
        """Build implementation."""
        return TritonApplyRotaryEmbImpl()


================================================
FILE: lmdeploy/pytorch/backends/cuda/attention/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools

import torch

from lmdeploy.pytorch.backends.attention import AttentionBuilder
from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata

logger = get_logger('lmdeploy')

use_fa3 = False
try:
    # Now flash-attention only support FA3 for sm90a && cuda >= 12.3
    if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):
        import lmdeploy.pytorch.third_party.flash_attn_interface  # noqa: F401
        assert torch.ops.flash_attn_3 is not None
        use_fa3 = True
except Exception:
    logger.debug('For higher performance, please install FlashAttention-3 '
                 'https://github.com/Dao-AILab/flash-attention')


@functools.lru_cache
def use_fa3_warning():
    if use_fa3:
        return True
    logger.warning('For higher performance, please install FlashAttention-3 '
                   'https://github.com/Dao-AILab/flash-attention')
    return False


@functools.lru_cache
def _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int, head_size: int) -> bool:
    """Check if FA3 should be enabled.

    FA3 is enabled when:
    - No alibi
    - No learnable sink
    - block_sparse_size == 1
    - FA3 is available (checked by use_fa3_warning)

    Returns:
        True if FA3 should be enabled, False otherwise.
    """
    enable = not alibi and not learnable_sink and block_sparse_size == 1 and head_size <= 256
    if enable and not use_fa3_warning():
        enable = False
    return enable


def _normalize_sliding_window(sliding_window):
    """Normalize sliding window to tuple format.

    Args:
        sliding_window: None, int, or tuple of (left, right).

    Returns:
        Tuple of (left, right) or (-1, -1) if None.
    """
    if sliding_window is None:
        return (-1, -1)
    if isinstance(sliding_window, int):
        return (sliding_window, sliding_window)
    return sliding_window


class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]):
    """Triton attention builder.

    This builder selects the appropriate attention implementation based on:
    1. use_flash_mla: Use FlashMLAImpl for MLA models
    2. enable_fa3: Use FA3Impl if FA3 is available and supported
    3. Default: Use TritonAttentionImpl as fallback
    """

    @staticmethod
    def build(
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        use_flash_mla: bool = False,
        learnable_sink: bool = False,
        block_sparse_size: int = 1,
        **kwargs,
    ) -> TritonAttentionImpl:
        """Build appropriate attention implementation.

        Args:
            num_heads: Number of attention heads.
            head_size: Size of each attention head.
            scale: Scaling factor for attention scores.
            num_kv_heads: Number of key-value heads (for GQA).
            v_head_size: Size of value head (for MLA).
            alibi: Whether to use ALiBi positional encoding.
            sliding_window: Sliding window size for local attention.
            logit_softcapping: Logit softcapping value (for Gemma 2).
            causal: Whether to use causal attention.
            use_flash_mla: Whether to use Flash MLA implementation.
            learnable_sink: Whether to use learnable sink tokens.
            block_sparse_size: Block sparse attention size.
            **kwargs: Additional arguments.

        Returns:
            Appropriate AttentionImpl instance.
        """
        # Normalize sliding window format
        sliding_window = _normalize_sliding_window(sliding_window)

        # Common arguments for all implementations
        common_args = dict(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_size=v_head_size,
            alibi=alibi,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            causal=causal,
            **kwargs,
        )
        enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size, head_size)

        if use_flash_mla is True:
            logger.debug('Build FlashMLAImpl Attention')
            from .mla import FlashMLAImpl
            return FlashMLAImpl(use_fa3=use_fa3, **common_args)
        elif enable_fa3:
            logger.debug('Build FA3Impl Attention')
            from .fa3 import FA3Impl
            return FA3Impl(**common_args)
        else:
            logger.debug('Build TritonAttentionImpl Attention')
            return TritonAttentionImpl(block_sparse_size=block_sparse_size, **common_args)


================================================
FILE: lmdeploy/pytorch/backends/cuda/attention/default.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Literal

import torch

from lmdeploy.pytorch.backends.attention import AttentionImpl, AttentionMetadata
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


@dataclass
class TritonAttentionMetadata(AttentionMetadata):
    """Triton attention metadata.

    This dataclass contains all metadata needed for attention computation
    across different stages (prefill/decoding) and implementations.

    Attributes:
        is_decoding: True for decoding stage, False for prefill.
        block_offsets: Block indices for paged KV cache [batch_size, max_blocks].
        q_start_loc: Start location of each query sequence [batch_size].
        q_seqlens: Length of each query sequence [batch_size].
        kv_start_loc: Start location of each KV sequence [batch_size].
        kv_seqlens: Length of each KV sequence [batch_size].
        quant_policy: Quantization policy (0=none, 4=int4, 8=int8/fp8).
        kv_flatten_size: Total size of flattened KV cache.
        tile_scheduler_metadata: Scheduler metadata for Flash MLA.
        num_splits: Number of splits for Flash MLA.
        cu_seqlens_q: Cumulative query sequence lengths [batch_size + 1].
        cu_seqlens_k: Cumulative KV sequence lengths [batch_size + 1].
        scheduler_metadata: Scheduler metadata for FA3.
        max_kv_seqlen: Maximum KV sequence length in the batch.
        max_q_seqlen: Maximum query sequence length in the batch.
    """
    is_decoding: bool
    block_offsets: torch.Tensor
    q_start_loc: torch.Tensor = None
    q_seqlens: torch.Tensor = None
    kv_start_loc: torch.Tensor = None
    kv_seqlens: torch.Tensor = None
    quant_policy: Literal[0, 4, 8] = 0
    kv_flatten_size: int = None
    # flash mla
    tile_scheduler_metadata: torch.Tensor = None
    num_splits: torch.Tensor = None
    cu_seqlens_q: torch.Tensor = None
    cu_seqlens_k: torch.Tensor = None
    # flash attn
    scheduler_metadata: torch.Tensor = None
    max_kv_seqlen: int = None
    max_q_seqlen: int = None


def _cdiv(a, b):
    """Perform ceiling division (division rounded up).

    Args:
        a: Dividend.
        b: Divisor.

    Returns:
        Ceiling of a / b.
    """
    return (a + b - 1) // b


class TritonAttentionImpl(AttentionImpl[TritonAttentionMetadata]):
    """Triton attention implementation."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        block_sparse_size: int = 1,
        **kwargs,
    ):
        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_size=v_head_size,
            alibi=alibi,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            causal=causal,
            **kwargs,
        )
        self.logit_softcapping = -1 if self.logit_softcapping <= 0.0 else self.logit_softcapping
        assert not (alibi and not causal)

        from lmdeploy.pytorch.kernels.cuda import (fill_kv_cache, flash_attn_varlen_func, flash_attn_with_kvcache,
                                                   flatten_kv_cache)

        self.fill_kv_cache = fill_kv_cache
        self.paged_attention_fwd = flash_attn_with_kvcache
        self.flatten_kv_cache = flatten_kv_cache
        self.flash_attention_fwd = flash_attn_varlen_func

        self.block_sparse_size = block_sparse_size

    def _get_max_q_seqlen(
        self,
        query: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ) -> int:
        """Get max q seqlen."""
        if attn_metadata.is_decoding:
            max_q_seqlen = self.block_sparse_size
        else:
            if attn_metadata.max_q_seqlen is not None:
                max_q_seqlen = attn_metadata.max_q_seqlen
            else:
                max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
        return max_q_seqlen

    def _get_fill_meta(
        self,
        key: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
    ):
        """Get fill meta."""
        fill_seqlens = attn_metadata.q_seqlens
        fill_max_q_seqlen = max_q_seqlen
        fill_q_start_loc = attn_metadata.q_start_loc
        return fill_seqlens, fill_max_q_seqlen, fill_q_start_loc

    def _fill_kv_cache_impl(
        self,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
    ):
        """Fill kv cache."""
        kv_seqlens = attn_metadata.kv_seqlens
        block_offsets = attn_metadata.block_offsets
        quant_policy = attn_metadata.quant_policy

        # fill seqlen args
        fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(
            key,
            attn_metadata,
            max_q_seqlen,
        )

        # fill kv cache
        self.fill_kv_cache(
            key,
            value,
            k_cache,
            v_cache,
            fill_q_start_loc,
            fill_seqlens,
            kv_seq_length=kv_seqlens,
            max_q_seq_length=fill_max_q_seqlen,
            block_offsets=block_offsets,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
            quant_policy=quant_policy,
        )

    def _forward_decoding(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        learnable_sink: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for decoding stage.

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.
            learnable_sink: Learnable sink tokens.

        Returns:
            Attention output tensor.
        """
        block_offsets = attn_metadata.block_offsets
        quant_policy = attn_metadata.quant_policy

        attn_output = self.paged_attention_fwd(
            query,
            k_cache,
            v_cache,
            cache_seqlens=attn_metadata.kv_seqlens,
            page_table=block_offsets,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            max_seqlen_q=max_q_seqlen,
            softmax_scale=self.scale,
            softcap=self.logit_softcapping,
            window_size=self.sliding_window,
            # custom args
            sinks=learnable_sink,
            alibi_slopes=self.alibi_slopes,
            quant_policy=quant_policy,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
        )
        return attn_output

    def _forward_prefill(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        learnable_sink: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for prefill stage.

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.
            learnable_sink: Learnable sink tokens.

        Returns:
            Attention output tensor.
        """
        block_offsets = attn_metadata.block_offsets
        kv_start_loc = attn_metadata.kv_start_loc
        kv_seqlens = attn_metadata.kv_seqlens
        kv_flatten_size = attn_metadata.kv_flatten_size
        quant_policy = attn_metadata.quant_policy

        # Prepare flattened KV cache
        BLOCK_BS = k_cache.size(1)
        # pad one more block to avoid invalid kv visit
        out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)
        kv_layout = 'hsd'  # custom triton kernel requires 'hsd' while fa3 requires 'shd'

        flatten_k, flatten_v = self.flatten_kv_cache(
            k_cache,
            v_cache,
            kv_seqlens,
            block_offsets,
            start_loc=kv_start_loc,
            out_size=out_size,
            out_dtype=query.dtype,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
            quant_policy=quant_policy,
            flatten_kv_layout=kv_layout,
        )

        attn_output = self.flash_attention_fwd(
            query,
            flatten_k,
            flatten_v,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            cu_seqlens_k=attn_metadata.cu_seqlens_k,
            max_seqlen_q=max_q_seqlen,
            max_seqlen_k=attn_metadata.max_kv_seqlen,
            window_size=self.sliding_window,
            softmax_scale=self.scale,
            softcap=self.logit_softcapping,
            causal=self.causal,
            # custom args
            sinks=learnable_sink,
            alibi_slopes=self.alibi_slopes,
            block_sparse_size=self.block_sparse_size,
            kv_layout=kv_layout,
        )
        return attn_output

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        learnable_sink: torch.Tensor = None,
        inplace: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Forward pass for attention computation.

        This method handles both prefill and decoding stages by:
        1. Computing max query sequence length
        2. Filling KV cache if new key/value are provided
        3. Dispatching to appropriate stage-specific method

        Args:
            query: Query tensor.
            key: Key tensor (None for decoding-only).
            value: Value tensor (None for decoding-only).
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata containing stage info and indices.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.
            learnable_sink: Learnable sink tokens.
            inplace: Whether to modify query inplace (unused, kept for compatibility).

        Returns:
            Attention output tensor.
        """
        # Shared preparation
        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)

        # Fill KV cache with new key/value if provided
        if key is not None and value is not None:
            self._fill_kv_cache_impl(
                key,
                value,
                k_cache=k_cache,
                v_cache=v_cache,
                attn_metadata=attn_metadata,
                max_q_seqlen=max_q_seqlen,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
            )

        # Validate alibi configuration
        if self.alibi:
            assert self.alibi_slopes is not None, 'alibi_slopes is not set.'

        # Dispatch to stage-specific forward method
        if attn_metadata.is_decoding:
            return self._forward_decoding(
                query,
                k_cache,
                v_cache,
                attn_metadata,
                max_q_seqlen,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
                learnable_sink=learnable_sink,
            )
        else:
            return self._forward_prefill(
                query,
                k_cache,
                v_cache,
                attn_metadata,
                max_q_seqlen,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
                learnable_sink=learnable_sink,
            )


================================================
FILE: lmdeploy/pytorch/backends/cuda/attention/fa3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata

logger = get_logger('lmdeploy')


class FA3Impl(TritonAttentionImpl):
    """Flash Attention 3 implementation.

    This implementation leverages Flash Attention 3's optimized kernels for both
    prefill and decoding stages. FA3 provides significant performance improvements
    on Hopper architecture (SM90) with CUDA >= 12.3.

    Key features:
    - Optimized prefill using flash_attn_varlen_func
    - Speculative decoding support with multi-token queries
    - Standard single-token decoding with paged attention
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: tuple = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        **kwargs,
    ):
        assert alibi is False, 'alibi not supported for FA3'
        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_size=v_head_size,
            alibi=alibi,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            causal=causal,
            **kwargs,
        )
        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
        self.flash_attn_varlen_func_v3 = flash_attn_varlen_func
        self.flash_attn_with_kvcache_v3 = flash_attn_with_kvcache

    def _get_max_q_seqlen(
        self,
        query: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ) -> int:
        """Get max q seqlen."""
        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
        if attn_metadata.is_decoding:
            batch_size = attn_metadata.q_seqlens.size(0)
            max_q_seqlen = max_q_seqlen // batch_size
        return max_q_seqlen

    def _normalize_sliding_window(self, sliding_window):
        """Normalize sliding window to tuple format.

        Args:
            sliding_window: Sliding window size (None, int, or tuple).

        Returns:
            Tuple of (left_window, right_window) or (-1, -1) if None.
        """
        if sliding_window is None:
            return (-1, -1)
        if isinstance(sliding_window, int):
            return (sliding_window, sliding_window)
        return sliding_window

    def _decoding_speculative(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
    ) -> torch.Tensor:
        """Speculative decoding with multi-token queries.

        This path handles speculative decoding where multiple tokens are generated
        in parallel (max_q_seqlen > 1). Uses FA3's flash_attn_with_kvcache for
        efficient batched computation.

        Args:
            query: Query tensor to unflatten.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length (> 1).

        Returns:
            Attention output tensor.
        """
        block_offsets = attn_metadata.block_offsets
        sliding_window = self._normalize_sliding_window(self.sliding_window)

        # Reshape query for batched processing
        query = query.unflatten(0, (-1, max_q_seqlen))

        attn_output = self.flash_attn_with_kvcache_v3(
            query,
            k_cache,
            v_cache,
            cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
            max_seqlen_q=max_q_seqlen,
            scheduler_metadata=attn_metadata.scheduler_metadata,
            page_table=block_offsets,
            softmax_scale=self.scale,
            causal=self.causal,
            window_size=sliding_window,
            softcap=self.logit_softcapping,
        )
        return attn_output

    def _decoding_standard(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
    ) -> torch.Tensor:
        """Standard single-token decoding.

        This path handles standard decoding where only one token is generated
        per request (max_q_seqlen = 1). Uses paged attention for memory efficiency.

        Args:
            query: Query tensor (single token per request).
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length (= 1).
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.

        Returns:
            Attention output tensor.
        """
        block_offsets = attn_metadata.block_offsets
        quant_policy = attn_metadata.quant_policy

        attn_output = self.paged_attention_fwd(
            query,
            k_cache,
            v_cache,
            cache_seqlens=attn_metadata.kv_seqlens,
            page_table=block_offsets,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            max_seqlen_q=max_q_seqlen,
            scheduler_metadata=attn_metadata.scheduler_metadata,
            softmax_scale=self.scale,
            causal=self.causal,
            softcap=self.logit_softcapping,
            window_size=self.sliding_window,
            # custom args
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
            quant_policy=quant_policy,
        )
        return attn_output

    def _forward_decoding(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for decoding stage.

        Supports two decoding modes:
        1. Speculative decoding: Multiple tokens (max_q_seqlen > 1)
        2. Standard decoding: Single token (max_q_seqlen = 1)

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.

        Returns:
            Attention output tensor.
        """
        if max_q_seqlen > 1:
            return self._decoding_speculative(query, k_cache, v_cache, attn_metadata, max_q_seqlen)
        else:
            return self._decoding_standard(query, k_cache, v_cache, attn_metadata, max_q_seqlen, k_scales_zeros,
                                           v_scales_zeros)

    def _forward_prefill(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        max_q_seqlen: int,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for prefill stage.

        Uses FA3's flash_attn_varlen_func for efficient variable-length attention
        computation during the prefill phase.

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            max_q_seqlen: Maximum query sequence length.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.

        Returns:
            Attention output tensor.
        """
        block_offsets = attn_metadata.block_offsets
        kv_start_loc = attn_metadata.kv_start_loc
        kv_seqlens = attn_metadata.kv_seqlens
        kv_flatten_size = attn_metadata.kv_flatten_size
        quant_policy = attn_metadata.quant_policy

        # Flatten KV cache for varlen attention
        flatten_k, flatten_v = self.flatten_kv_cache(
            k_cache,
            v_cache,
            kv_seqlens,
            block_offsets,
            start_loc=kv_start_loc,
            out_size=kv_flatten_size,
            out_dtype=query.dtype,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
            quant_policy=quant_policy,
            flatten_kv_layout='shd',
        )

        sliding_window = self._normalize_sliding_window(self.sliding_window)

        attn_output = self.flash_attn_varlen_func_v3(
            q=query,
            k=flatten_k,
            v=flatten_v,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            cu_seqlens_k=attn_metadata.cu_seqlens_k,
            max_seqlen_q=max_q_seqlen,
            max_seqlen_k=kv_flatten_size,
            softmax_scale=self.scale,
            causal=self.causal,
            window_size=sliding_window,
            softcap=self.logit_softcapping,
        )
        return attn_output

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        learnable_sink: torch.Tensor = None,
        inplace: bool = True,
    ) -> torch.Tensor:
        """Forward pass for FA3 attention computation.

        This method handles both prefill and decoding stages by:
        1. Computing max query sequence length
        2. Filling KV cache if new key/value are provided
        3. Dispatching to appropriate stage-specific method

        Architecture:
        - Decoding: Supports both speculative (multi-token) and standard (single-token)
        - Prefill: Uses flash_attn_varlen_func for efficient varlen attention

        Args:
            query: Query tensor.
            key: Key tensor (None for decoding-only).
            value: Value tensor (None for decoding-only).
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata containing stage info and indices.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.
            learnable_sink: Learnable sink tokens (unused in FA3).
            inplace: Whether to modify query inplace (unused, kept for compatibility).

        Returns:
            Attention output tensor.
        """
        # Shared preparation
        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)

        # Fill KV cache with new key/value if provided
        if key is not None and value is not None:
            self._fill_kv_cache_impl(
                key,
                value,
                k_cache=k_cache,
                v_cache=v_cache,
                attn_metadata=attn_metadata,
                max_q_seqlen=max_q_seqlen,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
            )

        # Dispatch to stage-specific forward method
        if attn_metadata.is_decoding:
            return self._forward_decoding(
                query,
                k_cache,
                v_cache,
                attn_metadata,
                max_q_seqlen,
                k_scales_zeros,
                v_scales_zeros,
            )
        else:
            return self._forward_prefill(
                query,
                k_cache,
                v_cache,
                attn_metadata,
                max_q_seqlen,
                k_scales_zeros,
                v_scales_zeros,
            )


================================================
FILE: lmdeploy/pytorch/backends/cuda/attention/mla.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import functools

import torch

from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata

logger = get_logger('lmdeploy')


def _cdiv(a, b):
    """Perform div up."""
    return (a + b - 1) // b


def _try_dynamic_compile(func, *args, **kwargs):
    """Try compile."""
    try:
        compiled_func = torch.compile(func, dynamic=True)
        compiled_func(*args, **kwargs)
        return compiled_func
    except Exception:
        return func


class NSAIndicesUpdater:
    """NSA indices updater.

    Flash MLA sparse attention requires different indice format for prefill and decoding. This module is used to update
    the indices to meet the requirements.
    """

    def __init__(self):
        self._update_decode_func = None
        self._update_prefill_func = None

    def _update_decode_impl(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor,
                            block_size: int) -> torch.Tensor:
        """Update for decode impl."""
        block_ids = nsa_indices // block_size
        block_ids = block_ids.clamp_min(0)
        block_ids = block_offsets.gather(1, block_ids)
        block_remain = nsa_indices % block_size
        ret = block_ids * block_size + block_remain
        ret[nsa_indices < 0] = -1
        return ret[:, None]

    def update_decode(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor, block_size: int) -> torch.Tensor:
        """Update for decode."""
        if self._update_decode_func is None:
            self._update_decode_func = _try_dynamic_compile(self._update_decode_impl, nsa_indices, block_offsets,
                                                            block_size)

        return self._update_decode_func(nsa_indices, block_offsets, block_size)

    def _update_prefill_impl(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):
        """Update for prefill impl."""
        num_tokens = nsa_indices.size(0)
        repeat_cu_seqlens_k = torch.repeat_interleave(cu_seqlens_k[:-1], q_seqlens, output_size=num_tokens)
        neg_mask = nsa_indices < 0
        nsa_indices = nsa_indices + repeat_cu_seqlens_k[:, None]
        nsa_indices[neg_mask] = -1
        return nsa_indices[:, None]

    def update_prefill(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):
        """Update for prefill."""
        if self._update_prefill_func is None:
            self._update_prefill_func = _try_dynamic_compile(self._update_prefill_impl, nsa_indices, q_seqlens,
                                                             cu_seqlens_k)

        return self._update_prefill_func(nsa_indices, q_seqlens, cu_seqlens_k)

    @staticmethod
    @functools.lru_cache(maxsize=None)
    def build():
        return NSAIndicesUpdater()


class FlashMLAImpl(TritonAttentionImpl):
    """Flash MLA (Multi-head Latent Attention) implementation.

    This implementation supports multiple execution paths:
    - Decoding: Uses flash_mla_with_kvcache with paged KV cache
    - Prefill with NSA: Uses flash_mla_sparse_fwd for sparse attention
    - Prefill with FA3: Uses flash_attn_varlen_func with split q_rope/q_nope
    - Prefill fallback: Uses custom Triton kernel
    """

    # MLA-specific constants
    _MLA_HEAD_ALIGNMENT = 64  # Query heads must be multiple of 64 for flash_mla
    _MLA_NOPE_SIZE = 512  # Size of non-positional embeddings
    _MLA_SCALE_SIZE = 16  # Size of FP8 quantization scales

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: tuple = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        use_fa3: bool = False,
        **kwargs,
    ):
        assert (sliding_window is None
                or all(win == -1 for win in sliding_window)), ('sliding window not supported for FlashMLA')
        assert alibi is False, 'alibi not supported for FlashMLA'
        if logit_softcapping > 0.0:
            logger.warning('logit_softcapping not properly supported for FlashMLA, using -1.0')
            logit_softcapping = -1.0
        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_size=v_head_size,
            alibi=alibi,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            causal=causal,
            **kwargs,
        )

        import flash_mla

        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8
        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8
        self.flash_mla_with_kvcache = flash_mla.flash_mla_with_kvcache
        self.flash_mla_sparse_fwd = None
        self.fill_kv_cache_blocked_fp8 = fill_kv_cache_blocked_fp8
        self.flatten_kv_cache_mla_fp8 = flatten_kv_cache_mla_fp8
        assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'
        self.use_fa3 = use_fa3

        self.nsa_updater = NSAIndicesUpdater.build()

    def _get_flash_mla_sparse_fwd(self):
        if self.flash_mla_sparse_fwd is not None:
            return self.flash_mla_sparse_fwd

        try:
            import flash_mla
            self.flash_mla_sparse_fwd = flash_mla.flash_mla_sparse_fwd
            return self.flash_mla_sparse_fwd
        except Exception:
            logger.exception('Can not import flash_mla_sparse_fwd from flash_mla.')

    def flash_mla_decoding(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        nsa_indices: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ):
        """Flash mla decoding."""
        causal = self.causal
        kv_seqlens = attn_metadata.kv_seqlens
        block_offsets = attn_metadata.block_offsets
        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn

        q_seqlens = attn_metadata.q_seqlens
        batch_size = q_seqlens.size(0)
        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
        max_q_seqlen = max_q_seqlen // batch_size
        query = query.unflatten(0, (batch_size, max_q_seqlen))
        if kv_seqlens.dtype == torch.int64:
            kv_seqlens = kv_seqlens.to(torch.int32)

        # update nsa indice according to flash-mla requirement
        if nsa_indices is not None:
            block_size = k_cache.size(1)
            nsa_indices = self.nsa_updater.update_decode(nsa_indices, block_offsets, block_size)
            causal = False

        attn_output, _ = self.flash_mla_with_kvcache(query,
                                                     k_cache=k_cache,
                                                     block_table=block_offsets,
                                                     cache_seqlens=kv_seqlens,
                                                     head_dim_v=self.v_head_size,
                                                     softmax_scale=self.scale,
                                                     tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
                                                     num_splits=attn_metadata.num_splits,
                                                     causal=causal,
                                                     is_fp8_kvcache=is_fp8_kvcache,
                                                     indices=nsa_indices)

        attn_output = attn_output.flatten(0, 1)
        return attn_output

    def _prefill_sparse(self, query: torch.Tensor, flatten_k: torch.Tensor, nsa_indices: torch.Tensor,
                        attn_metadata: TritonAttentionMetadata) -> torch.Tensor:
        """Sparse prefill using flash_mla_sparse_fwd.

        This path is used when NSA (Non-contiguous Sparse Attention) indices are provided.
        Requires FP8 KV cache and flash_mla library.

        Args:
            query: Query tensor.
            flatten_k: Flattened key cache.
            nsa_indices: Sparse attention indices.
            attn_metadata: Attention metadata.

        Returns:
            Attention output tensor.
        """
        q_seqlens = attn_metadata.q_seqlens
        flash_mla_sparse_fwd = self._get_flash_mla_sparse_fwd()

        num_q_heads = query.size(1)
        # flash_mla_sparse_fwd requires query heads to be multiple of alignment
        if num_q_heads % self._MLA_HEAD_ALIGNMENT != 0:
            padding = self._MLA_HEAD_ALIGNMENT - num_q_heads % self._MLA_HEAD_ALIGNMENT
            query = torch.nn.functional.pad(query, (0, 0, 0, padding))

        nsa_indices = self.nsa_updater.update_prefill(nsa_indices, q_seqlens, attn_metadata.cu_seqlens_k)
        output = flash_mla_sparse_fwd(
            query,
            flatten_k,
            nsa_indices,
            sm_scale=self.scale,
        )
        attn_output = output[0]
        attn_output = attn_output[:, :num_q_heads]
        return attn_output

    def _prefill_triton(
        self,
        query: torch.Tensor,
        flatten_k: torch.Tensor,
        flatten_v: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ) -> torch.Tensor:
        """Triton-based prefill fallback.

        This is the fallback path when Flash Attention 3 is not available.
        Uses custom Triton kernel for attention computation.

        Args:
            query: Query tensor.
            flatten_k: Flattened key cache.
            flatten_v: Flattened value cache.
            attn_metadata: Attention metadata.

        Returns:
            Attention output tensor.
        """
        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))

        attn_output = self.flash_attention_fwd(
            query,
            flatten_k,
            flatten_v,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            cu_seqlens_k=attn_metadata.cu_seqlens_k,
            max_seqlen_q=max_q_seqlen,
            max_seqlen_k=attn_metadata.max_kv_seqlen,
            window_size=self.sliding_window,
            softmax_scale=self.scale,
            softcap=self.logit_softcapping,
            causal=self.causal,
        )

        return attn_output

    def _prefill_fa3(
        self,
        query: torch.Tensor,
        flatten_k: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ) -> torch.Tensor:
        """Flash Attention 3 optimized prefill.

        This path uses Flash Attention 3's optimized kernels with split
        rope (positional) and nope (non-positional) components.

        Args:
            query: Query tensor.
            flatten_k: Flattened key cache.
            attn_metadata: Attention metadata.

        Returns:
            Attention output tensor.
        """
        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
        kv_flatten_size = attn_metadata.kv_flatten_size
        causal = self.causal

        # Split query and key into rope (positional) and nope (non-positional) parts
        q_rope = query[:, :, self.v_head_size:]
        q_nope = query[:, :, :self.v_head_size]
        k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]
        c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]
        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
        attn_output = flash_attn_varlen_func(
            q=q_rope,
            k=k_rope,
            v=c_kv,
            qv=q_nope,
            cu_seqlens_q=attn_metadata.cu_seqlens_q,
            cu_seqlens_k=attn_metadata.cu_seqlens_k,
            max_seqlen_q=max_q_seqlen,
            max_seqlen_k=kv_flatten_size,
            softmax_scale=self.scale,
            causal=causal,
            window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
        )
        return attn_output

    def run_flatten_kv_cache(self,
                             k_cache: torch.Tensor,
                             v_cache: torch.Tensor,
                             attn_metadata: TritonAttentionMetadata,
                             out_dtype: torch.dtype,
                             is_nsa: bool,
                             k_scales_zeros: torch.Tensor = None,
                             v_scales_zeros: torch.Tensor = None):
        """Flatten kv cache for prefill."""

        kv_start_loc = attn_metadata.kv_start_loc
        kv_seqlens = attn_metadata.kv_seqlens
        block_offsets = attn_metadata.block_offsets
        kv_flatten_size = attn_metadata.kv_flatten_size
        quant_policy = attn_metadata.quant_policy
        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
        BLOCK_BS = k_cache.size(1)

        # pad one more block to avoid invalid kv visit
        if self.use_fa3 or is_nsa:
            out_size = kv_flatten_size
            flatten_kv_layout = 'shd'
        else:
            out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)
            flatten_kv_layout = 'hsd'

        if is_fp8_kvcache:
            flatten_k = self.flatten_kv_cache_mla_fp8(
                k_cache,
                kv_seqlens,
                block_offsets,
                start_loc=kv_start_loc,
                out_size=out_size,
                out_dtype=out_dtype,
                flatten_kv_layout=flatten_kv_layout,
            )
            flatten_v = flatten_k[..., :self._MLA_NOPE_SIZE]
        else:
            flatten_k, flatten_v = self.flatten_kv_cache(
                k_cache,
                v_cache,
                kv_seqlens,
                block_offsets,
                start_loc=kv_start_loc,
                out_size=out_size,
                out_dtype=out_dtype,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
                quant_policy=quant_policy,
                flatten_kv_layout=flatten_kv_layout,
            )

        return flatten_k, flatten_v

    def _get_max_q_seqlen(
        self,
        query: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
    ) -> int:
        """Get max q seqlen."""
        q_seqlens = attn_metadata.q_seqlens
        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
        batch_size = q_seqlens.size(0)
        if attn_metadata.is_decoding:
            max_q_seqlen = max_q_seqlen // batch_size
        return max_q_seqlen

    def _fill_kv_cache_impl(self,
                            key: torch.Tensor,
                            value: torch.Tensor,
                            k_cache: torch.Tensor,
                            v_cache: torch.Tensor,
                            attn_metadata: TritonAttentionMetadata,
                            max_q_seqlen: int,
                            k_scales_zeros: torch.Tensor = None,
                            v_scales_zeros: torch.Tensor = None):
        """Fill kv cache."""
        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
        if not is_fp8_kvcache:
            return super()._fill_kv_cache_impl(
                key,
                value,
                k_cache,
                v_cache,
                attn_metadata,
                max_q_seqlen,
                k_scales_zeros=k_scales_zeros,
                v_scales_zeros=v_scales_zeros,
            )

        block_offsets = attn_metadata.block_offsets
        kv_seqlens = attn_metadata.kv_seqlens
        quant_policy = attn_metadata.quant_policy
        assert quant_policy == 0

        # fill seqlen args
        fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(
            key,
            attn_metadata,
            max_q_seqlen,
        )

        # Split k_cache into nope, scale, and pe components
        scale_offset = self._MLA_NOPE_SIZE
        scale_end = scale_offset + self._MLA_SCALE_SIZE
        k_cache_scale = k_cache[..., scale_offset:scale_end].view(torch.float32)
        k_cache_nope = k_cache[..., :self._MLA_NOPE_SIZE]
        k_cache_pe = k_cache[..., scale_end:].view(key.dtype)
        self.fill_kv_cache_blocked_fp8(
            key[..., :self._MLA_NOPE_SIZE],
            None,
            k_cache_nope,
            None,
            k_cache_scale,
            None,
            cu_seqlen_q=attn_metadata.cu_seqlens_q,
            kv_seqlens=attn_metadata.kv_seqlens,
            max_q_seqlen=max_q_seqlen,
            block_offsets=block_offsets,
            group_size=128,
            scale_fmt='ue8m0',
        )
        self.fill_kv_cache(
            key[..., self._MLA_NOPE_SIZE:],
            None,
            k_cache_pe,
            None,
            fill_q_start_loc,
            fill_seqlens,
            kv_seq_length=kv_seqlens,
            max_q_seq_length=fill_max_q_seqlen,
            block_offsets=block_offsets,
        )

    def _forward_decoding(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        nsa_indices: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for decoding stage.

        Uses flash_mla_with_kvcache for efficient decoding with paged KV cache.
        Supports both regular and sparse (NSA) attention patterns.

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            attn_metadata: Attention metadata.
            nsa_indices: Optional sparse attention indices.

        Returns:
            Attention output tensor.
        """
        return self.flash_mla_decoding(query, k_cache, nsa_indices, attn_metadata)

    def _forward_prefill(
        self,
        query: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        nsa_indices: torch.Tensor = None,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward pass for prefill stage.

        Supports three execution paths:
        1. Sparse (NSA + FP8): flash_mla_sparse_fwd for sparse attention
        2. FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope
        3. Triton fallback: Custom Triton kernel implementation

        Args:
            query: Query tensor.
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata.
            nsa_indices: Optional sparse attention indices.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.

        Returns:
            Attention output tensor.
        """
        # Flatten KV cache once for all prefill paths
        flatten_k, flatten_v = self.run_flatten_kv_cache(
            k_cache,
            v_cache,
            attn_metadata,
            out_dtype=query.dtype,
            is_nsa=nsa_indices is not None,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
        )

        # Dispatch to appropriate prefill implementation
        if nsa_indices is not None:
            return self._prefill_sparse(query, flatten_k, nsa_indices, attn_metadata)
        elif self.use_fa3:
            return self._prefill_fa3(query, flatten_k, attn_metadata)
        else:
            return self._prefill_triton(query, flatten_k, flatten_v, attn_metadata)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        nsa_indices: torch.Tensor = None,
        **kwargs,
    ) -> torch.Tensor:
        """Forward pass for MLA attention computation.

        This method handles both prefill and decoding stages by:
        1. Validating NSA requirements (FP8 KV cache)
        2. Computing max query sequence length
        3. Filling KV cache if new key/value are provided
        4. Dispatching to appropriate stage-specific method

        Architecture:
        - Decoding: Uses flash_mla_with_kvcache with paged KV cache
        - Prefill: Three paths based on availability and requirements
          * Sparse (NSA + FP8): flash_mla_sparse_fwd
          * FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope
          * Triton fallback: Custom triton kernel

        Args:
            query: Query tensor.
            key: Key tensor (None for decoding-only).
            value: Value tensor (None for decoding-only).
            k_cache: Key cache tensor.
            v_cache: Value cache tensor.
            attn_metadata: Attention metadata containing stage info and indices.
            k_scales_zeros: Key quantization scales/zeros.
            v_scales_zeros: Value quantization scales/zeros.
            nsa_indices: Optional sparse attention indices.

        Returns:
            Attention output tensor.
        """
        # Validate NSA requirements
        is_nsa = nsa_indices is not None
        if is_nsa:
            is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
            assert is_fp8_kvcache, 'NSA sparse attention requires FP8 KV cache'

        # Shared preparation
        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)

        # Fill KV cache with new key/value if provided
        self._fill_kv_cache_impl(
            key,
            value,
            k_cache,
            v_cache,
            attn_metadata,
            max_q_seqlen,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
        )

        # Dispatch to stage-specific forward method
        if attn_metadata.is_decoding:
            return self._forward_decoding(query, k_cache, attn_metadata, nsa_indices)
        else:
            return self._forward_prefill(
                query,
                k_cache,
                v_cache,
                attn_metadata,
                nsa_indices,
                k_scales_zeros,
                v_scales_zeros,
            )


================================================
FILE: lmdeploy/pytorch/backends/cuda/awq_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch

import lmdeploy.pytorch.distributed as dist

from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl


def wq_gemm_forward(
    x,
    qweight,
    qzeros,
    scales,
    w_bit=4,
    group_size=128,
    bias=None,
    out_features=0,
):
    """Wq gemm forward."""
    from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear
    out_shape = x.shape[:-1] + (out_features, )
    input_dtype = x.dtype
    if input_dtype != torch.float16:
        x = x.half()

    x = x.flatten(0, -2)
    out = awq_linear(x, qweight, scales, qzeros)

    out = out + bias if bias is not None else out
    out = out.reshape(out_shape)

    # always want 3D tensor if tensor is 2D
    if len(out.shape) == 2:
        out = out.unsqueeze(0)

    if input_dtype != torch.float16:
        out = out.to(dtype=input_dtype)
    return out


class AwqLinearW4A16Impl(LinearW4A16Impl):
    """Awq kernel linear."""

    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size

    def forward(self,
                x,
                qweight: torch.Tensor,
                scales: torch.Tensor,
                qzeros: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        out_features = scales.size(1)
        out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
        if all_reduce:
            dist.all_reduce(out, group=group)
        return out


class AwqLinearW4A16Builder(LinearW4A16Builder):
    """Awq linear builder."""

    @staticmethod
    def build(in_features: int,
              out_features: int,
              w_bit: int,
              group_size: int,
              bias: bool = False,
              dtype: torch.dtype = None):
        """build."""
        return AwqLinearW4A16Impl(in_features, out_features, w_bit, group_size)


================================================
FILE: lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8, deep_gemm_fp8, quant_fp8, quant_fp8_tma
from lmdeploy.utils import get_logger

from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl
from .warmup_manager import WarmupMeta, get_warmup_manager

logger = get_logger('lmdeploy')


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
    """Triton linear blocked f8 implementation."""

    def __init__(self, in_features: int, out_features: int, block_size: int, out_dtype: torch.dtype = torch.float16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.out_dtype = out_dtype
        self.block_size = block_size

    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[dist.ProcessGroup] = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        x_shape = x.shape
        x = x.flatten(0, -2)
        input_quant, input_scale = quant_fp8(x,
                                             self.block_size,
                                             dtype=weight.dtype,
                                             trans_scale=True,
                                             scale_fmt=self.scale_fmt)

        out = blocked_gemm_fp8(input_quant, input_scale, weight.t(), scale.t(), out_dtype=x.dtype)
        if bias is not None:
            out += bias

        out = out.unflatten(0, x_shape[:-1])

        if all_reduce:
            if scatter_size is not None:
                out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
            else:
                dist.all_reduce(out)
        return out


class TritonLinearBlockedF8Builder(LinearBlockedF8Builder):
    """Triton linear blocked f8 implementation builder."""

    @staticmethod
    def build(in_features: int, out_features: int, block_size: int = 128, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        try:
            import deep_gemm  # noqa
            logger.debug('build with DeepGemmLinearBlockedF8Impl')
            return DeepGemmLinearBlockedF8Impl(in_features, out_features, block_size, dtype)
        except:  # noqa
            logger.warning('Failed to import deep_gemm, LinearBlockedF8 fallback to triton implementation.')
            return TritonLinearBlockedF8Impl(in_features, out_features, block_size, dtype)


class DeepGemmLinearBlockedF8Impl(LinearBlockedF8Impl):
    """Deep gemm blocked f8 implementation."""

    def __init__(self, in_features: int, out_features: int, block_size: int, out_dtype: torch.dtype = torch.float16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.out_dtype = out_dtype
        self.block_size = block_size

        warmup_mgr = get_warmup_manager()
        key = ('deepgemm_blockedfp8_gemm_'
               f'{in_features}_{out_features}_{block_size}_{out_dtype}')
        if key not in warmup_mgr:
            warmup_mgr[key] = self.warmup

    def warmup(self, warmup_meta: WarmupMeta):
        """warmup."""
        import random

        from lmdeploy.pytorch.third_party.deep_gemm import get_m_alignment_for_contiguous_layout
        device = 'cuda'
        max_num_tokens = warmup_meta.max_num_tokens
        alignment = get_m_alignment_for_contiguous_layout()
        range_end = max_num_tokens + alignment - 1
        k, n = self.in_features, self.out_features
        block_size = self.block_size
        weight = torch.empty(n, k, dtype=torch.float8_e4m3fn, device=device)
        scale = torch.empty(((n + block_size - 1) // block_size, (k + block_size - 1) // block_size),
                            dtype=torch.float32,
                            device=device)
        # shuffle ranges so ranks might compile different kernels concurrently.
        ranges = list(range(alignment, range_end, alignment))
        random.shuffle(ranges)
        for m in ranges:
            inputs = torch.empty(m, k, dtype=self.out_dtype, device=device)
            input_quant, input_scale = quant_fp8_tma(inputs, self.block_size, dtype=weight.dtype)
            deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=inputs.dtype)

    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[dist.ProcessGroup] = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        x_shape = x.shape
        x = x.flatten(0, -2)
        input_quant, input_scale = quant_fp8_tma(x, self.block_size, dtype=weight.dtype, scale_fmt=self.scale_fmt)

        out = deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=x.dtype)
        out = out[:x.size(0)]
        if bias is not None:
            out += bias
        out = out.unflatten(0, x_shape[:-1])

        if all_reduce:
            if scatter_size is not None:
                out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
            else:
                dist.all_reduce(out, group=group)
        return out


================================================
FILE: lmdeploy/pytorch/backends/cuda/causal_conv1d.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache

import torch

from ..causal_conv1d import CausalConv1dBuilder, CausalConv1dImpl
from .utils import has_tilelang


class CausalConv1dTilelangImpl(CausalConv1dImpl):
    """CausalConv1d update implementation."""

    def __init__(self):
        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
        self.causal_conv1d_fn = causal_conv1d_fn
        self.causal_conv1d_update = causal_conv1d_update

    def conv1d_fn(self,
                  x: torch.Tensor,
                  weight: torch.Tensor,
                  bias: torch.Tensor | None = None,
                  seq_idx: torch.Tensor | None = None,
                  return_final_states: bool = False,
                  activation: str | None = None):
        return self.causal_conv1d_fn(x,
                                     weight,
                                     bias=bias,
                                     seq_idx=seq_idx,
                                     return_final_states=return_final_states,
                                     activation=activation)

    def update_fn(self,
                  x: torch.Tensor,
                  conv_state: torch.Tensor,
                  weight: torch.Tensor,
                  bias: torch.Tensor | None = None,
                  activation: str | None = None,
                  conv_state_indices: torch.Tensor | None = None):
        """Update conv state."""
        return self.causal_conv1d_update(x,
                                         conv_state,
                                         weight,
                                         bias=bias,
                                         activation=activation,
                                         conv_state_indices=conv_state_indices)


class CausalConv1dDaoImpl(CausalConv1dTilelangImpl):

    def __init__(self):
        try:
            import causal_conv1d
            self.causal_conv1d_fn = causal_conv1d.causal_conv1d_fn
            self.causal_conv1d_update = causal_conv1d.causal_conv1d_update
        except Exception:
            raise RuntimeError(
                'causal_conv1d is not installed, please refer to https://github.com/Dao-AILab/causal-conv1d')


@lru_cache
def has_dao():
    try:
        import causal_conv1d  # noqa: F401
        causal_conv1d_fn = causal_conv1d.causal_conv1d_fn  # noqa: F841
        causal_conv1d_update = causal_conv1d.causal_conv1d_update  # noqa: F841
        return True
    except Exception:
        return False


class CausalConv1dCudaBuilder(CausalConv1dBuilder):
    """CausalConv1d update implementation builder."""

    @staticmethod
    def build() -> CausalConv1dImpl:
        """build."""
        if has_tilelang():
            return CausalConv1dTilelangImpl()
        elif has_dao():
            return CausalConv1dDaoImpl()
        else:
            raise RuntimeError('No available implementation for CausalConv1d, '
                               'please install https://tilelang.com/ or https://github.com/Dao-AILab/causal-conv1d')


================================================
FILE: lmdeploy/pytorch/backends/cuda/flash_attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class TritonFlashAttentionImpl(FlashAttentionImpl):
    """Triton flash attention implementation."""

    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = None,
    ):
        if scale is None:
            scale = 1.0 / (head_dim**0.5)

        if num_kv_heads is None:
            num_kv_heads = num_heads

        if v_head_dim is None:
            v_head_dim = head_dim

        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.v_head_dim = v_head_dim
        self.causal = causal
        self.sliding_window = sliding_window
        self.logit_softcapping = logit_softcapping

        from lmdeploy.pytorch.kernels.cuda import flash_attn_varlen_func
        self.flash_attention_fwd = flash_attn_varlen_func

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                q_start_loc: Tensor,
                q_seqlens: Tensor,
                kv_start_loc: Tensor,
                kv_seqlens: Tensor,
                max_q_seqlen: int = None):
        """forward."""
        out = self.flash_attention_fwd(
            query,
            key,
            value,
            q_start_loc=q_start_loc,
            q_seqlens=q_seqlens,
            kv_start_loc=kv_start_loc,
            kv_seqlens=kv_seqlens,
            max_seqlen_q=max_q_seqlen,
            window_size=self.sliding_window,
            softmax_scale=self.scale,
            softcap=self.logit_softcapping,
            causal=self.causal,
            kv_layout='shd',
        )

        return out


class TritonFlashAttentionBuilder(FlashAttentionBuilder):
    """Triton attention builder."""

    @staticmethod
    def build(
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = None,
        **kwargs,
    ) -> FlashAttentionImpl:
        """build."""
        return TritonFlashAttentionImpl(
            num_heads=num_heads,
            head_dim=head_dim,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_dim=v_head_dim,
            causal=causal,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
        )


================================================
FILE: lmdeploy/pytorch/backends/cuda/gated_delta_rule.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache

import torch

from ..gated_delta_rule import GatedDeltaRuleBuilder, GatedDeltaRuleImpl
from .utils import has_tilelang


@lru_cache
def has_fla():
    try:
        from fla.ops.gated_delta_rule import chunk_gated_delta_rule  # noqa: F401
        return True
    except Exception:
        return False


class CudaGatedDeltaRuleImpl(GatedDeltaRuleImpl):

    def __init__(self):
        if not has_fla() or not has_tilelang():
            raise ImportError('fla and tilelang is required for CudaGatedDeltaRuleImpl')
        from fla.ops.gated_delta_rule import chunk_gated_delta_rule

        from lmdeploy.pytorch.kernels.cuda.gated_delta_rule import fused_recurrent_gated_delta_rule
        self.chunk_func = chunk_gated_delta_rule
        self.recurrent_func = fused_recurrent_gated_delta_rule

    def chunk_gated_delta_rule(self,
                               q: torch.Tensor,
                               k: torch.Tensor,
                               v: torch.Tensor,
                               g: torch.Tensor | None = None,
                               beta: torch.Tensor | None = None,
                               initial_state: torch.Tensor | None = None,
                               state_indices: torch.Tensor | None = None,
                               scale: float | None = None,
                               use_qk_l2norm_in_kernel: bool = False,
                               cu_seqlens: torch.Tensor | None = None,
                               output_final_state: bool = False):

        assert initial_state is not None
        recurrent_state = initial_state
        init_state = recurrent_state.index_select(0, state_indices)
        if use_qk_l2norm_in_kernel:
            # l2norm in fla would recompile when seqlen changed.
            q = torch.nn.functional.normalize(q, p=2, dim=-1)
            k = torch.nn.functional.normalize(k, p=2, dim=-1)
        core_attn_out, last_state = self.chunk_func(
            q,
            k,
            v,
            g=g,
            beta=beta,
            scale=scale,
            initial_state=init_state,
            output_final_state=output_final_state,
            use_qk_l2norm_in_kernel=False,
            cu_seqlens=cu_seqlens,
        )

        last_state = recurrent_state.index_copy_(0, state_indices, last_state.to(recurrent_state.dtype))
        if not output_final_state:
            last_state = None
        return core_attn_out, last_state

    def fused_recurrent_gated_delta_rule(self,
                                         q: torch.Tensor,
                                         k: torch.Tensor,
                                         v: torch.Tensor,
                                         g: torch.Tensor | None = None,
                                         beta: torch.Tensor | None = None,
                                         initial_state: torch.Tensor | None = None,
                                         state_indices: torch.Tensor | None = None,
                                         scale: float | None = None,
                                         use_qk_l2norm_in_kernel: bool = False,
                                         output_final_state: bool = False):
        return self.recurrent_func(
            q,
            k,
            v,
            g=g,
            beta=beta,
            scale=scale,
            initial_state=initial_state,
            state_indices=state_indices,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
            output_final_state=output_final_state,
        )


class CudaGatedDeltaRuleBuilder(GatedDeltaRuleBuilder):

    @staticmethod
    def build() -> GatedDeltaRuleImpl:
        return CudaGatedDeltaRuleImpl()


================================================
FILE: lmdeploy/pytorch/backends/cuda/graph_runner.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Any, Dict, List, Tuple

import torch
from torch.profiler import record_function

from lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend
from lmdeploy.pytorch.backends.selector import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
from lmdeploy.pytorch.strategies.base import StrategyFactoryBase
from lmdeploy.utils import get_logger

from ..graph_runner import GraphRunner
from .attention import TritonAttentionMetadata

logger = get_logger('lmdeploy')


def next_power_of_2(n: int):
    """Return the smallest power of 2 greater than or equal to n."""
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n |= n >> 32
    n += 1
    return n


@functools.lru_cache
def _get_capture_batch_size_impl(max_batches: int):
    """Capture batch size."""
    ret = []
    batch_size = 1
    batch_step = 256
    # power of 2
    while batch_size <= min(batch_step, max_batches):
        ret.append(batch_size)
        batch_size *= 2

    # step
    ret += list(range(batch_size, max_batches + 1, batch_step))

    if max_batches != ret[-1]:
        ret.append(max_batches)
    return ret


def _false(*args, **kwargs):
    """Default value of not support cuda graph."""
    return False


class CUDASingleGraphRunner:
    """Cuda single graph runner."""

    def __init__(
        self,
        model: torch.nn.Module,
        max_batches: int,
        max_tokens: int,
        num_blocks: int,
        is_decoding: bool,
        pool: Tuple[int, int],
        model_config: ModelConfig,
        device: torch.device,
        decode_query_len: int = 1,
    ):
        self.model = model
        self.ctx_mgr = model.ctx_mgr
        self.model_config = model_config

        self.meta = CudaGraphMeta(
            max_batchs=max_batches,
            max_tokens=max_tokens,
            num_blocks=num_blocks,
            is_decoding=is_decoding,
            device=device,
            input_buffers=dict(),
            output_buffers=dict(),
            vocab_size=self.model_config.vocab_size,
            use_mla_fp8_cache=getattr(self.model_config, 'use_mla_fp8_cache', False),
            use_flash_mla=getattr(self.model_config, 'use_flash_mla', False),
            mla_index_topk=getattr(self.model_config, 'mla_index_topk', None),
            decode_query_len=decode_query_len,
            use_fa3_decoding=model_config.model_paradigm == 'ar_spec',
        )
        self.device = device
        self.max_batches = max_batches
        self.max_tokens = max_tokens
        self.num_blocks = num_blocks
        self.is_decoding = is_decoding
        self.pool = pool
        self._graph: torch.cuda.CUDAGraph = None

    @record_function('capture_cudagraph')
    def capture(self, **kwargs):
        """Capture graph."""
        logger.debug(f'Capturing graph with meta: {self.meta}')
        self.meta.input_buffers = self.model.make_buffers_cudagraph(self.meta, **kwargs)
        padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)
        context = self.ctx_mgr.current_context()
        self.model.update_context_cudagraph(self.meta, context)
        current_stream = torch.cuda.current_stream()

        # warmup
        warmup_output = self.model(**padded_kwargs)
        warmup_buffers = self.model.make_output_buffers(warmup_output)

        self._graph = torch.cuda.CUDAGraph()
        # unsafe kernel call in other thread might invalid the capture
        # so we set thread_safe capture mode here.
        with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):
            output = self.model(**padded_kwargs)

        output_buffers = self.model.make_output_buffers(output)
        self.meta.output_buffers = output_buffers
        output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs)
        return output

    @record_function('forward_cudagraph')
    def forward(self, **kwargs):
        """forward."""
        assert self._graph is not None
        self.model.fill_buffers_cudagraph(self.meta, **kwargs)
        context = self.ctx_mgr.current_context()
        self.model.update_context_cudagraph(self.meta, context)
        self._graph.replay()
        output_buffers = self.meta.output_buffers
        output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
        return output

    def __del__(self):
        """del."""
        del self._graph


class CUDAGraphRunner(GraphRunner):
    """Cuda graph runner."""

    def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                 backend_config: BackendConfig, device: torch.device):
        super().__init__(model, model_config, cache_config, backend_config, device)
        self.max_batches = cache_config.max_batches
        self.max_tokens = cache_config.max_prefill_token_num
        self.num_blocks = cache_config.num_gpu_blocks

        self.enable_graph = self.check_enable_graph()

        self.graph_pool_handle = torch.cuda.graph_pool_handle()
        self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
        self.has_try_compile_model: bool = False

        # strategy factory
        build_ctx = model.ctx_mgr.build_ctx
        strategy_factory: StrategyFactoryBase = build_ctx.strategy_factory
        self.cudagraph_strategy = strategy_factory.build_cudagraph_strategy()

    def check_enable_graph(self):
        """Check enable graph."""
        if self.backend_config.eager_mode:
            return _false

        return getattr(self.model, 'support_cuda_graph', _false)

    def _try_compile_model_once(self):
        if self.has_try_compile_model:
            return

        # TODO: recovery it when torch.compile is stable (should be add a flag to enable it?)
        # if hasattr(self.model, 'compile_model'):
        #     method = getattr(self.model, 'compile_model')
        #     method()

        self.has_try_compile_model = True

    def _get_capture_tokens(self, batch_size: int):
        """Get capture tokens."""
        cap_sizes = self.get_capture_batch_sizes()
        for size in cap_sizes:
            if size >= batch_size:
                return size
        assert False, f'Unsupported batch_size={batch_size}'

    def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
                      attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):
        """Get graph key."""
        context = self.ctx_mgr.current_context()
        is_decoding = context.is_decoding
        batch_size = attn_metadata.q_seqlens.size(0)
        meta = self.get_meta()
        enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
        # for draft model to distinguish inputs from target model and itself
        query_len = input_ids.size(1) // batch_size
        if meta.padding_batch_size is None:
            batch_size = self._get_capture_tokens(batch_size)
        else:
            batch_size = self._get_capture_tokens(meta.padding_batch_size)
        return (batch_size, is_decoding, enable_microbatch, query_len)

    def _prepare_inputs(self, **kwargs):
        """Prepare inputs."""
        assert 'attn_metadata' in kwargs, 'attn_metadata is required for cudagraph.'
        attn_metadata: TritonAttentionMetadata = kwargs['attn_metadata']
        if not attn_metadata.block_offsets.dtype == torch.int32:
            attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32)
        return kwargs

    def _get_max_tokens(self, graph_key: tuple, input_ids: torch.Tensor, q_seqlens: torch.Tensor):
        max_batches = graph_key[0]
        is_decoding = graph_key[1]
        assert is_decoding
        origin_batch_size = q_seqlens.size(0)
        num_tokens = input_ids.size(1)
        return self.cudagraph_strategy.get_max_tokens(max_batches, origin_batch_size, num_tokens)

    def __call__(self, **kwargs):
        """call."""
        if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda':
            self._try_compile_model_once()

        kwargs = self._prepare_inputs(**kwargs)
        enable_graph = self.enable_graph(**kwargs)

        if not enable_graph:
            with record_function('forward_eager'):
                output = self.model(**kwargs)
                return self.model.make_output_buffers(output)

        graph_key = self.get_graph_key(**kwargs)
        max_batches = graph_key[0]
        is_decoding = graph_key[1]
        decode_query_len = graph_key[3]
        if graph_key not in self._runner_map:
            max_tokens = self._get_max_tokens(graph_key, kwargs['input_ids'], kwargs['attn_metadata'].q_seqlens)
            runner = CUDASingleGraphRunner(
                self.model,
                max_batches=max_batches,
                max_tokens=max_tokens,
                num_blocks=self.num_blocks,
                is_decoding=is_decoding,
                pool=self.graph_pool_handle,
                model_config=self.model_config,
                device=self.device,
                decode_query_len=decode_query_len,
            )
            output = runner.capture(**kwargs)
            self._runner_map[graph_key] = runner
            # SSM would update the state in capture(warmup), replay the graph will leads unexpected state update.
            return output
        else:
            runner = self._runner_map[graph_key]
            output = runner.forward(**kwargs)
            return output

    @record_function('prepare_inputs_for_generation')
    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare inputs."""

        if get_moe_backend().use_deepep_moe_backend():
            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode
            deepep_mode = DeepEPMode.LOW_LATENCY if context.is_decoding else DeepEPMode.NORMAL
            DeepEPBuffer.set_deepep_mode(deepep_mode)

        return self.model.prepare_inputs_for_generation(
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            context=context,
        )

    def reset(self):
        """Remove all graphs to prevent hanging on exit."""
        self._runner_map.clear()
        if get_moe_backend().use_deepep_moe_backend():
            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer

            if hasattr(DeepEPBuffer, 'destroy'):
                from torch import distributed as dist

                DeepEPBuffer.destroy()
                dist.barrier()

    def update_inputs(self, inputs):
        """Update inputs."""
        if self.backend_config.eager_mode:
            return inputs
        is_decoding = inputs.is_decoding
        dp_meta = inputs.dp_meta
        if is_decoding and dp_meta is not None:
            meta = self.get_meta()
            padding_batch_size = meta.padding_batch_size
            tp_size = self._get_capture_tokens(padding_batch_size)
            dp_meta.sync_tp_size(tp_size)
        return inputs

    def get_capture_batch_sizes(self) -> List[int]:
        """Capture batch sizes."""
        return _get_capture_batch_size_impl(self.cache_config.max_batches)


================================================
FILE: lmdeploy/pytorch/backends/cuda/lora.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass

import torch

from lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora
from lmdeploy.pytorch.model_inputs import StepContextManager

from ..lora import AdapterInfo, LoRABuilder, LoRAImpl


@dataclass
class PackedLoRAInput:
    """Packed lora input."""
    x: torch.Tensor
    q_start_loc: torch.Tensor
    q_seqlens: torch.Tensor
    adapter_ids: torch.Tensor
    max_seq_len: int
    is_decoding: bool


class TritonLoRAImpl(LoRAImpl):
    """Triton lora implementation."""

    @staticmethod
    def _make_packed_lora_input(x, ctx_mgr):
        """Make PackedLoRAInput."""
        context = ctx_mgr.current_context()

        # adapter cache
        max_q_seq_length = x.numel() // x.size(-1)

        return PackedLoRAInput(x=x.flatten(0, -2).contiguous(),
                               q_start_loc=context.q_start_loc,
                               q_seqlens=context.q_seqlens,
                               adapter_ids=context.local_adapter_ids,
                               max_seq_len=max_q_seq_length,
                               is_decoding=context.is_decoding)

    def forward(self,
                x: torch.Tensor,
                lora_A: torch.Tensor,
                lora_B: torch.Tensor,
                base_output: torch.Tensor,
                adapter_info: AdapterInfo,
                ctx_mgr: StepContextManager,
                colwise: bool,
                is_tp: bool = True):
        """forward."""
        lora_input = self._make_packed_lora_input(x, ctx_mgr)

        base_slice = adapter_info.base_slice
        sliced_base = base_output[..., base_slice]

        if base_output.is_contiguous():
            kernel_output = sliced_base.flatten(0, -2)
            cum = True
        else:
            kernel_output = None
            cum = False
        lora_out = fused_lora(
            lora_input.x,
            lora_A,
            lora_B,
            scaling=adapter_info.scalings,
            rank_start=adapter_info.rank_offsets,
            ranks=adapter_info.ranks,
            seq_start=lora_input.q_start_loc,
            seq_lens=lora_input.q_seqlens,
            adapter_ids=lora_input.adapter_ids,
            max_rank=adapter_info.max_rank,
            max_seqlen=lora_input.max_seq_len,
            output=kernel_output,
            cum=cum,
        )

        if not base_output.is_contiguous():
            lora_out = lora_out.reshape(sliced_base.shape)
            sliced_base.add_(lora_out)
        return base_output


class TritonLoRABuilder(LoRABuilder):
    """Triton lora layer builder."""

    @staticmethod
    def build():
        """build."""
        return TritonLoRAImpl()


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .blocked_fp8 import TritonFusedMoEBlockedF8Builder  # noqa: F401
from .default import TritonFusedMoEBuilder  # noqa: F401
from .w8a8 import TritonFusedMoEW8A8Builder  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Callable, List

import torch
import torch.distributed as dist

from lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend
from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp

logger = get_logger('lmdeploy')


class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
    """Triton fused moe blocked f8 implementation."""

    def __init__(self,
                 top_k: int,
                 num_experts: int,
                 renormalize: bool = False,
                 block_size: int = 128,
                 out_dtype: torch.dtype = torch.float16):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.renormalize = renormalize
        self.block_size = block_size
        self.out_dtype = out_dtype

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        num_experts = self.num_experts
        expert_per_rank = (num_experts + world_size - 1) // world_size
        first_expert = rank * expert_per_rank
        last_expert = min(first_expert + expert_per_rank, num_experts)
        return list(range(first_expert, last_expert))

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                gate_up_scale: torch.Tensor,
                down_weights: torch.Tensor,
                down_scale: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        input_size = hidden_states.shape
        hidden_states = hidden_states.flatten(0, -2)
        input_quant, input_scale = quant_fp8(hidden_states,
                                             self.block_size,
                                             dtype=gate_up_weights.dtype,
                                             scale_fmt=self.scale_fmt)
        expert_offset = 0
        num_experts = None
        if expert_list is not None and len(expert_list) != self.num_experts:
            expert_offset = expert_list[0]
            num_experts = self.num_experts
        output = fused_moe_blocked_fp8(input_quant,
                                       input_scale,
                                       gate_up_weights,
                                       gate_up_scale,
                                       down_weights,
                                       down_scale,
                                       topk_weights=topk_weights,
                                       topk_ids=topk_ids,
                                       topk=self.top_k,
                                       w1_bias=gate_up_bias,
                                       w2_bias=down_bias,
                                       out_dtype=hidden_states.dtype,
                                       expert_offset=expert_offset,
                                       num_experts=num_experts,
                                       renormalize=self.renormalize,
                                       act_func=act_func)
        output = output.unflatten(0, input_size[:-1])
        return output


class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):

    def __init__(self,
                 ep_size: int,
                 ep_group: dist.ProcessGroup,
                 top_k: int,
                 num_experts: int,
                 hidden_dim: int,
                 renormalize: bool = False,
                 block_size: int = 128,
                 out_dtype: torch.dtype = torch.bfloat16,
                 layer_idx: int = 0):
        super().__init__(top_k, num_experts, renormalize, block_size, out_dtype)
        self.num_experts = num_experts
        self.ep_size = ep_size
        self.ep_group = ep_group
        self.hidden_dim = hidden_dim
        self.block_size = block_size
        self.out_dtype = out_dtype
        self.layer_idx = layer_idx
        try:
            import deep_gemm  # noqa: F401
            self.use_deep_gemm = True
        except ImportError:
            self.use_deep_gemm = False
            logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')

        try:
            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep  # noqa: F401
            get_moe_backend().set_deepep_moe_backend()
            if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):
                DeepEPBuffer.set_explicitly_destroy()
        except ImportError:
            logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')

        # pre-allocate buffer
        self.fusedmoe_build(True)

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        if get_dist_manager().current_context().dist_config.enable_eplb:
            from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer
            phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx)
            expert_per_rank = (self.num_experts + world_size - 1) // world_size
            first_expert = rank * expert_per_rank
            last_expert = min(first_expert + expert_per_rank, self.num_experts)
            sliced_phy2log = phy2log[first_expert:last_expert].tolist()
            return sliced_phy2log
        else:
            return super().ep_expert_list(world_size=world_size, rank=rank)

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                gate_up_scale: torch.Tensor,
                down_weights: torch.Tensor,
                down_scale: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None,
                **kwargs):
        """forward."""
        hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,
                                                                                    topk_ids)

        topk_weights = self.do_renormalize(topk_weights)
        step_ctx = get_step_ctx_manager().current_context()
        low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
        moe = self.fusedmoe_build(low_latency_mode)
        out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,
                                 down_scale, expert_list)

        out_states = gather_outputs_by_attn_tp(out_states, split_size)
        return out_states

    def do_renormalize(self, topk_weights):
        return _renormalize(topk_weights, self.renormalize)

    def fusedmoe_build(self, low_latency_mode: bool = False):
        from dlblas.layers.moe.ep_moe import build_deepep_moe
        deepep_moe = build_deepep_moe(low_latency_mode,
                                      self.ep_size,
                                      self.ep_group,
                                      self.num_experts,
                                      self.hidden_dim,
                                      self.block_size,
                                      self.top_k,
                                      self.out_dtype,
                                      layer_idx=self.layer_idx,
                                      chunk_size=16 * 1024)
        return deepep_moe


class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
    """Triton fused moe blocked f8 builder."""

    @staticmethod
    def build(top_k: int,
              num_experts: int,
              hidden_dim: int = 1,
              renormalize: bool = False,
              block_size: int = 128,
              ep_size: int = 1,
              ep_group: dist.ProcessGroup = None,
              out_dtype: torch.dtype = torch.float16,
              layer_idx: int = 0,
              custom_gateup_act: bool = False):
        """Build from mlp."""
        if ep_size > 1:
            assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.'
            return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size,
                                               ep_group=ep_group,
                                               top_k=top_k,
                                               num_experts=num_experts,
                                               hidden_dim=hidden_dim,
                                               renormalize=renormalize,
                                               block_size=block_size,
                                               out_dtype=out_dtype,
                                               layer_idx=layer_idx)
        else:
            return TritonFusedMoEBlockedF8Impl(top_k=top_k,
                                               num_experts=num_experts,
                                               renormalize=renormalize,
                                               block_size=block_size,
                                               out_dtype=out_dtype)


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe/default.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Callable, List, Optional

import torch

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend
from lmdeploy.pytorch.backends.moe import FusedMoEBuilder, FusedMoEImpl
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.kernels.cuda import fused_moe
from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp

logger = get_logger('lmdeploy')


class TritonFusedMoEImpl(FusedMoEImpl):
    """Triton fused moe implementation."""

    def __init__(self, top_k: int, num_experts: int, renormalize: bool = False):
        self.num_experts = num_experts
        self.top_k = top_k
        self.renormalize = renormalize

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
        gate_up_weights = gate_up_weights.transpose(1, 2).contiguous().transpose(1, 2)
        down_weights = down_weights.transpose(1, 2).contiguous().transpose(1, 2)
        return gate_up_weights, down_weights

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        num_experts = self.num_experts
        expert_per_rank = (num_experts + world_size - 1) // world_size
        first_expert = rank * expert_per_rank
        last_expert = min(first_expert + expert_per_rank, num_experts)
        return list(range(first_expert, last_expert))

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                down_weights: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        expert_offset = 0
        num_experts = None
        if expert_list is not None and len(expert_list) != self.num_experts:
            expert_offset = expert_list[0]
            num_experts = self.num_experts
        return fused_moe(hidden_states,
                         gate_up_weights,
                         down_weights,
                         topk_weights=topk_weights,
                         topk_ids=topk_ids,
                         topk=self.top_k,
                         w1_bias=gate_up_bias,
                         w2_bias=down_bias,
                         expert_offset=expert_offset,
                         num_experts=num_experts,
                         renormalize=self.renormalize,
                         act_func=act_func)


# modify from dlblas: https://github.com/DeepLink-org/DLBlas
class FusedMoENormal:

    def __init__(
        self,
        ep_size: int,
        ep_group: dist.ProcessGroup,
        num_experts: int,
        hidden_dim: int,
        layer_index: int = 0,
        top_k: int = 8,
        out_dtype: torch.dtype = torch.bfloat16,
    ):
        from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherNormal
        self.layer_index = layer_index
        self.top_k = top_k
        self.num_experts = num_experts
        self.num_local_experts = num_experts // ep_size
        self.out_dtype = out_dtype
        self.token_dispatcher = DeepEPTokenDispatcherNormal(
            group=ep_group,
            num_experts=num_experts,
            num_local_experts=self.num_local_experts,
            hidden_size=hidden_dim,
            params_dtype=out_dtype,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.LongTensor,
        up_weights: torch.Tensor,
        down_weights: torch.Tensor,
        expert_list: List[int] = None,
    ):
        """forward."""
        from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3
        x, recv_topk_ids, recv_topk_weights, recv_tokens_per_expert = self.token_dispatcher.dispatch(
            hidden_states,
            topk_ids,
            topk_weights,
            expert_list,
        )
        topk_ids, topk_weights = None, None
        out_states = fused_moe_v3(x, recv_topk_ids, recv_topk_weights, up_weights, down_weights, recv_tokens_per_expert)
        out_states = self.token_dispatcher.combine(out_states)
        return out_states

    def capture(self):
        return self.token_dispatcher.buffer_normal.capture()

    def wait(self, event):
        self.token_dispatcher.release()
        event.current_stream_wait()

    def dispatch_async(self,
                       x: torch.Tensor,
                       topk_idx: torch.Tensor,
                       topk_weights: torch.Tensor,
                       num_experts: Optional[int] = None,
                       previous_event=None,
                       async_finish=True):
        return self.token_dispatcher.dispatch_normal_async(x, topk_idx, topk_weights, num_experts, previous_event,
                                                           async_finish)

    def combine_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True):
        return self.token_dispatcher.combine_normal_async(x, handle, previous_event, async_finish)

    def release(self):
        return self.token_dispatcher.release()

    def fusedmoe_forward(self, state, up_weight, down_weight):
        from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3
        return fused_moe_v3(state['recv_hidden_states'], state['recv_topk_idx'], state['recv_topk_weights'], up_weight,
                            down_weight, state['recv_tokens_per_expert'])


def _disposible_tensor(tensor):
    from dlblas.utils.utils import DisposibleTensor
    if isinstance(tensor, torch.Tensor):
        tensor = DisposibleTensor(tensor)
    else:
        tensor = [DisposibleTensor(x) for x in tensor]
    return tensor


def dispatch_ll(
    self,
    hidden_states: torch.Tensor,
    topk_idx: torch.Tensor,
    topk_weights: torch.Tensor,
    num_experts: int,
    use_fp8: bool = True,
):
    """Dispatch low latency."""
    if num_experts is not None and self.num_experts is not None:
        assert self.num_experts == num_experts
    topk_idx = topk_idx.to(torch.int64)
    expected_m = (hidden_states.shape[0] * self.get_buffer().group_size * topk_idx.shape[1] +
                  num_experts) // num_experts

    (
        packed_recv_hidden,
        masked_m,
        self.handle,
        event,
        hook,
    ) = self.get_buffer().low_latency_dispatch(
        hidden_states,
        topk_idx,
        self.num_max_dispatch_tokens_per_rank,
        num_experts,
        use_fp8=use_fp8,
        async_finish=not self.return_recv_hook,
        return_recv_hook=self.return_recv_hook,
    )
    hook() if self.return_recv_hook else event.current_stream_wait()
    packed_recv_hidden = _disposible_tensor(packed_recv_hidden)
    return (
        packed_recv_hidden,
        topk_idx,
        topk_weights,
        masked_m,
        expected_m,
    )


def dispatch_async_ll(
    self,
    hidden_states: torch.Tensor,
    topk_idx: torch.Tensor,
    num_experts: Optional[int] = None,
    use_fp8: bool = True,
    async_finish: bool = True,
):
    assert topk_idx.dtype == torch.int64
    if num_experts is not None and self.num_experts is not None:
        assert self.num_experts == num_experts
    (
        recv_hidden_states,
        recv_expert_count,
        handle,
        event,
        hook,
    ) = self.get_buffer().low_latency_dispatch(
        hidden_states,
        topk_idx,
        self.num_max_dispatch_tokens_per_rank,
        num_experts=self.num_experts,
        use_fp8=use_fp8,
        async_finish=async_finish,
        return_recv_hook=not async_finish,
    )
    recv_hidden_states = _disposible_tensor(recv_hidden_states)
    return recv_hidden_states, recv_expert_count, handle, event, hook


class FusedMoELowLatency:

    def __init__(
        self,
        ep_size: int,
        ep_group: dist.ProcessGroup,
        num_experts: int,
        hidden_dim: int,
        layer_index: int,
        out_dtype: torch.dtype = torch.bfloat16,
    ):
        from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherLowLatency
        self.num_experts = num_experts
        self.layer_index = layer_index
        self.out_dtype = out_dtype
        self.token_dispatcher = DeepEPTokenDispatcherLowLatency(
            group=ep_group,
            num_experts=num_experts,
            num_local_experts=num_experts // ep_size,
            hidden_size=hidden_dim,
            params_dtype=out_dtype,
        )

    def experts(
        self,
        hidden_states: torch.Tensor,
        gate_up_weight: torch.Tensor,
        gate_down_weight: torch.Tensor,
        masked_m: torch.Tensor,
        expected_m: int,
    ):
        from dlblas.utils.utils import DisposibleTensor

        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_moe_ep
        from lmdeploy.pytorch.third_party.deep_gemm import m_grouped_bf16_gemm_nt_masked
        num_groups, m, _ = hidden_states.shape
        n = gate_up_weight.size(1)
        expected_m = min(expected_m, m)
        gateup_output = gate_up_weight.new_empty((num_groups, m, n))
        m_grouped_bf16_gemm_nt_masked(DisposibleTensor.maybe_unwrap(hidden_states), gate_up_weight, gateup_output,
                                      masked_m, expected_m)
        DisposibleTensor.maybe_dispose(hidden_states)
        down_input = silu_and_mul_moe_ep(gateup_output, masked_m)
        del gateup_output
        n = gate_down_weight.size(1)
        down_output = down_input.new_empty((num_groups, m, n))
        m_grouped_bf16_gemm_nt_masked(down_input, gate_down_weight, down_output, masked_m, expected_m)
        return down_output

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                up_weights: torch.Tensor,
                down_weights: torch.Tensor,
                expert_list: List[int] = None):
        """forward."""
        recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = dispatch_ll(
            self.token_dispatcher,
            hidden_states,
            topk_ids,
            topk_weights,
            self.num_experts,
            use_fp8=False,
        )
        hidden_states = None
        out_states = self.experts(recv_hidden_states, up_weights, down_weights, masked_m, expected_m)
        out_states = self.token_dispatcher.combine(out_states, topk_idx, topk_weights)
        return out_states

    def wait(self, event):
        event.current_stream_wait()

    def dispatch_async(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        num_experts: Optional[int] = None,
        use_fp8: bool = False,
        async_finish: bool = True,
    ):
        return dispatch_async_ll(self.token_dispatcher, hidden_states, topk_idx, num_experts, use_fp8, async_finish)

    def combine_async(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
        handle: tuple,
        async_finish: bool,
    ):
        return self.token_dispatcher.combine_async(hidden_states, topk_idx, topk_weights, handle, async_finish)

    def fusedmoe_forward(self, state, up_weight, down_weight):
        recv_hidden_states = state['recv_hidden_states']
        masked_m = state['recv_expert_count']
        hidden_shape = state['raw_hidden_shape']
        topk_idx = state['topk_idx']
        expected_m = (hidden_shape[0] * self.token_dispatcher.buffer_low_latency.group_size * topk_idx.shape[1] +
                      self.token_dispatcher.num_experts) // self.token_dispatcher.num_experts
        return self.experts(recv_hidden_states, up_weight, down_weight, masked_m, expected_m)


def build_deepep_moe(
    low_latency_mode: bool,
    ep_size: int,
    ep_group: dist.ProcessGroup,
    num_experts: int,
    hidden_dim: int,
    top_k: int,
    layer_idx: int = 0,
    out_dtype: torch.dtype = torch.bfloat16,
):
    if low_latency_mode:
        return FusedMoELowLatency(ep_size=ep_size,
                                  ep_group=ep_group,
                                  num_experts=num_experts,
                                  hidden_dim=hidden_dim,
                                  layer_index=layer_idx,
                                  out_dtype=out_dtype)
    else:
        return FusedMoENormal(ep_size=ep_size,
                              ep_group=ep_group,
                              num_experts=num_experts,
                              hidden_dim=hidden_dim,
                              layer_index=layer_idx,
                              top_k=top_k,
                              out_dtype=out_dtype)


class FusedMoEEPImpl(TritonFusedMoEImpl):
    """Fused moe implementation."""

    def __init__(
        self,
        ep_size: int,
        ep_group: dist.ProcessGroup,
        top_k: int,
        num_experts: int,
        hidden_dim: int,
        renormalize: bool = False,
        layer_idx: int = 0,
        out_dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__(top_k, num_experts, renormalize)
        self.num_experts = num_experts
        self.ep_size = ep_size
        self.ep_group = ep_group
        self.hidden_dim = hidden_dim
        self.layer_idx = layer_idx
        self.out_dtype = out_dtype

        try:
            import deep_gemm  # noqa: F401
        except ImportError:
            logger.exception('DeepGEMM is required for DeepEP MoE implementation.')

        try:
            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep  # noqa: F401
            get_moe_backend().set_deepep_moe_backend()
            if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):
                DeepEPBuffer.set_explicitly_destroy()
        except ImportError:
            logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')

        # pre-allocate buffer
        self.fusedmoe_build(True)

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
        return gate_up_weights, down_weights

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                down_weights: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        assert act_func is None, 'Activation function is not supported in DeepEP MoE.'
        hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,
                                                                                    topk_ids)

        topk_weights = self.do_renormalize(topk_weights)
        step_ctx = get_step_ctx_manager().current_context()
        low_latency_mode = step_ctx.is_decoding
        moe = self.fusedmoe_build(low_latency_mode)
        out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, down_weights, expert_list)

        out_states = gather_outputs_by_attn_tp(out_states, split_size)
        return out_states

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        if get_dist_manager().current_context().dist_config.enable_eplb:
            raise NotImplementedError('float16/bfloat16 enable_eplb is not Implemented.')
        else:
            return super().ep_expert_list(world_size=world_size, rank=rank)

    def do_renormalize(self, topk_weights):
        return _renormalize(topk_weights, self.renormalize)

    def fusedmoe_build(self, low_latency_mode: bool = False):
        deepep_moe = build_deepep_moe(low_latency_mode,
                                      self.ep_size,
                                      self.ep_group,
                                      self.num_experts,
                                      self.hidden_dim,
                                      self.top_k,
                                      layer_idx=self.layer_idx,
                                      out_dtype=self.out_dtype)
        return deepep_moe


class TritonFusedMoEBuilder(FusedMoEBuilder):
    """Triton fused moe builder."""

    @staticmethod
    def build(
        top_k: int,
        num_experts: int,
        renormalize: bool = False,
        hidden_dim: int = 1,
        ep_size: int = 1,
        ep_group: dist.ProcessGroup = None,
        layer_idx: int = 0,
        out_dtype: torch.dtype = torch.bfloat16,
    ):
        """Build from mlp."""
        if ep_size > 1:
            return FusedMoEEPImpl(ep_size=ep_size,
                                  ep_group=ep_group,
                                  top_k=top_k,
                                  num_experts=num_experts,
                                  hidden_dim=hidden_dim,
                                  renormalize=renormalize,
                                  layer_idx=layer_idx,
                                  out_dtype=out_dtype)
        return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize)


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe/ep_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import distributed as dist

from lmdeploy.pytorch.distributed import get_dist_manager


def split_inputs_by_attn_tp(
    hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
):
    """Split input by attn tp."""
    dist_ctx = get_dist_manager().current_context()
    attn_tp = dist_ctx.dist_config.attn_tp
    attn_rank = dist_ctx.attn_tp_group.rank
    num_states = hidden_states.size(0)

    if attn_tp == 1 or attn_tp > num_states:
        return hidden_states, topk_weights, topk_ids, None

    # split size
    base = num_states // attn_tp
    remain = num_states % attn_tp
    split_size = [base + 1] * remain + [base] * (attn_tp - remain)

    # split inputs
    hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank]
    topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank]
    topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank]

    return hidden_states, topk_weights, topk_ids, split_size


def gather_outputs_by_attn_tp(out_states: torch.Tensor, split_size: List[int]):
    """Gather output by attn tp."""
    if split_size is None:
        return out_states

    dist_ctx = get_dist_manager().current_context()
    gpu_group = dist_ctx.attn_tp_group.gpu_group
    new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1]))
    new_out_states_list = list(new_out_states.split(split_size, dim=0))
    dist.all_gather(new_out_states_list, out_states, group=gpu_group)
    return new_out_states


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe/w8a8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import torch

from lmdeploy.pytorch.backends.moe import FusedMoEW8A8Builder, FusedMoEW8A8Impl
from lmdeploy.pytorch.kernels.cuda import fused_moe_w8a8
from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8
from lmdeploy.pytorch.models.q_modules import QTensor
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):
    """Triton fused moe w8a8 implementation."""

    def __init__(
        self,
        top_k: int,
        num_experts: int,
        renormalize: bool = False,
        out_dtype: torch.dtype = torch.float16,
        quant_dtype: torch.dtype = torch.int8,
    ):
        self.num_experts = num_experts
        self.top_k = top_k
        self.renormalize = renormalize
        self.out_dtype = out_dtype
        self.quant_dtype = quant_dtype

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
                       down_scale: torch.Tensor):
        # do not transpose weight for int8/fp8
        return gate_up_weights, down_weights, gate_up_scale, down_scale

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                gate_up_scale: torch.Tensor,
                down_weights: torch.Tensor,
                down_scale: torch.Tensor,
                expert_list: List[int] = None):
        """forward."""

        if isinstance(hidden_states, torch.Tensor):
            hidden_states = hidden_states.contiguous()
            input_quant, input_scale = per_token_quant_int8(hidden_states, 1e-7, quant_dtype=self.quant_dtype)
        else:
            assert isinstance(hidden_states, QTensor)
            input_quant, input_scale = (hidden_states.tensor, hidden_states.scale)

        expert_offset = 0
        num_experts = None
        if expert_list is not None and len(expert_list) != self.num_experts:
            expert_offset = expert_list[0]
            num_experts = self.num_experts
        return fused_moe_w8a8(input_quant,
                              input_scale,
                              gate_up_weights,
                              gate_up_scale,
                              down_weights,
                              down_scale,
                              topk_weights=topk_weights,
                              topk_ids=topk_ids,
                              topk=self.top_k,
                              out_dtype=self.out_dtype,
                              quant_dtype=self.quant_dtype,
                              expert_offset=expert_offset,
                              num_experts=num_experts,
                              renormalize=self.renormalize)


class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):
    """Triton fused moe w8a8 builder."""

    @staticmethod
    def build(
        top_k: int,
        num_experts: int,
        renormalize: bool = False,
        out_dtype: torch.dtype = torch.float16,
        quant_dtype: torch.dtype = torch.int8,
    ):
        """Build from mlp."""
        return TritonFusedMoEW8A8Impl(top_k=top_k,
                                      num_experts=num_experts,
                                      renormalize=renormalize,
                                      out_dtype=out_dtype,
                                      quant_dtype=quant_dtype)


================================================
FILE: lmdeploy/pytorch/backends/cuda/moe_router.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.kernels.cuda.fused_noaux_tc import fused_noaux_tc_routing

from ..default.moe_router import DefaultRouterNoauxTCImpl
from ..moe_router import RouterNoauxTCBuilder, RouterNoauxTCImpl


def is_power_of_two(n):
    return n > 0 and (n & (n - 1)) == 0


class TritonRouterNoauxTCImpl(DefaultRouterNoauxTCImpl):

    def __init__(
        self,
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ):
        super().__init__(
            scoring_func=scoring_func,
            top_k=top_k,
            n_group=n_group,
            topk_group=topk_group,
            n_routed_experts=n_routed_experts,
            routed_scaling_factor=routed_scaling_factor,
            renormalize=renormalize,
            router_n_groups=router_n_groups,
        )

        self.enable_custom_kernel = self.should_enable_custom_kernel()

    def should_enable_custom_kernel(self) -> bool:
        if self.router_n_groups > 0:
            return False

        if self.scoring_func != 'sigmoid':
            return False

        if self.n_routed_experts % 32 != 0:
            return False

        if not is_power_of_two(self.n_routed_experts):
            return False

        if not is_power_of_two(self.n_group):
            return False

        return True

    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Router forward."""
        if self.enable_custom_kernel:
            return fused_noaux_tc_routing(
                logits,
                bias,
                num_experts=self.n_routed_experts,
                n_group=self.n_group,
                topk_group=self.topk_group,
                top_k=self.top_k,
                renormalize=self.renormalize,
                routed_scaling_factor=self.routed_scaling_factor,
            )
        else:
            return super().forward(logits, bias)


class TritonRouterNoauxTCBuilder(RouterNoauxTCBuilder):

    @staticmethod
    def build(
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ) -> RouterNoauxTCImpl:
        return TritonRouterNoauxTCImpl(
            scoring_func=scoring_func,
            top_k=top_k,
            n_group=n_group,
            topk_group=topk_group,
            n_routed_experts=n_routed_experts,
            routed_scaling_factor=routed_scaling_factor,
            renormalize=renormalize,
            router_n_groups=router_n_groups,
        )


================================================
FILE: lmdeploy/pytorch/backends/cuda/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from lmdeploy.pytorch.kernels.cuda import multinomial_sampling

from ..multinomial_sampling import MultinomialSamplingBuilder, MultinomialSamplingImpl


class TritonMultinomialSamplingImpl(MultinomialSamplingImpl):

    def forward(self,
                scores: torch.Tensor,
                seeds: torch.LongTensor,
                offsets: torch.LongTensor,
                indices: torch.Tensor = None):
        """forward."""
        return multinomial_sampling(scores, seeds, offsets, indices)


class TritonMultinomialSamplingBuilder(MultinomialSamplingBuilder):
    """Triton multinomial sampling builder."""

    def build():
        """build."""
        return TritonMultinomialSamplingImpl()


================================================
FILE: lmdeploy/pytorch/backends/cuda/norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.pytorch.kernels.cuda import rms_norm

from ..norm import RMSNormBuilder, RMSNormImpl


class TritonRMSNormImpl(RMSNormImpl):
    """Triton RMS norm implementation."""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        self.hidden_size = hidden_size
        self.eps = eps

    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        if residual is None:
            x = rms_norm(x, weight, self.eps)
            return x
        else:
            x, residual = rms_norm(x, weight, self.eps, residual=residual)
            return x, residual


class TritonRMSNormBuilder(RMSNormBuilder):
    """Triton RMS norm implementation builder."""

    @staticmethod
    def build(weight: torch.Tensor, eps: float = 1e-6):
        """build."""
        return TritonRMSNormImpl(weight, eps)


================================================
FILE: lmdeploy/pytorch/backends/cuda/nsa.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index
from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8

from ..nsa import BaseNSAIndexFP8, BaseNSAIndexFP8Builder, NSAIndexMeta


class TritonNSAIndexFP8(BaseNSAIndexFP8):

    def __init__(self, topk: int, softmax_scale: float, block_size: int, fill: int) -> None:
        super().__init__()
        self.topk = topk
        self.softmax_scale = softmax_scale
        self.block_size = block_size
        self.fill = fill
        # TODO: configable scale fmt
        self.scale_fmt = 'ue8m0'

    def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,
                meta: NSAIndexMeta) -> Tensor:

        assert q.dim() == 3
        assert k.dim() == 2
        cu_seqlen_q = meta.cu_seqlen_q
        q_seqlens = meta.q_seqlens
        k_seqlens = meta.k_seqlens
        block_offset = meta.block_offset
        max_q_seqlen = meta.max_q_seqlen
        max_kv_seqlen = meta.max_kv_seqlen

        q_shape = q.shape
        q = q.reshape(-1, q_shape[-1])
        q, q_s = quant_fp8(q, self.block_size, dtype=k_cache.dtype, trans_scale=True, scale_fmt=self.scale_fmt)
        q = q.reshape(*q_shape)
        q_s = q_s.reshape(weights.shape)
        q_s = q_s * self.softmax_scale * weights

        fill_kv_cache_blocked_fp8(k[:, None],
                                  None,
                                  k_cache[..., None, :],
                                  None,
                                  k_s_cache[..., None, :],
                                  None,
                                  cu_seqlen_q=cu_seqlen_q,
                                  kv_seqlens=k_seqlens,
                                  max_q_seqlen=max_q_seqlen,
                                  block_offsets=block_offset,
                                  group_size=self.block_size,
                                  scale_fmt=self.scale_fmt)

        scores = fp8_index(q,
                           q_s,
                           k_cache,
                           k_s_cache[..., 0],
                           cu_seqlen_q,
                           k_seqlens,
                           block_offset,
                           max_q_seqlen=max_q_seqlen,
                           max_k_seqlen=max_kv_seqlen,
                           causal=True)
        return bitonic_topk(scores, q_seqlens, k_seqlens, self.topk, fill=self.fill, descending=True)


class TritonNSAIndexFP8Builder(BaseNSAIndexFP8Builder):

    @staticmethod
    def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:
        return TritonNSAIndexFP8(topk, softmax_scale=softmax_scale, block_size=block_size, fill=fill)


================================================
FILE: lmdeploy/pytorch/backends/cuda/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..base import OpType
from ..default import DefaultOpsBackend

logger = get_logger('lmdeploy')


class CudaOpsBackend(DefaultOpsBackend):
    """Cuda layer backend."""

    @staticmethod
    def get_name() -> str:
        """Backend name."""
        return 'cuda'

    @classmethod
    def get_layer_impl_builder(cls, layer_type: OpType):
        """Get cuda layer builder."""
        if layer_type == OpType.PagedAttention:
            from .attention import TritonAttentionBuilder
            return TritonAttentionBuilder
        elif layer_type == OpType.FlashAttention:
            from .flash_attention import TritonFlashAttentionBuilder
            return TritonFlashAttentionBuilder
        elif layer_type == OpType.ApplyRotaryEmb:
            from .apply_rotary_emb import TritonApplyRotaryEmbBuilder
            return TritonApplyRotaryEmbBuilder
        elif layer_type == OpType.RMSNorm:
            from .norm import TritonRMSNormBuilder
            return TritonRMSNormBuilder
        elif layer_type == OpType.LoRA:
            from .lora import TritonLoRABuilder
            return TritonLoRABuilder
        elif layer_type == OpType.LinearW8A8:
            from .qmodules import TritonLinearW8A8Builder
            return TritonLinearW8A8Builder
        elif layer_type == OpType.RMSNormW8A8:
            from .qmodules import TritonRMSNormBuilder
            return TritonRMSNormBuilder
        elif layer_type == OpType.MultinomialSampling:
            from .multinomial_sampling import TritonMultinomialSamplingBuilder
            return TritonMultinomialSamplingBuilder
        elif layer_type == OpType.SiluAndMul:
            from .activation import TritonSiluAndMulBuilder
            return TritonSiluAndMulBuilder
        elif layer_type == OpType.LinearW4A16:
            from .awq_modules import AwqLinearW4A16Builder
            return AwqLinearW4A16Builder
        elif layer_type == OpType.FusedMoE:
            from .moe import TritonFusedMoEBuilder
            return TritonFusedMoEBuilder
        elif layer_type == OpType.FusedMoEW8A8:
            from .moe import TritonFusedMoEW8A8Builder
            return TritonFusedMoEW8A8Builder
        elif layer_type == OpType.FusedMoEBlockedF8:
            from .moe import TritonFusedMoEBlockedF8Builder
            return TritonFusedMoEBlockedF8Builder
        elif layer_type == OpType.LinearBlockedF8:
            from .blockedf8_modules import TritonLinearBlockedF8Builder
            return TritonLinearBlockedF8Builder
        elif layer_type == OpType.NSAIndexFP8:
            from .nsa import TritonNSAIndexFP8Builder
            return TritonNSAIndexFP8Builder
        elif layer_type == OpType.RouterNoauxTC:
            from .moe_router import TritonRouterNoauxTCBuilder
            return TritonRouterNoauxTCBuilder
        elif layer_type == OpType.CausalConv1d:
            from .causal_conv1d import CausalConv1dCudaBuilder
            return CausalConv1dCudaBuilder
        elif layer_type == OpType.GatedDeltaRule:
            from .gated_delta_rule import CudaGatedDeltaRuleBuilder
            return CudaGatedDeltaRuleBuilder
        else:
            logger.debug(f'Op {layer_type} fallback to default implementation.')
            return super().get_layer_impl_builder(layer_type)

    @staticmethod
    def get_attention_metadata_cls():
        """Get attention metadata class."""
        from .attention import TritonAttentionMetadata
        return TritonAttentionMetadata

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get k block shape."""
        return (
            block_size,
            num_heads,
            head_size,
        )

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get v block shape."""
        return (
            block_size,
            num_heads,
            head_size,
        )

    @classmethod
    def update_meta_flashmla(cls, attn_metadata, model_config: ModelConfig, decoding_query_len: int):
        """Update meta for flashmla."""
        import flash_mla
        num_attention_heads = model_config.num_attention_heads * decoding_query_len
        is_fp8_kvcache = model_config.use_mla_fp8_cache
        index_topk = model_config.mla_index_topk
        num_heads_q = None if index_topk is None else num_attention_heads
        tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32),
                                                                         num_attention_heads,
                                                                         num_heads_k=1,
                                                                         num_heads_q=num_heads_q,
                                                                         is_fp8_kvcache=is_fp8_kvcache,
                                                                         topk=index_topk)
        attn_metadata.tile_scheduler_metadata = tile_scheduler_metadata
        attn_metadata.num_splits = num_splits

        if attn_metadata.block_offsets.dtype != torch.int32:
            attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32)

    @classmethod
    def update_meta_flashattn(cls, attn_metadata, step_context):
        from lmdeploy.pytorch.models.utils.cudagraph import _get_meta_flashattn
        batch_size = attn_metadata.q_seqlens.size(0)
        max_seqlen_q = step_context.input_ids.size(1) // batch_size
        block_size = step_context.kv_caches[0][0].size(1)
        window_size = (step_context.model_config.sliding_window, ) * 2
        scheduler_metadata = _get_meta_flashattn(
            batch_size=batch_size,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=step_context.max_kv_seqlen,
            num_heads_q=step_context.model_config.num_attention_heads,
            num_heads_kv=step_context.model_config.num_key_value_heads,
            headdim=step_context.model_config.head_dim,
            cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
            qkv_dtype=step_context.model_config.dtype,
            page_size=block_size,
            window_size=window_size,
        )
        attn_metadata.scheduler_metadata = scheduler_metadata
        attn_metadata.max_kv_seqlen = step_context.max_kv_seqlen
        return attn_metadata

    @classmethod
    def update_step_context(cls, step_context):
        """Update step context."""
        attn_meta_cls = cls.get_attention_metadata_cls()
        q_seqlens = step_context.q_seqlens
        kv_seqlens = step_context.kv_seqlens
        kv_start_loc = None
        kv_flatten_size = None
        use_flash_mla = step_context.model_config.use_flash_mla
        use_flash_attn3_decoding = step_context.model_config.model_paradigm == 'ar_spec'

        # pad and cumsum requires 4 kernels, so we fuse seqlens cumsum into one kernel
        seqlens = torch.stack([q_seqlens, kv_seqlens], dim=0)
        cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens, dim=1, dtype=torch.int32), (1, 0))
        cu_seqlens_q = cu_seqlens[0]
        cu_seqlens_k = cu_seqlens[1]
        q_start_loc = step_context.q_start_loc
        if not step_context.is_decoding:
            kv_start_loc = cu_seqlens_k[:-1].to(kv_seqlens.dtype)
            kv_flatten_size = step_context.sum_kv_seqlen

        attn_metadata = attn_meta_cls(
            step_context.is_decoding,
            step_context.block_offsets,
            q_start_loc=q_start_loc,
            q_seqlens=q_seqlens,
            kv_start_loc=kv_start_loc,
            kv_seqlens=kv_seqlens,
            kv_flatten_size=kv_flatten_size,
            quant_policy=step_context.kv_quant_policy,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_kv_seqlen=step_context.max_kv_seqlen,
        )
        if step_context.is_decoding:
            if use_flash_mla:
                model_config = step_context.model_config
                decode_query_len = step_context.input_ids.size(1) // q_seqlens.size(0)
                cls.update_meta_flashmla(attn_metadata, model_config, decode_query_len)
            elif use_flash_attn3_decoding:
                attn_metadata = cls.update_meta_flashattn(attn_metadata, step_context)

        step_context.attn_metadata = attn_metadata
        return step_context

    @staticmethod
    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                           backend_config: BackendConfig, device: torch.device):
        """Build graph runner."""
        from .graph_runner import CUDAGraphRunner
        from .warmup_manager import WarmupMeta, get_warmup_manager

        # warmup ops.
        warmup_meta = WarmupMeta(
            max_num_tokens=cache_config.max_prefill_token_num,
            max_batch_size=cache_config.max_batches,
            dtype=model_config.dtype,
        )
        get_warmup_manager().warmup(warmup_meta)

        # make graph runner.
        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)

    @staticmethod
    def device_count():
        """Get num available devices."""
        return torch.cuda.device_count()

    @staticmethod
    def support_ray():
        """Support ray."""
        return True


================================================
FILE: lmdeploy/pytorch/backends/cuda/qmodules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_token_quant_int8,
                                                               rms_norm_dynamic_quant)
from lmdeploy.pytorch.models.q_modules import QTensor

from ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl


class TritonRMSNormW8A8Impl(RMSNormW8A8Impl):
    """Triton RMS norm w8a8 implementation api."""

    def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.quant_dtype = quant_dtype

    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        if residual is None:
            (x, rms_scale) = rms_norm_dynamic_quant(x, weight, self.eps, quant_dtype=self.quant_dtype)
            x = QTensor(x, rms_scale)
            return x
        else:
            (x, rms_scale, residual) = rms_norm_dynamic_quant(x,
                                                              weight,
                                                              self.eps,
                                                              residual=residual,
                                                              quant_dtype=self.quant_dtype)
            x = QTensor(x, rms_scale)
            return x, residual


class TritonRMSNormBuilder(RMSNormW8A8Builder):
    """Triton RMS norm w8a8 implementation builder."""

    @staticmethod
    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):
        """build."""
        return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype)


class TritonLinearW8A8Impl(LinearW8A8Impl):
    """Triton linear w8a8 implementation."""

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 out_dtype: torch.dtype = torch.float16,
                 quant_dtype: torch.dtype = torch.int8):
        self.in_features = in_features
        self.out_features = out_features
        self.out_dtype = out_dtype
        self.quant_dtype = quant_dtype

    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        if isinstance(x, torch.Tensor):
            input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype)
        else:
            assert isinstance(x, QTensor)
            input_quant, input_scale = x.tensor, x.scale

        out = matmul_kernel_dynamic_quant(input_quant,
                                          weight,
                                          input_scale,
                                          scale,
                                          output_dtype=self.out_dtype,
                                          bias=bias)

        if all_reduce:
            dist.all_reduce(out, group=group)
        return out


class TritonLinearW8A8Builder(LinearW8A8Builder):
    """Triton linear w8a8 implementation builder."""

    @staticmethod
    def build(in_features: int,
              out_features: int,
              bias: bool = True,
              dtype: torch.dtype = None,
              quant_dtype: torch.dtype = torch.int8):
        """build."""
        return TritonLinearW8A8Impl(in_features, out_features, dtype, quant_dtype=quant_dtype)


================================================
FILE: lmdeploy/pytorch/backends/cuda/token_dispatcher.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
try:
    from deep_ep import Buffer

    from lmdeploy.pytorch.envs import deep_ep_buffer_num_sms

    Buffer.set_num_sms(deep_ep_buffer_num_sms)
    use_deepep = True
except ImportError:
    use_deepep = False

from typing import List, Optional, Tuple

import torch
import torch.distributed as dist

from ..default.token_dispatcher import AlltoAllTokenDispatcher
from ..token_dispatcher import TokenDispatcherImpl

_buffer_normal = None
_buffer_low_latency = None
_buffer_common = None


def get_buffer_common(
    group: dist.ProcessGroup,
    num_max_dispatch_tokens_per_rank: int,
    hidden: int,
    num_experts: int,
    hidden_bytes: int,
):
    global _buffer_common
    num_nvl_bytes, num_rdma_bytes = 0, 0
    for config in (
            Buffer.get_dispatch_config(group.size()),
            Buffer.get_combine_config(group.size()),
    ):
        num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
        num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)

    num_rdma_bytes = max(
        Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts),
        num_rdma_bytes)

    if (_buffer_common is None or _buffer_common.group != group or _buffer_common.num_nvl_bytes < num_nvl_bytes
            or _buffer_common.num_rdma_bytes < num_rdma_bytes):
        _buffer_common = Buffer(
            group,
            num_nvl_bytes=num_nvl_bytes,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=True,
            num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2),
        )
    return _buffer_common


def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
    """Copy from DeepEP example usage in model inference prefilling.

    https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
    """
    global _buffer_normal
    num_nvl_bytes, num_rdma_bytes = 0, 0
    for config in (
            Buffer.get_dispatch_config(group.size()),
            Buffer.get_combine_config(group.size()),
    ):
        num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
        num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)

    if (_buffer_normal is None or _buffer_normal.group != group or _buffer_normal.num_nvl_bytes < num_nvl_bytes
            or _buffer_normal.num_rdma_bytes < num_rdma_bytes):
        _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
    return _buffer_normal


def get_buffer_low_latency(
    group: dist.ProcessGroup,
    num_max_dispatch_tokens_per_rank: int,
    hidden: int,
    num_experts: int,
):
    """Copy from DeepEP example usage in model inference decoding.

    https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
    """

    global _buffer_low_latency
    num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(),
                                                           num_experts)

    if (_buffer_low_latency is None or _buffer_low_latency.group != group or not _buffer_low_latency.low_latency_mode
            or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes):
        assert num_experts % group.size(
        ) == 0, f'num_experts: {num_experts} must be divisible by ep_size: {group.size()}'
        _buffer_low_latency = Buffer(
            group,
            num_rdma_bytes=num_rdma_bytes,
            low_latency_mode=True,
            num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2),
        )
    return _buffer_low_latency


class DeepEPTokenDispatcher(TokenDispatcherImpl):
    """Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
    https://github.com/NVIDIA/Megatron-
    LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py."""

    def __init__(
        self,
        group: torch.distributed.ProcessGroup,
        num_experts: int = None,
        num_local_experts: int = None,
        hidden_size: int = None,
        params_dtype: torch.dtype = None,
        num_max_dispatch_tokens_per_rank=128,
    ):
        self.group = group
        self.num_experts = num_experts
        self.num_local_experts = num_local_experts
        self.hidden_size = hidden_size
        self.params_bytes = params_dtype.itemsize
        # Handle used for combine operation
        self.handle = None
        if not use_deepep:
            raise ImportError('DeepEP is not installed. Please install DeepEP package from '
                              'https://github.com/deepseek-ai/deepep.')
        self.buffer_normal = get_buffer_common(self.group,
                                               num_max_dispatch_tokens_per_rank,
                                               self.hidden_size,
                                               self.num_experts,
                                               hidden_bytes=self.hidden_size * self.params_bytes)

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
        expert_list: List[int] = None,
        previous_event=None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        self.hidden_shape = hidden_states.shape
        topk_idx = topk_idx.to(torch.int64)
        (
            hidden_states,
            topk_idx,
            topk_weights,
            recv_tokens_per_expert,
            handle,
            event,
        ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, self.num_experts, previous_event)
        self.tokens_per_expert = torch.tensor(
            recv_tokens_per_expert,
            device=hidden_states.device,
            dtype=torch.int64,
        )
        tokens_per_expert = self.get_number_of_tokens_per_expert()
        self.handle = handle
        self.topk_idx = topk_idx
        self.topk_weights = topk_weights
        if hidden_states.shape[0] > 0:
            hidden_states, _, _, _, _ = self.get_permuted_hidden_states_by_experts(hidden_states)
        return hidden_states, topk_idx, topk_weights, tokens_per_expert

    def dispatch_normal(
        self,
        x: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
        num_experts: int,
        previous_event=None,
    ):
        (
            num_tokens_per_rank,
            num_tokens_per_rdma_rank,
            num_tokens_per_expert,
            is_token_in_rank,
            previous_event,
        ) = self.buffer_normal.get_dispatch_layout(
            topk_idx,
            num_experts,
            previous_event=previous_event,
            async_finish=False,
            allocate_on_comm_stream=False,
        )

        (
            recv_x,
            recv_topk_idx,
            recv_topk_weights,
            recv_tokens_per_expert,
            handle,
            event,
        ) = self.buffer_normal.dispatch(
            x,
            topk_idx=topk_idx,
            topk_weights=topk_weights.to(torch.float32),
            num_tokens_per_rank=num_tokens_per_rank,
            num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
            is_token_in_rank=is_token_in_rank,
            num_tokens_per_expert=num_tokens_per_expert,
            previous_event=previous_event,
            async_finish=False,
            allocate_on_comm_stream=False,
        )

        return (
            recv_x,
            recv_topk_idx,
            recv_topk_weights,
            recv_tokens_per_expert,
            handle,
            event,
        )

    def dispatch_normal_async(self,
                              x: torch.Tensor,
                              topk_idx: torch.Tensor,
                              topk_weights: torch.Tensor,
                              num_experts: Optional[int] = None,
                              previous_event=None,
                              async_finish=True):
        (
            num_tokens_per_rank,
            num_tokens_per_rdma_rank,
            num_tokens_per_expert,
            is_token_in_rank,
            previous_event,
        ) = self.buffer_normal.get_dispatch_layout(
            topk_idx,
            num_experts=self.num_experts if num_experts is None else num_experts,
            previous_event=previous_event,
            async_finish=async_finish,
            allocate_on_comm_stream=previous_event is not None and async_finish,
        )

        (
            recv_x,
            recv_topk_idx,
            recv_topk_weights,
            recv_tokens_per_expert,
            handle,
            event,
        ) = self.buffer_normal.dispatch(
            x,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            num_tokens_per_rank=num_tokens_per_rank,
            num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
            is_token_in_rank=is_token_in_rank,
            num_tokens_per_expert=num_tokens_per_expert,
            previous_event=previous_event,
            async_finish=async_finish,
            allocate_on_comm_stream=previous_event is not None and async_finish,
        )

        return (
            recv_x,
            recv_topk_idx,
            recv_topk_weights,
            recv_tokens_per_expert,
            handle,
            event,
        )

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if hidden_states.shape[0] > 0:
            hidden_states = self.get_restored_hidden_states_by_experts(hidden_states)
        hidden_states, event = self.combine_normal(hidden_states, self.handle)
        self.handle = None
        return hidden_states.view(self.hidden_shape)

    def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
        combined_x, _, event = self.buffer_normal.combine(
            x,
            handle,
            async_finish=False,
            previous_event=previous_event,
            allocate_on_comm_stream=False,
        )
        return combined_x, event

    def combine_normal_async(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):
        combined_x, _, event = self.buffer_normal.combine(
            x,
            handle,
            async_finish=async_finish,
            previous_event=previous_event,
            allocate_on_comm_stream=previous_event is not None and async_finish,
        )
        return combined_x, event

    def release(self):
        self.tokens_per_expert = None
        self.handle = None
        self.topk_idx = None
        self.topk_weights = None
        self.hidden_shape_before_permute = None
        self.dispatched_routing_map = None
        self.reversed_mapping_for_combine = None
        return True

    def get_number_of_tokens_per_expert(self) -> torch.Tensor:
        """Get the number of tokens per expert."""
        return self.tokens_per_expert

    def get_permuted_hidden_states_by_experts(self,
                                              hidden_states: torch.Tensor,
                                              topk_idx: Optional[torch.Tensor] = None,
                                              topk_weights: Optional[torch.Tensor] = None,
                                              num_experts: Optional[int] = None) -> torch.Tensor:
        (dispatched_routing_map,
         topk_weights) = super().indices_to_multihot(self.topk_idx if topk_idx is None else topk_idx,
                                                     self.topk_weights if topk_weights is None else topk_weights,
                                                     self.num_experts if num_experts is None else num_experts)
        hidden_states_shape = hidden_states.shape
        (hidden_states, reversed_mapping_for_combine) = super().permute(
            hidden_states,
            dispatched_routing_map,
        )
        self.hidden_shape_before_permute = hidden_states_shape
        self.dispatched_routing_map = dispatched_routing_map
        self.topk_weights = topk_weights
        self.reversed_mapping_for_combine = reversed_mapping_for_combine
        return hidden_states, hidden_states_shape, dispatched_routing_map, topk_weights, reversed_mapping_for_combine

    def get_restored_hidden_states_by_experts(
        self,
        hidden_states: torch.Tensor,
        reversed_mapping_for_combine: Optional[torch.Tensor] = None,
        hidden_shape_before_permute: Optional[torch.Size] = None,
        dispatched_routing_map: Optional[torch.Tensor] = None,
        topk_weights: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        assert (self.topk_weights.dtype == torch.float32), 'DeepEP only supports float32 probs'
        hidden_states = super().unpermute(
            hidden_states,
            sorted_indices=self.reversed_mapping_for_combine
            if reversed_mapping_for_combine is None else reversed_mapping_for_combine,
            restore_shape=self.hidden_shape_before_permute
            if hidden_shape_before_permute is None else hidden_shape_before_permute,
            routing_map=self.dispatched_routing_map if dispatched_routing_map is None else dispatched_routing_map,
            probs=self.topk_weights if topk_weights is None else topk_weights,
        )
        return hidden_states.to(input_dtype)


class DeepEPTokenDispatcherLowLatency(TokenDispatcherImpl):

    def __init__(
        self,
        group: torch.distributed.ProcessGroup,
        num_experts: int = None,
        num_local_experts: int = None,
        hidden_size: int = None,
        params_dtype: torch.dtype = None,
        return_recv_hook: bool = False,
    ):
        if not use_deepep:
            raise ImportError('DeepEP is not installed. Please install DeepEP package from '
                              'https://github.com/deepseek-ai/deepep.')
        self.group = group
        self.num_experts = num_experts
        self.num_local_experts = num_local_experts
        self.hidden_size = hidden_size
        self.params_bytes = params_dtype.itemsize
        self.handle = None
        self.num_max_dispatch_tokens_per_rank = 128
        self.buffer_low_latency = get_buffer_common(self.group,
                                                    self.num_max_dispatch_tokens_per_rank,
                                                    self.hidden_size,
                                                    self.num_experts,
                                                    hidden_bytes=self.hidden_size * self.params_bytes)
        self.return_recv_hook = return_recv_hook

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
        num_experts: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        topk_idx = topk_idx.to(torch.int64)
        expected_m = (hidden_states.shape[0] * self.buffer_low_latency.group_size * topk_idx.shape[1] +
                      num_experts) // num_experts

        packed_recv_hidden, masked_m, self.handle, event, hook = (self.buffer_low_latency.low_latency_dispatch(
            hidden_states,
            topk_idx,
            self.num_max_dispatch_tokens_per_rank,
            num_experts,
            use_fp8=True,
            async_finish=not self.return_recv_hook,
            return_recv_hook=self.return_recv_hook,
        ))
        hook() if self.return_recv_hook else event.current_stream_wait()
        return (
            packed_recv_hidden,
            topk_idx,
            topk_weights,
            masked_m,
            expected_m,
        )

    def dispatch_async(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        num_experts: Optional[int] = None,
        use_fp8: bool = True,
        async_finish: bool = True,
    ):
        assert topk_idx.dtype == torch.int64
        recv_hidden_states, recv_expert_count, handle, event, hook = (self.buffer_low_latency.low_latency_dispatch(
            hidden_states,
            topk_idx,
            self.num_max_dispatch_tokens_per_rank,
            num_experts=self.num_experts if num_experts is None else num_experts,
            use_fp8=use_fp8,
            async_finish=async_finish,
            return_recv_hook=not async_finish,
        ))
        return recv_hidden_states, recv_expert_count, handle, event, hook

    def combine(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        combined_hidden_states, event, hook = (self.buffer_low_latency.low_latency_combine(
            hidden_states,
            topk_idx,
            topk_weights.to(torch.float32),
            self.handle,
            async_finish=not self.return_recv_hook,
            return_recv_hook=self.return_recv_hook,
        ))
        hook() if self.return_recv_hook else event.current_stream_wait()
        return combined_hidden_states

    def combine_async(
        self,
        hidden_states: torch.Tensor,
        topk_idx: torch.Tensor,
        topk_weights: torch.Tensor,
        handle: Tuple,
        async_finish: bool,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert topk_idx.dtype == torch.int64
        assert topk_weights.dtype == torch.float32
        combined_hidden_states, event, hook = self.buffer_low_latency.low_latency_combine(
            hidden_states,
            topk_idx,
            topk_weights,
            handle,
            async_finish=async_finish,
            return_recv_hook=not async_finish,
        )
        return combined_hidden_states, event, hook


class TokenDispatcherBuilder:
    """Token dispatcher builder."""

    @staticmethod
    def build(
        group,
        num_experts,
        num_local_experts,
        hidden_size,
        params_dtype,
    ) -> TokenDispatcherImpl:
        """build."""
        if use_deepep is True:
            return DeepEPTokenDispatcher(
                group,
                num_experts,
                num_local_experts,
                hidden_size,
                params_dtype,
            )
        else:
            return AlltoAllTokenDispatcher(
                group,
                num_experts,
                num_local_experts,
            )


================================================
FILE: lmdeploy/pytorch/backends/cuda/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache


@lru_cache
def has_tilelang():
    try:
        import tilelang  # noqa: F401
        return True
    except Exception:
        return False


================================================
FILE: lmdeploy/pytorch/backends/cuda/warmup_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass

import torch

from lmdeploy.pytorch.utils import singleton
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


@dataclass
class WarmupMeta:
    """Warmup meta."""
    max_num_tokens: int
    max_batch_size: int
    dtype: torch.dtype


@singleton
class WarmupManager:

    def __init__(self):
        self._warmup_calls = dict()

    def __contains__(self, key: str):
        """Contain key."""
        return key in self._warmup_calls

    def __getitem__(self, key: str):
        """Get item."""
        return self._warmup_calls.get(key, None)

    def __setitem__(self, key: str, val):
        """Set item."""
        self._warmup_calls[key] = val

    def warmup(self, warmup_meta: WarmupMeta):
        """Warmup meta."""
        if len(self._warmup_calls) == 0:
            return
        import random
        logger.info('Warming up ops.')
        funcs = list(self._warmup_calls.values())
        random.shuffle(funcs)
        for func in funcs:
            func(warmup_meta)


def get_warmup_manager():
    """Get warmup manager."""
    return WarmupManager()


================================================
FILE: lmdeploy/pytorch/backends/deepep_moe_checker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.utils import singleton


@singleton
class MoEBackend:

    def __init__(self):
        """Initialize moe backend."""
        self._use_deepep_moe_backend = False

    def set_deepep_moe_backend(self):
        """Set deepep moe backend."""
        self._use_deepep_moe_backend = True

    def use_deepep_moe_backend(self):
        """Get deepep moe backend."""
        return self._use_deepep_moe_backend


def get_moe_backend():
    return MoEBackend()


================================================
FILE: lmdeploy/pytorch/backends/default/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import DefaultOpsBackend  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/default/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from torch import nn

from ..activation import GeluAndMulBuilder, GeluAndMulImpl, SiluAndMulBuilder, SiluAndMulImpl


class DefaultSiluAndMulImpl(SiluAndMulImpl):
    """Silu + multiple residual fused implementation."""

    def __init__(self, inplace: bool):
        self.inplace = inplace
        self.silu = nn.SiLU(inplace)

    def forward(self, x):
        """forward."""
        gate, up = x.chunk(2, -1)
        return self.silu(gate) * up


class DefaultSiluAndMulBuilder(SiluAndMulBuilder):
    """Silu and mul implementation builder."""

    @staticmethod
    def build(inplace: bool = False):
        """build."""
        return DefaultSiluAndMulImpl(inplace)


class DefaultGeluAndMulImpl(GeluAndMulImpl):
    """Gelu + multiple residual fused implementation."""

    def __init__(self, approximate: str = 'none'):
        self.act = nn.GELU(approximate=approximate)

    def forward(self, x):
        """forward."""
        gate, up = x.chunk(2, -1)
        return self.act(gate) * up


class DefaultGeluAndMulBuilder(GeluAndMulBuilder):
    """Gelu and mul implementation builder."""

    @staticmethod
    def build(approximate: str = 'none'):
        """build."""
        return DefaultGeluAndMulImpl(approximate)


================================================
FILE: lmdeploy/pytorch/backends/default/apply_rotary_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor

from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    half_size = x.shape[-1] // 2
    x1 = x[..., :half_size]
    x2 = x[..., half_size:]
    out = torch.empty_like(x)
    out[..., :half_size] = -x2
    out[..., half_size:] = x1
    return out


class DefaultApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
    """Apply rotary embedding implementation."""

    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
        """forward."""
        unsqueeze_dim = -2
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        if inplace:
            q_embed = query
            k_embed = key
            q_sin = rotate_half(query) * sin
            q_embed.mul_(cos)
            q_embed.add_(q_sin)
            k_sin = rotate_half(key) * sin
            k_embed.mul_(cos)
            k_embed.add_(k_sin)
        else:
            q_embed = (query * cos) + (rotate_half(query) * sin)
            k_embed = (key * cos) + (rotate_half(key) * sin)
        return q_embed, k_embed


class DefaultApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):
    """Apply rotary embedding implementation builder."""

    @staticmethod
    def build():
        """Build implementation."""
        return DefaultApplyRotaryEmbImpl()


================================================
FILE: lmdeploy/pytorch/backends/default/awq_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache
from typing import Optional

import torch

import lmdeploy.pytorch.distributed as dist

from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl


@lru_cache
def get_shifts(bits: int, device: torch.device):
    """Get awq shifts."""
    shifts = torch.arange(0, 32, bits, device=device)
    shifts = shifts.view(2, 4).t().flatten()
    return shifts


def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
    shifts = get_shifts(bits, qzeros.device)

    # unpacking columnwise
    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(torch.int8)
    iweights = iweights.view(iweights.shape[0], -1)

    # unpacking columnwise
    izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(torch.int8)
    izeros = izeros.view(izeros.shape[0], -1)

    # overflow checks
    iweights = torch.bitwise_and(iweights, (2**bits) - 1)
    izeros = torch.bitwise_and(izeros, (2**bits) - 1)

    return iweights, izeros


def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
    # Unpack the qweight and qzeros tensors
    iweight, izeros = unpack_awq(qweight, qzeros, bits)

    # fp16 weights
    iweight = iweight.unflatten(0, (-1, group_size))
    iweight = (iweight - izeros[:, None]) * scales[:, None]
    iweight = iweight.flatten(0, 1)

    return iweight


class DefaultLinearW4A16Impl(LinearW4A16Impl):
    """W4a16 linear implementation."""

    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size

    def forward(self,
                x,
                qweight: torch.Tensor,
                scales: torch.Tensor,
                qzeros: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        out_shape = x.shape[:-1] + (self.out_features, )
        input_dtype = x.dtype
        if input_dtype != torch.float16:
            x = x.half()
        out = dequantize_gemm(qweight, qzeros, scales, self.w_bit, self.group_size)
        out = torch.matmul(x, out)

        out = out + bias if bias is not None else out
        out = out.reshape(out_shape)

        if input_dtype != torch.float16:
            out = out.to(dtype=input_dtype)
        if all_reduce:
            dist.all_reduce(out, group=group)
        return out


class DefaultLinearW4A16Builder(LinearW4A16Builder):
    """W4a16 linear implementation builder."""

    @staticmethod
    def build(in_features: int,
              out_features: int,
              w_bit: int,
              group_size: int,
              bias: bool = False,
              dtype: torch.dtype = None):
        """build."""
        return DefaultLinearW4A16Impl(in_features, out_features, w_bit, group_size)


================================================
FILE: lmdeploy/pytorch/backends/default/embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist
import torch.nn.functional as F

from ..embedding import EmbeddingBuilder, EmbeddingImpl


def get_masked_input_and_mask(input: torch.Tensor, start_index: int, end_index: int):
    input = input - start_index
    masked_input = input.clamp(0, end_index - start_index - 1)
    inv_vocab_mask = masked_input != input
    return masked_input, inv_vocab_mask


class DefaultEmbeddingImpl(EmbeddingImpl):
    """Embedding implementation api."""

    def __init__(self, start_index: int, end_index: int):
        self.start_index = start_index
        self.end_index = end_index

    def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):
        """forward."""
        if all_reduce:
            mask_input, inv_vocab_mask = get_masked_input_and_mask(x, self.start_index, self.end_index)
            out = F.embedding(mask_input, weight)
            out.masked_fill_(inv_vocab_mask.unsqueeze(-1), 0)
            dist.all_reduce(out, group=group)
        else:
            out = F.embedding(x, weight)

        return out


class DefaultEmbeddingBuilder(EmbeddingBuilder):
    """Embedding implementation builder."""

    @staticmethod
    def build(start_index: int, end_index: int):
        """build."""
        return DefaultEmbeddingImpl(start_index=start_index, end_index=end_index)


================================================
FILE: lmdeploy/pytorch/backends/default/linear.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F

from ..linear import LinearBuilder, LinearImpl


class DefaultLinearImpl(LinearImpl):
    """Linear implementation api."""

    def forward(self,
                x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: dist.ProcessGroup = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        out = F.linear(x, weight, bias)
        if all_reduce:
            if scatter_size is not None:
                from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes
                out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
            else:
                dist.all_reduce(out, group=group)
        return out


class DefaultLinearBuilder(LinearBuilder):
    """Linear implementation builder."""

    @staticmethod
    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        return DefaultLinearImpl()


================================================
FILE: lmdeploy/pytorch/backends/default/moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..moe import SoftmaxTopKBuilder, SoftmaxTopKImpl


class DefaultSoftmaxTopKImpl(SoftmaxTopKImpl):
    """RMS norm implementation api."""

    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):
        self.top_k = top_k
        self.dim = dim
        self.n_groups = n_groups
        assert self.top_k % self.n_groups == 0, f'{self.top_k} cannot be divided by {self.n_groups}'

    def forward(self, x: torch.Tensor):
        """forward."""
        routing_weights = torch.softmax(x, dim=self.dim, dtype=torch.float32)
        if self.n_groups > 0:
            assert routing_weights.shape[
                self.
                dim] % self.n_groups == 0, f'{routing_weights.shape[self.dim]} cannot be divided by {self.n_groups}'
            per_group_top_k = self.top_k // self.n_groups
            group_size = routing_weights.shape[self.dim] // self.n_groups
            group_offsets = self.get_group_offsets(self.n_groups, group_size, routing_weights.device)
            routing_weights = routing_weights.unflatten(self.dim, (self.n_groups, group_size))
            topk_weights, topk_ids = torch.topk(routing_weights, per_group_top_k, dim=-1)
            topk_ids = (topk_ids + group_offsets).flatten(-2, -1)
            topk_weights = topk_weights.flatten(-2, -1)
        else:
            topk_weights, topk_ids = torch.topk(routing_weights, self.top_k, dim=self.dim)
        return topk_weights, topk_ids


class DefaultSoftmaxTopKBuilder(SoftmaxTopKBuilder):
    """RMS norm implementation builder."""

    @staticmethod
    def build(top_k: int, dim: int = -1, n_groups: int = -1):
        """build."""
        return DefaultSoftmaxTopKImpl(top_k, dim, n_groups=n_groups)


================================================
FILE: lmdeploy/pytorch/backends/default/moe_router.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Tuple

import torch

from ..moe_router import RouterNoauxTCBuilder, RouterNoauxTCImpl


def _compute_scores(scoring_func: str, logits: torch.Tensor):
    """Compute scores."""
    if scoring_func == 'softmax':
        scores = logits.softmax(dim=-1, dtype=torch.float32)
    elif scoring_func == 'sigmoid':
        scores = logits.sigmoid()
    else:
        raise NotImplementedError('unsupported scoring function '
                                  f'for MoE gating: {scoring_func}')
    return scores


@functools.lru_cache
def get_group_offsets(n_groups: int, group_size: int, device: str | torch.device) -> torch.Tensor:
    group_offsets = (torch.arange(n_groups, device=device) * group_size).view(1, -1, 1)  # [1, n_groups, 1]
    return group_offsets


class DefaultRouterNoauxTCImpl(RouterNoauxTCImpl):

    def __init__(
        self,
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ):

        self.scoring_func = scoring_func
        self.top_k = top_k
        self.n_group = n_group
        self.topk_group = topk_group
        self.n_routed_experts = n_routed_experts

        # renorm
        self.renormalize = renormalize
        self.routed_scaling_factor = routed_scaling_factor

        # n_group
        self.router_n_groups = router_n_groups

    def _forward_router_n_groups(self, scores_for_choice: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        assert scores_for_choice.shape[-1] % self.router_n_groups == 0, \
            f'{scores_for_choice.shape[-1]} cannot be divided by {self.router_n_groups}'
        per_group_top_k = self.top_k // self.router_n_groups
        group_size = scores_for_choice.shape[-1] // self.router_n_groups
        group_offsets = get_group_offsets(self.router_n_groups, group_size, device=scores_for_choice.device)
        scores_for_choice = scores_for_choice.unflatten(-1, (self.router_n_groups, group_size))
        topk_weight, topk_idx = torch.topk(scores_for_choice, per_group_top_k, dim=-1)
        topk_idx = (topk_idx + group_offsets).flatten(-2, -1)
        topk_weight = topk_weight.flatten(-2, -1)
        return topk_weight, topk_idx

    def _forward_default(self, scores: torch.Tensor, scores_for_choice: torch.Tensor,
                         sequence_length: int) -> Tuple[torch.Tensor, torch.Tensor]:
        group_scores = (scores_for_choice.view(sequence_length, self.n_group,
                                               -1).topk(2, dim=-1)[0].sum(dim=-1))  # [n, n_group]
        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]
        group_mask = torch.zeros_like(group_scores)  # [n, n_group]
        group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
        score_mask = (group_mask.unsqueeze(-1).expand(sequence_length, self.n_group,
                                                      self.n_routed_experts // self.n_group).reshape(
                                                          sequence_length, -1))  # [n, e]
        tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
        _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
        topk_weight = scores.gather(1, topk_idx)

        return topk_weight, topk_idx

    def renorm(self, topk_weight: torch.Tensor) -> torch.Tensor:
        if self.renormalize:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator
            if not topk_weight.is_contiguous():
                topk_weight = topk_weight.contiguous()

        topk_weight = topk_weight * self.routed_scaling_factor
        return topk_weight

    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Router forward."""
        sequence_length = logits.shape[0]

        scores = _compute_scores(self.scoring_func, logits)
        scores_for_choice = scores.view(sequence_length, -1) + bias[None]
        if self.router_n_groups > 0:
            topk_weight, topk_idx = self._forward_router_n_groups(scores_for_choice)
        else:
            topk_weight, topk_idx = self._forward_default(scores, scores_for_choice, sequence_length)

        topk_weight = self.renorm(topk_weight)
        return topk_weight, topk_idx


class DefaultRouterNoauxTCBuilder(RouterNoauxTCBuilder):

    @staticmethod
    def build(
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ):
        return DefaultRouterNoauxTCImpl(
            scoring_func=scoring_func,
            top_k=top_k,
            n_group=n_group,
            topk_group=topk_group,
            n_routed_experts=n_routed_experts,
            routed_scaling_factor=routed_scaling_factor,
            renormalize=renormalize,
            router_n_groups=router_n_groups,
        )


================================================
FILE: lmdeploy/pytorch/backends/default/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from ..multinomial_sampling import MultinomialSamplingBuilder, MultinomialSamplingImpl


class DefaultMultinomialSamplingImpl(MultinomialSamplingImpl):
    """Multinomial sampling implementation api."""

    def forward(self,
                scores: torch.Tensor,
                seeds: torch.LongTensor,
                offsets: torch.LongTensor,
                indices: torch.Tensor = None):
        """forward."""
        sampled_index = torch.multinomial(scores, num_samples=1, replacement=True)
        outputs = torch.gather(indices, dim=1, index=sampled_index)
        return outputs.view(-1)


class DefaultMultinomialSamplingBuilder(MultinomialSamplingBuilder):
    """Multinomial sampling implementation builder."""

    @staticmethod
    def build():
        """build."""
        return DefaultMultinomialSamplingImpl()


================================================
FILE: lmdeploy/pytorch/backends/default/norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..norm import LayerNormBuilder, LayerNormImpl, RMSNormBuilder, RMSNormImpl


class DefaultRMSNormImpl(RMSNormImpl):
    """RMS norm implementation api."""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        self.hidden_size = hidden_size
        self.eps = eps

    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        input_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        x = weight * x.to(input_dtype)
        if residual is None:
            return x
        return x, residual


class DefaultRMSNormBuilder(RMSNormBuilder):
    """RMS norm implementation builder."""

    @staticmethod
    def build(hidden_size: int, eps: float = 1e-6):
        """build."""
        return DefaultRMSNormImpl(hidden_size, eps)


class DefaultLayerNormImpl(LayerNormImpl):
    """RMS norm implementation api."""

    def __init__(self, normalized_shape: int, eps: float = 1e-6):
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape, )
        self.normalized_shape = normalized_shape
        self.eps = eps

    def forward(self,
                x: torch.Tensor,
                weight: torch.Tensor = None,
                bias: torch.Tensor = None,
                residual: torch.Tensor = None):
        """forward."""
        if residual is not None:
            x = x + residual
            residual = x
        x = torch.nn.functional.layer_norm(x, self.normalized_shape, weight=weight, bias=bias, eps=self.eps)
        if residual is None:
            return x
        return x, residual


class DefaultLayerNormBuilder(LayerNormBuilder):
    """RMS norm implementation builder."""

    @staticmethod
    def build(normalized_shape: int, eps: float = 1e-6):
        """build."""
        return DefaultLayerNormImpl(normalized_shape, eps)


================================================
FILE: lmdeploy/pytorch/backends/default/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from ..base import OpsBackend, OpType


class DefaultOpsBackend(OpsBackend):

    @staticmethod
    def get_name() -> str:
        return 'default'

    @classmethod
    def get_layer_impl_builder(cls, layer_type: OpType):
        """Get builder of given layer type."""
        if layer_type == OpType.Linear:
            from .linear import DefaultLinearBuilder
            return DefaultLinearBuilder
        elif layer_type == OpType.RotaryEmbedding:
            from .rotary_embedding import DefaultRotaryEmbeddingBuilder
            return DefaultRotaryEmbeddingBuilder
        elif layer_type == OpType.ApplyRotaryEmb:
            from .apply_rotary_emb import DefaultApplyRotaryEmbBuilder
            return DefaultApplyRotaryEmbBuilder
        elif layer_type == OpType.SiluAndMul:
            from .activation import DefaultSiluAndMulBuilder
            return DefaultSiluAndMulBuilder
        elif layer_type == OpType.GeluAndMul:
            from .activation import DefaultGeluAndMulBuilder
            return DefaultGeluAndMulBuilder
        elif layer_type == OpType.RMSNorm:
            from .norm import DefaultRMSNormBuilder
            return DefaultRMSNormBuilder
        elif layer_type == OpType.LayerNorm:
            from .norm import DefaultLayerNormBuilder
            return DefaultLayerNormBuilder
        elif layer_type == OpType.MultinomialSampling:
            from .multinomial_sampling import DefaultMultinomialSamplingBuilder
            return DefaultMultinomialSamplingBuilder
        elif layer_type == OpType.LinearW4A16:
            from .awq_modules import DefaultLinearW4A16Builder
            return DefaultLinearW4A16Builder
        elif layer_type == OpType.SoftmaxTopK:
            from .moe import DefaultSoftmaxTopKBuilder
            return DefaultSoftmaxTopKBuilder
        elif layer_type == OpType.Embedding:
            from .embedding import DefaultEmbeddingBuilder
            return DefaultEmbeddingBuilder
        elif layer_type == OpType.RouterNoauxTC:
            from .moe_router import DefaultRouterNoauxTCBuilder
            return DefaultRouterNoauxTCBuilder
        else:
            raise RuntimeError(f'{layer_type} not supported.')

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get block shape of k."""
        return (
            block_size,
            num_heads,
            head_size,
        )

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        """Get block shape of v."""
        return (
            block_size,
            num_heads,
            head_size,
        )

    @staticmethod
    def init():
        pass

    @staticmethod
    def ccl_backend() -> str:
        return 'nccl'


================================================
FILE: lmdeploy/pytorch/backends/default/rotary_embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn

from ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
                                RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)


def safe_torch_compile(**compile_kwargs):
    """Auto fallback."""

    def decorator(func):
        compiled_func = None
        compile_failed = False

        @wraps(func)
        def wrapper(*args, **kwargs):
            nonlocal compiled_func, compile_failed

            if compile_failed:
                return func(*args, **kwargs)

            if compiled_func is None:
                try:
                    compiled_func = torch.compile(func, **compile_kwargs)
                    return compiled_func(*args, **kwargs)
                except Exception:
                    compile_failed = True
                    return func(*args, **kwargs)

            return compiled_func(*args, **kwargs)

        return wrapper

    return decorator


@safe_torch_compile(dynamic=True)
def _rotary_embedding_fwd(position_ids: torch.Tensor,
                          inv_freq: torch.Tensor,
                          scaling_factor: float,
                          mscale: float = None,
                          dtype: torch.dtype = None,
                          device_type: torch.device = None):
    """Rotary embedding forward."""
    if dtype is None:
        dtype = torch.float16
    if device_type is None:
        device_type = 'cuda'
    position_ids = position_ids.float() / scaling_factor
    inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
    position_ids_expanded = position_ids[:, None, :]
    # Force float32 since bfloat16 loses precision on long contexts
    # See https://github.com/huggingface/transformers/pull/29285
    device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
    with torch.autocast(device_type=device_type, enabled=False):
        freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2)
        emb = freqs.repeat(1, 1, 2)
        cos = emb.cos()
        sin = emb.sin()

        if mscale is not None:
            cos = cos * mscale
            sin = sin * mscale

    return cos.to(dtype=dtype), sin.to(dtype=dtype)


class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):
    """Base rotary embedding."""

    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """forward."""
        device_type = x.device.type
        dtype = x.dtype
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        return _rotary_embedding_fwd(position_ids,
                                     self.inv_freq,
                                     scaling_factor=self.scaling_factor,
                                     dtype=dtype,
                                     device_type=device_type)


class LlamaDynamicNTKScalingRotaryEmbedding(RotaryEmbeddingImpl):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling.

    Credits to the Reddit users /u/bloc97 and /u/emozilla
    """

    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0, max_position_embeddings: int = 2048):
        super().__init__(dim, base, scaling_factor)
        self.max_position_embeddings = max_position_embeddings

    def _ntk_inv_freq(self, seq_len: torch.Tensor):
        """ntk_inv_freq."""
        device = seq_len.device
        base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
                            (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
        inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
        return inv_freq

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """forward."""
        device_type = x.device.type
        dtype = x.dtype
        seq_len = torch.max(position_ids) + 1
        ntk_inv_freq = self._ntk_inv_freq(seq_len)
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        inv_freq = torch.where(seq_len > self.max_position_embeddings, ntk_inv_freq, self.inv_freq)

        cos, sin = _rotary_embedding_fwd(position_ids,
                                         inv_freq,
                                         scaling_factor=1.0,
                                         dtype=dtype,
                                         device_type=device_type)
        return cos, sin


class Llama3RotaryEmbeddingImpl(RotaryEmbeddingImpl):
    """Llama3 rotary embedding implementation."""

    def __init__(
        self,
        dim: int,
        base: int = 10000,
        scaling_factor: float = 1.0,
        low_freq_factor: float = 1.0,
        high_freq_factor: float = 4.0,
        original_max_position_embeddings: int = 8194,
    ):
        super().__init__(dim, base, scaling_factor)
        old_context_len = original_max_position_embeddings
        low_freq_wavelen = old_context_len / low_freq_factor
        high_freq_wavelen = old_context_len / high_freq_factor

        inv_freq = self.inv_freq
        factor = self.scaling_factor

        wavelen = 2 * math.pi / inv_freq
        # wavelen < high_freq_wavelen: do nothing
        # wavelen > low_freq_wavelen: divide by factor
        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
        # otherwise: interpolate between the two, using a smooth factor
        smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        self.scaling_factor = 1.0
        self.register_buffer('inv_freq', inv_freq_llama)


def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
    """yarn_find_correction_dim."""
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))


# Find dim range bounds based on rotations
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048, truncate: bool = True):
    """yarn_find_correction_range."""
    low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
    high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
    if truncate:
        low = math.floor(low)
        high = math.ceil(high)
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def yarn_get_mscale(scale=1, mscale=1):
    """yarn_get_mscale."""
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


def yarn_linear_ramp_mask(min, max, dim):
    """yarn_linear_ramp_mask."""
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


class YarnRotaryEmbeddingImpl(RotaryEmbeddingImpl):
    """Yarn rotary embedding implementation."""

    def __init__(self,
                 dim: int,
                 base: int = 10000,
                 scaling_factor: float = 1.0,
                 original_max_position_embeddings: int = 4096,
                 yarn_params: YarnParameters = None):
        super().__init__(dim, base, scaling_factor)
        self.original_max_position_embeddings = \
            original_max_position_embeddings
        assert yarn_params is not None
        self.beta_fast = yarn_params.beta_fast
        self.beta_slow = yarn_params.beta_slow
        self.mscale = yarn_params.mscale
        self.mscale_all_dim = yarn_params.mscale_all_dim
        self.truncate = yarn_params.truncate

        # get inv_freq
        freq_extra = 1.0 / (self.base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))
        freq_inter = 1.0 / (self.scaling_factor * self.base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))
        low, high = yarn_find_correction_range(
            self.beta_fast,
            self.beta_slow,
            dim,
            self.base,
            self.original_max_position_embeddings,
            truncate=self.truncate,
        )
        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(dtype=torch.float32)
        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
        self.register_buffer('inv_freq', inv_freq, persistent=False)

        # get mscale
        if yarn_params.attention_factor is not None:
            self.mscale = yarn_params.attention_factor
        else:
            self.mscale = float(
                yarn_get_mscale(self.scaling_factor, self.mscale) /
                yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
        if self.mscale == 1.0:
            self.mscale = None

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """forward."""
        device_type = x.device.type
        dtype = x.dtype
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        return _rotary_embedding_fwd(position_ids,
                                     self.inv_freq,
                                     scaling_factor=1.0,
                                     mscale=self.mscale,
                                     dtype=dtype,
                                     device_type=device_type)


class LongRoPEScalingRotaryEmbeddingImpl(RotaryEmbeddingImpl):
    """Yarn rotary embedding implementation."""

    def __init__(
        self,
        dim: int,
        base: int = 10000,
        max_position_embeddings: int = 4096,
        longrope_params: LongRoPEScalingParameters = None,
    ):
        super().__init__(dim, base)
        short_factor = torch.tensor(longrope_params.short_factor, dtype=torch.float32)
        long_factor = torch.tensor(longrope_params.long_factor, dtype=torch.float32)
        self.register_buffer('short_factor', short_factor, persistent=False)
        self.register_buffer('long_factor', long_factor, persistent=False)
        self.original_max_position_embeddings = \
            longrope_params.original_max_position_embeddings
        self.mscale = None
        self.short_mscale = longrope_params.short_mscale
        self.long_mscale = longrope_params.long_mscale
        if self.short_mscale is None and self.long_mscale is None:
            scale = (max_position_embeddings / self.original_max_position_embeddings)
            if scale <= 1.0:
                self.mscale = 1.0
            else:
                self.mscale = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """Rope forward."""
        dtype = x.dtype
        device = position_ids.device
        if self.short_factor.device != device:
            self.register_buffer('short_factor', self.short_factor.to(device), persistent=False)
            self.register_buffer('long_factor', self.long_factor.to(device), persistent=False)

        max_pos_ids = position_ids.max() + 1
        mask = max_pos_ids > self.original_max_position_embeddings
        ext_factors = torch.where(mask, self.long_factor, self.short_factor)

        mscale = self.mscale
        if mscale is None:
            mscale = torch.where(mask, self.long_mscale, self.short_mscale)

        inv_freq = self.inv_freq * (1.0 / ext_factors)
        return _rotary_embedding_fwd(position_ids,
                                     inv_freq,
                                     scaling_factor=1.0,
                                     mscale=mscale,
                                     dtype=dtype,
                                     device_type=device)


class FopeRotaryEmbeddingImpl(RotaryEmbeddingImpl):

    def __init__(self,
                 dim: int,
                 max_position_embeddings: int = 4096,
                 scaling_factor: float = 1.0,
                 params: FopeParameters = None):
        super().__init__(dim, scaling_factor=scaling_factor)
        self.head_dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.attention_scaling = scaling_factor
        self.params = params

        inv_freq = self.params.inv_freq
        inv_freq_idx_selected = torch.ones_like(inv_freq, dtype=torch.bool)
        if self.params.num_inv_freq is not None:
            num_inv_freq = self.params.num_inv_freq
            inv_freq_idx_selected[num_inv_freq:] = False
        else:
            inv_freq_idx_selected = inv_freq > (2.0 * torch.pi / self.max_position_embeddings)
            num_inv_freq = inv_freq_idx_selected.sum().item()

        self.inv_freq = inv_freq[inv_freq_idx_selected]
        self.register_buffer('inv_freq', self.inv_freq, persistent=False)

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor, sin_coef: torch.Tensor, cos_coef: torch.Tensor):
        """forward."""
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)

        inv_freq = self.inv_freq
        inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

        batch_size, seq_len, _ = x.shape
        if self.params.fope_sep_head:
            pos_cos = freqs.cos().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
            pos_sin = freqs.sin().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
        else:
            pos_cos = freqs.cos()
            pos_sin = freqs.sin()

        if self.params.fope_sep_head:
            sin = torch.einsum('bhtD, hDd -> bthd', pos_sin, sin_coef.float())
            cos = torch.einsum('bhtD, hDd -> bthd', pos_cos, cos_coef.float())
        else:
            sin = torch.einsum('btD, Dd -> btd', pos_sin, sin_coef.float())
            cos = torch.einsum('btD, Dd -> btd', pos_cos, cos_coef.float())

        sin = F.pad(input=sin, pad=(0, self.head_dim // 2 - sin.size(-1)), mode='constant', value=1)
        cos = F.pad(input=cos, pad=(0, self.head_dim // 2 - cos.size(-1)), mode='constant', value=1)

        sin = torch.cat((sin, sin), dim=-1)
        cos = torch.cat((cos, cos), dim=-1)

        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
    """Rotary embedding builder."""

    @staticmethod
    def build(
        dim: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        scaling_factor: float = 1.0,
        yarn_params: YarnParameters = None,
        longrope_params: LongRoPEScalingParameters = None,
        llama3_params: Llama3Parameters = None,
        fope_params: FopeParameters = None,
        emb_type: RopeType = RopeType.Default,
    ):
        """build."""
        if emb_type in (RopeType.Default, RopeType.LinearScaling):
            return RotaryEmbeddingImpl(dim, base, scaling_factor)
        elif emb_type == RopeType.DynamicNTKScaling:
            return LlamaDynamicNTKScalingRotaryEmbedding(dim, base, scaling_factor, max_position_embeddings)
        elif emb_type == RopeType.Llama3:
            return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,
                                             llama3_params.high_freq_factor,
                                             llama3_params.original_max_position_embeddings)
        elif emb_type == RopeType.Yarn:
            return YarnRotaryEmbeddingImpl(dim, base, scaling_factor, max_position_embeddings, yarn_params=yarn_params)
        elif emb_type == RopeType.LongRoPEScaling:
            return LongRoPEScalingRotaryEmbeddingImpl(
                dim,
                base,
                max_position_embeddings=max_position_embeddings,
                longrope_params=longrope_params,
            )
        elif emb_type == RopeType.Fope:
            return FopeRotaryEmbeddingImpl(
                dim,
                max_position_embeddings=max_position_embeddings,
                scaling_factor=scaling_factor,
                params=fope_params,
            )
        else:
            raise NotImplementedError(f'Unsupported embedding type: {emb_type}')


================================================
FILE: lmdeploy/pytorch/backends/default/token_dispatcher.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from ..token_dispatcher import TokenDispatcherImpl


class AlltoAllTokenDispatcher(TokenDispatcherImpl):

    def __init__(
        self,
        ep_group,
        num_experts,
        num_local_experts: int,
    ) -> None:
        self.num_local_experts = num_local_experts
        assert num_experts is not None
        self.num_experts = num_experts
        assert self.num_local_experts > 0, 'Expected at least one expert'
        self.ep_size = num_experts // num_local_experts
        self.ep_group = ep_group
        self.tp_size = 1
        self.input_splits = None
        self.output_splits = None
        input_chunk_idxs = torch.arange(self.num_experts, device=torch.device('cpu'))
        self.sort_input_by_local_experts = input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel()
        self.restore_output_by_local_experts = input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel()

    def sort_chunks_by_idxs(self, input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):
        """Split and sort the input tensor based on the split_sizes and sorted
        indices."""
        input = torch.split(input, split_sizes.tolist(), dim=0)
        output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
        return output

    def all_to_all(self, group: torch.distributed.group, input_: torch.Tensor, output_split: torch.Tensor,
                   input_split: torch.Tensor):
        output_split_sizes_ = output_split.tolist()
        input_split_sizes = input_split.tolist()
        output = input_.new_empty(
            size=[sum(output_split_sizes_)] + list(input_.size()[1:]),
            dtype=input_.dtype,
            device=torch.cuda.current_device(),
        )
        torch.distributed.all_to_all_single(
            output,
            input_,
            output_split_sizes=output_split_sizes_,
            input_split_sizes=input_split_sizes,
            group=group,
        )
        return output

    def preprocess(self, routing_map: torch.Tensor, local_expert_indices) -> torch.Tensor:
        assert (len(local_expert_indices) == self.num_local_experts), 'Invalid local expert indices'
        for i in range(len(local_expert_indices) - 1):
            assert (local_expert_indices[i] == local_expert_indices[i + 1] -
                    1), 'local_expert_indices must be continous'

        num_local_tokens_per_expert = routing_map.sum(dim=0).long()
        self.input_splits = (num_local_tokens_per_expert.reshape(self.ep_size, self.num_local_experts).sum(axis=1).to(
            torch.device('cpu'), non_blocking=True))
        dim_size = list(num_local_tokens_per_expert.size())
        dim_size[0] = dim_size[0] * torch.distributed.get_world_size(self.ep_group)
        output = num_local_tokens_per_expert.new_empty(dim_size)
        torch.distributed.all_gather_into_tensor(output, num_local_tokens_per_expert.contiguous(), group=self.ep_group)
        num_global_tokens_per_expert = (output.reshape(self.ep_size, self.tp_size, self.num_experts).transpose(0, 1))
        num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, :, local_expert_indices[0]:
                                                                          local_expert_indices[-1] + 1].contiguous()
        num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
        self.output_splits = (num_global_tokens_per_rank[0].to(torch.device('cpu'), non_blocking=True))
        num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
        if self.num_local_experts > 1:
            self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
                -1, self.num_local_experts)

            self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.to(torch.device('cpu'),
                                                                                            non_blocking=True)
        return num_tokens_per_local_expert

    def dispatch(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, probs: torch.Tensor,
                 local_expert_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        self.hidden_shape = hidden_states.shape
        self.topk_ids = topk_ids
        self.routing_map, self.topk_weights = super().indices_to_multihot(topk_ids, probs, self.num_experts)
        assert probs.dim() == 2, 'Expected 2D tensor for probs'
        assert self.routing_map.dim() == 2, 'Expected 2D tensor for token2expert mask'
        assert self.routing_map.dtype == torch.bool, 'Expected bool tensor for mask'
        hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
        tokens_per_expert = self.preprocess(self.routing_map, local_expert_indices)
        self.hidden_shape_before_permute = hidden_states.shape

        permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = super().permute(
            hidden_states,
            self.routing_map,
        )
        global_input_tokens = self.all_to_all(self.ep_group, permutated_local_input_tokens, self.output_splits,
                                              self.input_splits)
        if self.num_local_experts > 1:
            global_input_tokens = self.sort_chunks_by_idxs(
                global_input_tokens,
                self.num_global_tokens_per_local_expert.ravel(),
                self.sort_input_by_local_experts,
            )
        return global_input_tokens, None, None, tokens_per_expert

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.num_local_experts > 1:
            hidden_states = self.sort_chunks_by_idxs(
                hidden_states,
                self.num_global_tokens_per_local_expert.mT.ravel(),
                self.restore_output_by_local_experts,
            )
        permutated_local_input_tokens = self.all_to_all(self.ep_group, hidden_states, self.input_splits,
                                                        self.output_splits)
        output = super().unpermute(
            permutated_local_input_tokens,
            self.reversed_local_input_permutation_mapping,
            restore_shape=self.hidden_shape_before_permute,
            probs=self.topk_weights,
            routing_map=self.routing_map,
        )
        output = output.view(self.hidden_shape)
        return output


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.kernels.dlinfer.activation import silu_and_mul

from ..activation import SiluAndMulBuilder, SiluAndMulImpl


class DlinferSiluAndMulImpl(SiluAndMulImpl):
    """Silu + multiple fused implementation."""

    def forward(self, x):
        """forward."""
        return silu_and_mul(x)


class DlinferSiluAndMulBuilder(SiluAndMulBuilder):
    """Silu and mul implementation builder."""

    @staticmethod
    def build(inplace: bool = False):
        """build."""
        return DlinferSiluAndMulImpl()


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb

from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl


class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
    """Apply rotary embedding implementation."""

    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
        """forward."""
        if inplace:
            q_embed = None
            k_embed = None
        else:
            q_embed = query.new_empty(query.shape)
            k_embed = key.new_empty(key.shape)
        return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)


class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):
    """Apply rotary embedding implementation builder."""

    @staticmethod
    def build():
        """Build implementation."""
        return DlinferApplyRotaryEmbImpl()


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/ascend/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import AscendOpsBackend, SocVersion  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import math
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Dict, Tuple

import torch
import torch.distributed as dist

from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.utils import get_logger

from ..moe import DlinferMoECommType, DlinferMoeMetadata
from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class SocVersion:
    Ascend310P: str = 'Ascend310P'
    Ascend910: str = 'Ascend910'

    @classmethod
    @lru_cache(maxsize=1)
    def device_name(cls) -> str:
        try:
            return torch.npu.get_device_name()
        except ImportError:
            logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly.')
        except Exception as e:
            logger.warning(f'Error during Ascend get device name: {str(e)}. '
                           'Please check your Ascend environment configuration.')

    @classmethod
    def is_Ascend310P(cls) -> bool:
        return cls.device_name().startswith(cls.Ascend310P)

    @classmethod
    def is_Ascend910(cls) -> bool:
        return cls.device_name().startswith(cls.Ascend910)

    @classmethod
    @lru_cache(maxsize=1)
    def soc_version(cls) -> int:
        return torch.npu.get_soc_version()

    @classmethod
    def is_A2(cls) -> bool:
        return 220 <= cls.soc_version() <= 225

    @classmethod
    def is_A3(cls) -> bool:
        return 250 <= cls.soc_version() <= 255


@dataclass
class DistMeta:
    dp_size: int
    tp_size: int
    ep_size: int
    tp_rank: int
    ep_rank: int
    tp_group: torch.distributed.ProcessGroup
    ep_group: torch.distributed.ProcessGroup


class AscendKVQuantMeta:
    has_set_value: bool = False
    quant_meta: Dict = {}

    @classmethod
    def set_value(cls, device: str, dtype: torch.dtype, record_file: str, total_layers: int):
        with open(record_file, 'r') as file:
            data = file.read()
        scale_offset_pairs = re.findall(r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data)
        scale_offset_pairs = [(float(scale), float(offset)) for scale, offset in scale_offset_pairs]
        k_scales, v_scales, kv_scales = [], [], []
        k_zeros, v_zeros, kv_zeros = [], [], []
        if len(scale_offset_pairs) == total_layers:
            for scale, offset in scale_offset_pairs:
                k_scales.append(torch.tensor([scale], device=device, dtype=dtype))
                v_scales.append(torch.tensor([scale], device=device, dtype=dtype))
                kv_scales.append(torch.tensor([scale, scale], device=device, dtype=dtype))
                k_zeros.append(torch.tensor([offset], device=device, dtype=dtype))
                v_zeros.append(torch.tensor([offset], device=device, dtype=dtype))
                kv_zeros.append(torch.tensor([offset, offset], device=device, dtype=dtype))
        elif len(scale_offset_pairs) == total_layers * 2:
            for i in range(total_layers):
                scale_k, offset_k = scale_offset_pairs[2 * i]
                scale_v, offset_v = scale_offset_pairs[2 * i + 1]
                k_scales.append(torch.tensor([scale_k], device=device, dtype=dtype))
                v_scales.append(torch.tensor([scale_v], device=device, dtype=dtype))
                kv_scales.append(torch.tensor([scale_k, scale_v], device=device, dtype=dtype))
                k_zeros.append(torch.tensor([offset_k], device=device, dtype=dtype))
                v_zeros.append(torch.tensor([offset_v], device=device, dtype=dtype))
                kv_zeros.append(torch.tensor([offset_k, offset_v], device=device, dtype=dtype))
        else:
            raise ValueError(f'num of scale_offset_pairs({len(scale_offset_pairs)}) '
                             f'must match num of total_layers({total_layers})')

        cls.quant_meta.update({
            'k_scales': itertools.cycle(k_scales),
            'k_zeros': itertools.cycle(k_zeros),
            'v_scales': itertools.cycle(v_scales),
            'v_zeros': itertools.cycle(v_zeros),
            'kv_scales': itertools.cycle(kv_scales),
            'kv_zeros': itertools.cycle(kv_zeros)
        })
        cls.has_set_value = True


class AscendOpsBackend(DlinferOpsBackend):
    """Ascend layer backend."""
    enable_graph: bool = False
    total_slots = None
    max_batches = None
    dist_meta: DistMeta = None

    @staticmethod
    def get_name() -> str:
        """Backend name."""
        return 'ascend'

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        if SocVersion.is_Ascend910():
            return (block_size, num_heads, head_size)
        else:
            raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        if SocVersion.is_Ascend910():
            return (block_size, num_heads, head_size)
        else:
            raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')

    @classmethod
    def update_step_context(cls, step_context):
        """Update step context."""

        block_num, block_size, *_ = step_context.kv_caches[0][0].shape
        is_unpaged_prefill = False
        if not step_context.is_decoding:
            is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
        if step_context.block_offsets.dtype != torch.int32:
            step_context.block_offsets = step_context.block_offsets.to(torch.int32)
        if not (step_context.is_decoding or is_unpaged_prefill):
            step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0)
        if step_context.kv_seqlens.dtype != torch.int32:
            step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32)
        if step_context.q_seqlens.dtype != torch.int32:
            step_context.q_seqlens = step_context.q_seqlens.to(torch.int32)

        def get_total_slots():
            if cls.total_slots is None:
                cls.total_slots = torch.arange(block_num * block_size,
                                               dtype=torch.int32,
                                               device=step_context.block_offsets.device)
                cls.total_slots = cls.total_slots.view(block_num, block_size)
            return cls.total_slots

        def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
            """Get sequence lengths on CPU.

            Returns:
                q_seqlens_cpu: query sequence lengths (per sequence).
                kv_seqlens_cpu: kv sequence lengths (per sequence), used for
                    list/max seqlens calculation.
                kv_seqlens_expanded: kv sequence lengths expanded per token via
                    repeat_interleave, used for attention metadata.
            """
            if is_decoding:
                q_seqlens_cpu = None
                kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu()
            elif is_unpaged_prefill:
                q_seqlens_cpu = step_context.q_seqlens.cpu()
                kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu
            else:
                q_seqlens_cpu = step_context.q_seqlens.cpu()
                kv_seqlens_cpu = step_context.kv_seqlens.cpu()
                # Expand kv_seqlens to per-token for paged prefill attention
                kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0)
            return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded

        def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):
            if is_decoding:
                q_seqlens_list, kv_seqlens_list = None, None
            elif is_unpaged_prefill:
                q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist()
            else:
                q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist()
            return q_seqlens_list, kv_seqlens_list

        def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None):
            if is_decoding:
                max_q_seq_len, max_kv_seq_len = 1, None
            elif is_unpaged_prefill:
                max_q_seq_len = max_kv_seq_len = max(q_seqlens_list)
            else:
                max_q_seq_len = max(q_seqlens_list)
                max_kv_seq_len = max(kv_seqlens_list)
            return max_q_seq_len, max_kv_seq_len

        def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list,
                                                    max_q_seq_len, max_kv_seq_len):
            kv_start_indices, attention_mask = [], []
            if is_decoding:
                idx = (step_context.kv_seqlens - 1) % block_size
                block_num = (step_context.kv_seqlens - 1) // block_size
                last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1)
                kv_start_indices = last_block * block_size + idx
            else:
                for i in range(step_context.q_start_loc.size(0)):
                    q_seq_len = q_seqlens_list[i]
                    kv_seq_len = kv_seqlens_list[i]

                    history_length = kv_seq_len - q_seq_len
                    total_slots = get_total_slots()
                    slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
                    slots = slot_tables[history_length:kv_seq_len]
                    kv_start_indices.append(slots)

                    if not is_unpaged_prefill:
                        single_attention_mask = torch.triu(
                            torch.ones(q_seq_len,
                                       step_context.block_offsets.shape[1] * block_size,
                                       dtype=torch.bool,
                                       device=step_context.block_offsets.device),
                            diagonal=kv_seq_len - q_seq_len + 1,
                        )
                        attention_mask.append(single_attention_mask)

                if is_unpaged_prefill:
                    attention_mask.append(
                        torch.triu(torch.ones(max_q_seq_len,
                                              max_kv_seq_len,
                                              dtype=step_context.kv_caches[0][0].dtype,
                                              device=step_context.block_offsets.device),
                                   diagonal=max_kv_seq_len - max_q_seq_len + 1))
                else:
                    attention_mask = [torch.cat(attention_mask)]

                kv_start_indices = torch.cat(kv_start_indices)

            return kv_start_indices, attention_mask

        def get_dist_meta():
            if cls.dist_meta is not None:
                return cls.dist_meta
            dist_ctx = get_dist_manager().current_context()
            dp_size, tp_size, ep_size = dist_ctx.dist_config.dp, dist_ctx.dist_config.tp, dist_ctx.dist_config.ep
            tp_rank, ep_rank = dist_ctx.attn_tp_group.rank, dist_ctx.ep_rank
            tp_group = dist_ctx.attn_tp_group.gpu_group
            ep_group = dist_ctx.ep_gpu_group
            cls.dist_meta = DistMeta(dp_size=dp_size,
                                     tp_size=tp_size,
                                     ep_size=ep_size,
                                     tp_rank=tp_rank,
                                     ep_rank=ep_rank,
                                     tp_group=tp_group,
                                     ep_group=ep_group)
            return cls.dist_meta

        def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
            if ep_size <= 1:
                return 0, 0, 0
            # get padded_tokens_current_rank
            is_graph = cls.enable_graph and step_context.is_decoding
            if is_graph:
                from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
                actual_tokens_current_rank = step_context.q_seqlens.shape[0]
                padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),
                                                 cls.max_batches)
            else:
                actual_tokens_current_rank = step_context.q_seqlens.sum().item()
                padded_tokens_current_rank = actual_tokens_current_rank
            # get max_tokens_across_dp
            if dp_size > 1:
                runtime_tokens_tensor = torch.tensor([padded_tokens_current_rank],
                                                     dtype=step_context.q_seqlens.dtype,
                                                     device=torch.npu.current_device())
                world_size = dp_size * tp_size
                runtime_tokens_buffer = torch.zeros([world_size],
                                                    dtype=step_context.q_seqlens.dtype,
                                                    device=torch.npu.current_device())
                dist.all_gather_into_tensor(runtime_tokens_buffer, runtime_tokens_tensor, ep_group)
                max_tokens_across_dp = torch.max(runtime_tokens_buffer).item()
            else:
                max_tokens_across_dp = padded_tokens_current_rank
            return actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp

        @lru_cache
        def init_mc2_token_capacity(tp_size):
            max_num_tokens = min(cls.max_batches, 512)
            num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
            return num_tokens_per_tp_rank * tp_size

        def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
            if ep_size <= 1:
                return DlinferMoECommType.ALLGATHER
            mc2_token_capacity = init_mc2_token_capacity(tp_size)
            is_graph = cls.enable_graph and step_context.is_decoding
            if is_graph:
                max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size
            if SocVersion.is_A2():
                if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16:
                    return DlinferMoECommType.MC2
                else:
                    return DlinferMoECommType.ALLGATHER
            elif SocVersion.is_A3():
                if max_tokens_across_dp <= mc2_token_capacity:
                    return DlinferMoECommType.MC2
                else:
                    return DlinferMoECommType.ALLTOALL
            else:
                raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}')

        def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp, tp_size,
                         moe_comm_type):
            x_active_mask = None
            if moe_comm_type == DlinferMoECommType.MC2:
                padded_size = math.ceil(max_tokens_across_dp / tp_size) * tp_size
                pad_size = padded_size - padded_tokens_current_rank
                x_active_mask = torch.ones(actual_tokens_current_rank,
                                           dtype=torch.bool,
                                           device=torch.npu.current_device())
            elif moe_comm_type == DlinferMoECommType.ALLTOALL:
                pad_size = tp_size - padded_tokens_current_rank
            elif moe_comm_type == DlinferMoECommType.ALLGATHER:
                pad_size = max_tokens_across_dp - padded_tokens_current_rank
            else:
                pad_size = 0
            return pad_size, x_active_mask

        @lru_cache(maxsize=1)
        def get_moe_group_name(group):
            if group is None:
                return None
            local_rank = torch.distributed.get_rank(group=group)
            backend = group._get_backend(torch.device('npu'))
            group_name = backend.get_hccl_comm_name(local_rank)
            return group_name

        q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
                                                                             is_unpaged_prefill)
        q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
                                                           kv_seqlens_cpu)
        max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,
                                                        kv_seqlens_list)
        kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
                                                                                   is_unpaged_prefill, q_seqlens_list,
                                                                                   kv_seqlens_list, max_q_seq_len,
                                                                                   max_kv_seq_len)

        if not cls.enable_graph and step_context.kv_quant_policy == 8:
            record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
            assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
            path = Path(record_file)
            is_path = path.is_absolute() or path.is_relative_to('/')
            exists = path.exists()
            if not (is_path and exists):
                raise ValueError('please specify valid ASCEND_QUANT_RECORD_FILE')
            if not AscendKVQuantMeta.has_set_value:
                total_layers = len(step_context.kv_caches)
                AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype,
                                            record_file, total_layers)

        attn_meta_cls = cls.get_attention_metadata_cls()
        attn_metadata = attn_meta_cls(
            step_context.is_decoding,
            step_context.block_offsets,
            q_start_loc=None,
            q_seqlens=q_seqlens_cpu,
            # kv_seqlens_expanded is only expanded in paged prefill,
            # otherwise it equals kv_seqlens_cpu
            kv_seqlens=kv_seqlens_expanded,
            kv_start_indices=kv_start_indices,
            block_size=block_size,
            attention_mask=attention_mask,
            is_unpaged_prefill=is_unpaged_prefill,
            max_q_seq_len=max_q_seq_len,
            max_kv_seq_len=max_kv_seq_len,
            quant_policy=step_context.kv_quant_policy,
            quant_meta=AscendKVQuantMeta.quant_meta,
        )
        step_context.attn_metadata = attn_metadata

        cls.dist_meta = get_dist_meta()
        actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp = get_tokens_info(
            cls.dist_meta.dp_size, cls.dist_meta.tp_size, cls.dist_meta.ep_size, cls.dist_meta.ep_group)
        moe_comm_type = select_moe_comm_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size,
                                             cls.dist_meta.ep_size)
        pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank,
                                               max_tokens_across_dp, cls.dist_meta.tp_size, moe_comm_type)
        moe_group_name = get_moe_group_name(cls.dist_meta.ep_group)

        moe_metadata = DlinferMoeMetadata(
            max_tokens_across_dp=max_tokens_across_dp,
            pad_size=pad_size,
            dp_size=cls.dist_meta.dp_size,
            tp_size=cls.dist_meta.tp_size,
            ep_size=cls.dist_meta.ep_size,
            tp_rank=cls.dist_meta.tp_rank,
            ep_rank=cls.dist_meta.ep_rank,
            tp_group=cls.dist_meta.tp_group,
            ep_group=cls.dist_meta.ep_group,
            moe_comm_type=moe_comm_type,
            x_active_mask=x_active_mask,
            moe_group_name=moe_group_name,
        )
        step_context.moe_metadata = moe_metadata
        return step_context

    @staticmethod
    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                           backend_config: BackendConfig, device: torch.device):
        """Build graph runner."""
        AscendOpsBackend.enable_graph = not backend_config.eager_mode
        AscendOpsBackend.max_batches = cache_config.max_batches
        from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import AscendGraphRunner
        return AscendGraphRunner(model, model_config, cache_config, backend_config, device)

    @staticmethod
    def init():
        """Initialize Ascend backend."""
        try:
            from torch_npu.contrib import transfer_to_npu  # noqa: F401
        except ImportError:
            logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. '
                           'Ascend initialization skipped.')
        except Exception as e:
            logger.warning(f'Error during Ascend initialization: {str(e)}. '
                           'Please check your Ascend environment configuration.')

    @staticmethod
    def ccl_backend():
        return 'hccl'

    @staticmethod
    def device_count():
        """Get num available devices."""
        return torch.npu.device_count()

    @staticmethod
    def support_ray():
        """Support ray."""
        if not _envs.ascend_set_rt_visable_devices_by_ray:
            os.environ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES'] = '1'
        return True


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/ascend/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch_npu

ACL_FORMAT_FRACTAL_NZ = 29


def nd_to_nz_spec(tensor: torch.Tensor) -> torch.Tensor:
    '''
    This function is copied from vllm-ascend commit hash: 420e794c35fe887db2be81cf9db0461f5b71da0b
    It converts a tensor in ACL_FORMAT_ND format to ACL_FORMAT_FRACTAL_NZ format for Ascend 310P devices.
    It behaves similarly to the TransdataOperation and it requires the input tensor to be 2D.
    '''
    num_tokens = tensor.shape[0]
    max_seq_len = tensor.shape[1]

    tokens_pad = (num_tokens + 15) // 16 * 16
    max_seq_len_pad = (max_seq_len + 15) // 16 * 16

    tensor_pad = \
        torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=tensor.dtype, device=tensor.device)

    tensor_pad[0][:num_tokens, :max_seq_len] = tensor
    tensor_nz = tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)

    tensor_nz = torch_npu.npu_format_cast(tensor_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
    return tensor_nz


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from dataclasses import dataclass
from typing import Dict, Optional, Sequence

from torch import Tensor

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


@dataclass
class DlinferAttentionMetadata(AttentionMetadata):
    kv_start_indices: Optional[Tensor] = None
    block_size: int = 64
    attention_mask: Sequence[Tensor] = tuple()
    is_unpaged_prefill: Optional[bool] = None
    max_q_seq_len: int = 1
    max_kv_seq_len: int = 1
    quant_meta: Dict = None
    cu_seq_lens_kv: Optional[Tensor] = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
    """Dlinfer attention implementation."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = None,
        sliding_window: int = None,
        logit_softcapping: float = None,
        causal: bool = True,
        **kwargs,
    ):
        assert causal
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            v_head_size,
            alibi,
            sliding_window,
            logit_softcapping,
            causal=causal,
            **kwargs,
        )

        from lmdeploy.pytorch.kernels.dlinfer import fill_kv_cache, paged_attention_fwd

        self.fill_kv_cache = fill_kv_cache
        self.paged_attention_fwd = paged_attention_fwd

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        k_cache: Tensor,
        v_cache: Tensor,
        attn_metadata: DlinferAttentionMetadata,
        k_scales_zeros: Tensor = None,
        v_scales_zeros: Tensor = None,
        learnable_sink: Tensor = None,
        nsa_indices: Tensor = None,
        inplace: bool = True,
    ) -> Tensor:
        """forward."""

        block_offsets = attn_metadata.block_offsets
        q_start_loc = attn_metadata.q_start_loc
        q_seqlens = attn_metadata.q_seqlens
        kv_seqlens = attn_metadata.kv_seqlens
        is_decoding = attn_metadata.is_decoding
        kv_start_indices = attn_metadata.kv_start_indices
        block_size = attn_metadata.block_size
        attn_mask = attn_metadata.attention_mask
        is_unpaged_prefill = attn_metadata.is_unpaged_prefill
        max_q_seq_len = attn_metadata.max_q_seq_len
        max_kv_seq_len = attn_metadata.max_kv_seq_len
        quant_bits = attn_metadata.quant_policy
        cu_seq_lens_kv = attn_metadata.cu_seq_lens_kv

        if attn_metadata.quant_meta is not None:
            k_scales_zeros = [next(attn_metadata.quant_meta['k_scales']),
                              next(attn_metadata.quant_meta['k_zeros'])
                              ] if 'k_scales' in attn_metadata.quant_meta else []
            v_scales_zeros = [next(attn_metadata.quant_meta['v_scales']),
                              next(attn_metadata.quant_meta['v_zeros'])
                              ] if 'v_scales' in attn_metadata.quant_meta else []
            kv_scales = next(attn_metadata.quant_meta['kv_scales']) if 'kv_scales' in attn_metadata.quant_meta else None
            kv_zeros = next(attn_metadata.quant_meta['kv_zeros']) if 'kv_zeros' in attn_metadata.quant_meta else None
        else:
            k_scales_zeros = []
            v_scales_zeros = []
            kv_scales = None
            kv_zeros = None

        # fill kv cache
        k_cache, v_cache = self.fill_kv_cache(key,
                                              value,
                                              k_cache,
                                              v_cache,
                                              kv_start_indices,
                                              k_scales_zeros=k_scales_zeros,
                                              v_scales_zeros=v_scales_zeros,
                                              quant_bits=quant_bits)

        if inplace:
            attn_output = query[..., :self.v_head_size]
        else:
            q_shape = query.shape
            o_shape = q_shape[:-1] + (self.v_head_size, )
            attn_output = query.new_empty(o_shape)

        attn_output = self.paged_attention_fwd(
            query,
            key,
            value,
            attn_output,
            k_cache,
            v_cache,
            block_offsets,
            q_start_loc=q_start_loc,
            q_seqlens=q_seqlens,
            kv_seqlens=kv_seqlens,
            cu_seq_lens_kv=cu_seq_lens_kv,
            max_q_seq_len=max_q_seq_len,
            max_kv_seq_len=max_kv_seq_len,
            is_decoding=is_decoding,
            block_size=block_size,
            num_heads=self.num_heads,
            num_kv_heads=self.num_kv_heads,
            v_head_size=self.v_head_size,
            attn_mask=attn_mask,
            softmax_scale=self.scale,
            is_unpaged_prefill=is_unpaged_prefill,
            kv_scales=kv_scales,
            kv_zeros=kv_zeros,
            quant_bits=quant_bits,
        )

        return attn_output


class DlinferAttentionBuilder(AttentionBuilder[DlinferAttentionMetadata]):
    """Dlinfer attention builder."""

    @staticmethod
    def build(
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi_scale: float = None,
        sliding_window: int = None,
        logit_softcapping: float = None,
        causal: bool = True,
        learnable_sink: bool = False,
        **kwargs,
    ) -> DlinferAttentionImpl:
        """build."""
        return DlinferAttentionImpl(num_heads,
                                    head_size,
                                    scale=scale,
                                    num_kv_heads=num_kv_heads,
                                    v_head_size=v_head_size,
                                    alibi_scale=alibi_scale,
                                    sliding_window=sliding_window,
                                    logit_softcapping=logit_softcapping,
                                    causal=causal,
                                    **kwargs)


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/awq_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch

from lmdeploy.pytorch.kernels.dlinfer import awq_linear

from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl


class AwqLinearW4A16Impl(LinearW4A16Impl):
    """Awq kernel linear."""

    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size

    def forward(self,
                x,
                qweight: torch.Tensor,
                scales: torch.Tensor,
                qzeros: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size)
        return out


class AwqLinearW4A16Builder(LinearW4A16Builder):
    """Awq linear builder."""

    @staticmethod
    def build(in_features: int,
              out_features: int,
              w_bit: int,
              group_size: int,
              bias: bool = False,
              dtype: torch.dtype = None):
        """build."""
        return AwqLinearW4A16Impl(in_features, out_features, w_bit, group_size)


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import CambOpsBackend  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class CambOpsBackend(DlinferOpsBackend):
    """Camb layer backend."""
    total_slots = None

    @staticmethod
    def get_name() -> str:
        """Backend name."""
        return 'camb'

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (
            num_heads,
            block_size,
            head_size,
        )

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (
            num_heads,
            block_size,
            head_size,
        )

    @classmethod
    def update_step_context(cls, step_context):
        """Update step context."""

        def get_total_slots():
            if cls.total_slots is None:
                cls.total_slots = torch.arange(block_num * block_size,
                                               dtype=torch.int32,
                                               device=step_context.block_offsets.device)
                cls.total_slots = cls.total_slots.view(block_num, block_size)
            return cls.total_slots

        kv_start_indices = []
        block_num, _, block_size, _ = step_context.kv_caches[0][0].shape

        is_unpaged_prefill = False
        q_start_loc = step_context.q_start_loc
        q_seqlens = step_context.q_seqlens
        kv_seqlens = step_context.kv_seqlens.to(torch.int32)
        block_offsets = step_context.block_offsets.to(torch.int32)
        max_q_seq_len = torch.max(q_seqlens).cpu().item()
        max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

        cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
        cu_seq_lens_kv = None

        q_seqlens_list = step_context.q_seqlens.tolist()
        kv_seqlens_list = step_context.kv_seqlens.tolist()
        if not step_context.is_decoding:
            is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
            # get kv_indices
            for i in range(q_start_loc.size(0)):
                q_seq_len = q_seqlens_list[i]
                kv_seq_len = kv_seqlens_list[i]
                # collect kv start indices.
                history_length = kv_seq_len - q_seq_len
                total_slots = get_total_slots()
                slot_tables = total_slots[block_offsets[i]].view(-1)
                slots = slot_tables[history_length:kv_seq_len]
                kv_start_indices.append(slots)
            kv_start_indices = torch.cat(kv_start_indices)
            if not is_unpaged_prefill:
                cu_seq_lens_kv = torch.cat((torch.tensor([0], device=kv_seqlens.device), kv_seqlens.cumsum(0))).int()
        else:
            # collect kv_start_indices without using a for-loop,
            # (fill kv-cache for just ONE token during the decoding phase)
            idx = (step_context.kv_seqlens - 1) % block_size
            block_num = (step_context.kv_seqlens - 1) // block_size
            last_block = block_offsets.gather(  # dtype of gather must be int64
                1, block_num.view(-1, 1)).view(-1)
            kv_start_indices = (last_block * block_size + idx).to(torch.int32)

        attn_meta_cls = cls.get_attention_metadata_cls()
        attn_metadata = attn_meta_cls(
            step_context.is_decoding,
            block_offsets,
            q_start_loc=cu_seqlens,
            cu_seq_lens_kv=cu_seq_lens_kv,
            q_seqlens=q_seqlens,
            kv_seqlens=kv_seqlens,
            kv_start_indices=kv_start_indices,
            block_size=block_size,
            attention_mask=None,
            is_unpaged_prefill=is_unpaged_prefill,
            max_q_seq_len=max_q_seq_len,
            max_kv_seq_len=max_kv_seq_len,
        )

        step_context.attn_metadata = attn_metadata
        return step_context

    @staticmethod
    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                           backend_config: BackendConfig, device: torch.device):
        """Build graph runner."""
        from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)

    @staticmethod
    def support_ray():
        """Support ray."""
        return True


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/flash_attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class DlinferFlashAttentionImpl(FlashAttentionImpl):
    """Dlinfer flash attention implementation."""

    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = None,
    ):
        if scale is None:
            scale = 1.0 / (head_dim**0.5)
        if num_kv_heads is None:
            num_kv_heads = num_heads
        if v_head_dim is None:
            v_head_dim = head_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.v_head_dim = v_head_dim
        self.causal = causal
        self.sliding_window = sliding_window
        self.logit_softcapping = logit_softcapping
        from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
        self.flash_attention_fwd = flash_attention_fwd

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                q_start_loc: Tensor,
                q_seqlens: Tensor,
                kv_start_loc: Tensor,
                kv_seqlens: Tensor,
                max_q_seqlen: int = None):
        """forward."""
        q_shape = query.shape
        o_shape = q_shape[:-1] + (self.v_head_dim, )
        out = query.new_empty(o_shape)
        self.flash_attention_fwd(
            query,
            key,
            value,
            out,
            q_start_loc=q_start_loc,
            q_seqlens=q_seqlens,
            kv_start_loc=kv_start_loc,
            kv_seqlens=kv_seqlens,
            num_heads=self.num_heads,
            num_kv_heads=self.num_kv_heads,
            max_q_seqlen=max_q_seqlen,
            window_size=self.sliding_window,
            sm_scale=self.scale,
            logit_softcapping=self.logit_softcapping,
            causal=self.causal,
        )
        return out


class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
    """Dlinfer attention builder."""

    @staticmethod
    def build(
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = None,
        **kwargs,
    ) -> FlashAttentionImpl:
        """build."""
        return DlinferFlashAttentionImpl(
            num_heads=num_heads,
            head_dim=head_dim,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_dim=v_head_dim,
            causal=causal,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
        )


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/linear.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import List, Optional

import torch
import torch.distributed as dist

from lmdeploy.pytorch.kernels.dlinfer import linear

from ..linear import LinearBuilder, LinearImpl


class DlinferLinearImpl(LinearImpl):
    """Dlinfer linear implementation api."""

    def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Update weights."""
        if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':
            weight = weight.data.t().contiguous()
        return weight, bias

    def forward(self,
                x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: dist.ProcessGroup = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        out = linear(x, weight, bias, False)
        if all_reduce:
            dist.all_reduce(out, group=group)
        return out


class DlinferLinearBuilder(LinearBuilder):
    """Dlinfer linear implementation builder."""

    @staticmethod
    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        return DlinferLinearImpl()


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/maca/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import MacaOpsBackend  # noqa: F401


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class MacaOpsBackend(DlinferOpsBackend):
    """Maca layer backend."""
    total_slots = None

    @staticmethod
    def get_name() -> str:
        """Backend name."""
        return 'maca'

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (block_size, num_heads, head_size)

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (block_size, num_heads, head_size)

    @classmethod
    def update_step_context(cls, step_context):
        """Update step context."""

        def get_total_slots():
            if cls.total_slots is None:
                cls.total_slots = torch.arange(block_num * block_size,
                                               dtype=torch.long,
                                               device=step_context.block_offsets.device)
                cls.total_slots = cls.total_slots.view(block_num, block_size)
            return cls.total_slots

        kv_start_indices, attention_mask = [], []
        block_num, block_size, _, _ = step_context.kv_caches[0][1].shape

        is_unpaged_prefill = False
        if not step_context.is_decoding:
            is_unpaged_prefill = \
               all((step_context.q_seqlens ==
                    step_context.kv_seqlens).tolist())
        q_start_loc = step_context.q_start_loc
        cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()

        q_seqlens = step_context.q_seqlens.int()
        kv_seqlens = step_context.kv_seqlens.int()

        if step_context.is_decoding:
            # max_q_seq_len, max_kv_seq_len is not used in decoding stage
            max_q_seq_len = -1
            max_kv_seq_len = -1

            # collect kv_start_indices without using a for-loop,
            # (fill kv-cache for just ONE token during the decoding phase)
            idx = (step_context.kv_seqlens - 1) % block_size
            b_num = (step_context.kv_seqlens - 1) // block_size
            last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1)
            kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))
        else:
            max_q_seq_len = torch.max(q_seqlens).cpu().item()
            max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

            for i in range(step_context.q_start_loc.size(0)):
                q_seq_len = int(step_context.q_seqlens[i])
                kv_seq_len = int(step_context.kv_seqlens[i])
                # collect kv start indices during the prefill phase.
                history_length = kv_seq_len - q_seq_len
                total_slots = get_total_slots()
                slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
                slots = slot_tables[history_length:kv_seq_len]
                kv_start_indices.append(slots)
            kv_start_indices = torch.cat(kv_start_indices)

        attn_meta_cls = cls.get_attention_metadata_cls()
        attn_metadata = attn_meta_cls(
            step_context.is_decoding,
            step_context.block_offsets.int(),
            q_start_loc=cu_seqlens,
            q_seqlens=q_seqlens,
            kv_seqlens=kv_seqlens,
            kv_start_indices=kv_start_indices,
            block_size=block_size,
            attention_mask=attention_mask,
            is_unpaged_prefill=is_unpaged_prefill,
            max_q_seq_len=max_q_seq_len,
            max_kv_seq_len=max_kv_seq_len,
        )

        step_context.attn_metadata = attn_metadata
        return step_context

    @staticmethod
    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                           backend_config: BackendConfig, device: torch.device):
        """Build graph runner."""
        from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)

    @staticmethod
    def support_ray():
        """Support ray."""
        return True


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List

import torch

from lmdeploy.pytorch.kernels.dlinfer import DlinferMoECommType  # noqa: F401
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetadata  # noqa: F401
from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager

from ..moe import FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder, SoftmaxTopKImpl


class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl):
    """Dlinfer softmax topk implementation."""

    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):
        self.top_k = top_k
        self.dim = dim
        self.n_groups = n_groups

    def forward(self, x: torch.Tensor):
        step_context = get_step_ctx_manager().current_context()
        moe_metadata = getattr(step_context, 'moe_metadata', None)
        if moe_metadata is not None:
            moe_metadata.router_n_groups = self.n_groups
        routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, moe_metadata)
        return routing_weights, selected_experts


class DlinferSoftmaxTopKBuilder(SoftmaxTopKBuilder):
    """Dlinfer softmax topk implementation builder."""

    @staticmethod
    def build(top_k: int, dim: int = -1, n_groups: int = -1):
        """build."""
        return DlinferSoftmaxTopKImpl(top_k, dim, n_groups)


class DlinferFusedMoEImpl(FusedMoEImpl):
    """Dlinfer fused moe implementation."""

    def __init__(self,
                 top_k: int,
                 num_experts: int,
                 renormalize: bool = False,
                 ep_size: int = 1,
                 ep_group: torch.distributed.ProcessGroup = None):
        self.top_k = top_k
        self.num_experts = num_experts
        self.renormalize = renormalize
        self.ep_size = ep_size
        self.ep_group = ep_group
        self.expert_ids_per_ep_rank = None
        if self.ep_size > 1:
            self.expert_ids_per_ep_rank = torch.tensor(
                [i % (self.num_experts // self.ep_size) for i in range(num_experts)],
                dtype=torch.int32,
                device=torch.cuda.current_device(),
            )

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
        """Update weights."""
        device_type = gate_up_weights.device.type
        if device_type in ['npu']:
            if os.getenv('DLINFER_RESET_MOE_UPDATE_WEIGHTS', '0') == '1':
                return gate_up_weights, down_weights
            return gate_up_weights.transpose(-1, -2).contiguous(), down_weights.transpose(-1, -2).contiguous()
        return gate_up_weights, down_weights

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        num_experts = self.num_experts
        expert_per_rank = (num_experts + world_size - 1) // world_size
        first_expert = rank * expert_per_rank
        last_expert = min(first_expert + expert_per_rank, num_experts)
        return list(range(first_expert, last_expert))

    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                down_weights: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        assert gate_up_bias is None
        assert down_bias is None

        step_context = get_step_ctx_manager().current_context()
        moe_metadata = getattr(step_context, 'moe_metadata', None)
        if moe_metadata is not None:
            moe_metadata.expert_ids_per_ep_rank = self.expert_ids_per_ep_rank
        return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, self.top_k,
                         self.renormalize, moe_metadata)


class DlinferFusedMoEBuilder(FusedMoEBuilder):
    """Dlinfer fused moe builder."""

    @staticmethod
    def build(top_k: int,
              num_experts: int,
              renormalize: bool = False,
              hidden_dim: int = 1,
              ep_size: int = 1,
              ep_group: torch.distributed.ProcessGroup = None,
              layer_idx: int = 0,
              out_dtype: torch.dtype = torch.bfloat16):
        """Build from mlp."""
        return DlinferFusedMoEImpl(top_k=top_k,
                                   num_experts=num_experts,
                                   renormalize=renormalize,
                                   ep_size=ep_size,
                                   ep_group=ep_group)


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.pytorch.kernels.dlinfer import rms_norm

from ..norm import RMSNormBuilder, RMSNormImpl


class DlinferRMSNormImpl(RMSNormImpl):
    """Dlinfer RMS norm implementation."""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        self.hidden_size = hidden_size
        self.eps = eps

    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        if residual is None:
            x = rms_norm(x, weight, self.eps)
            return x
        else:
            x, residual = rms_norm(x, weight, self.eps, residual=residual)
            return x, residual


class DlinferRMSNormBuilder(RMSNormBuilder):
    """Dlinfer RMS norm implementation builder."""

    @staticmethod
    def build(weight: torch.Tensor, eps: float = 1e-6):
        """build."""
        return DlinferRMSNormImpl(weight, eps)


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/op_backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.utils import get_logger

from ..base import OpType
from ..default import DefaultOpsBackend

logger = get_logger('lmdeploy')


class DlinferOpsBackend(DefaultOpsBackend):
    """Dlinfer layer backend."""

    @staticmethod
    def get_name() -> str:
        """Backend name."""
        return 'dlinfer'

    @classmethod
    def get_layer_impl_builder(cls, layer_type: OpType):
        """Get dlinfer layer builder."""
        if layer_type == OpType.PagedAttention:
            from .attention import DlinferAttentionBuilder
            return DlinferAttentionBuilder
        elif layer_type == OpType.FlashAttention:
            from .flash_attention import DlinferFlashAttentionBuilder
            return DlinferFlashAttentionBuilder
        elif layer_type == OpType.ApplyRotaryEmb:
            from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
            return DlinferApplyRotaryEmbBuilder
        elif layer_type == OpType.SiluAndMul:
            from .activation import DlinferSiluAndMulBuilder
            return DlinferSiluAndMulBuilder
        elif layer_type == OpType.RMSNorm:
            from .norm import DlinferRMSNormBuilder
            return DlinferRMSNormBuilder
        elif layer_type == OpType.LinearW8A8:
            from .qmodules import DlinferLinearW8A8Builder
            return DlinferLinearW8A8Builder
        elif layer_type == OpType.RMSNormW8A8:
            from .qmodules import DlinferRMSNormW8A8Builder
            return DlinferRMSNormW8A8Builder
        elif layer_type == OpType.SoftmaxTopK:
            from .moe import DlinferSoftmaxTopKBuilder
            return DlinferSoftmaxTopKBuilder
        elif layer_type == OpType.FusedMoE:
            from .moe import DlinferFusedMoEBuilder
            return DlinferFusedMoEBuilder
        elif layer_type == OpType.Linear:
            from .linear import DlinferLinearBuilder
            return DlinferLinearBuilder
        elif layer_type == OpType.LinearW4A16:
            from .awq_modules import AwqLinearW4A16Builder
            return AwqLinearW4A16Builder
        elif layer_type == OpType.RotaryEmbedding:
            from .rotary_embedding import DlinferRotaryEmbeddingBuilder
            return DlinferRotaryEmbeddingBuilder
        else:
            logger.debug(f'Op {layer_type} fallback to default implementation.')
            return super().get_layer_impl_builder(layer_type)

    @staticmethod
    def get_attention_metadata_cls():
        from .attention import DlinferAttentionMetadata
        return DlinferAttentionMetadata

    @staticmethod
    def get_k_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (
            block_size,
            num_heads,
            head_size,
        )

    @staticmethod
    def get_v_block_shape(
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ) -> Tuple[int, ...]:
        return (
            block_size,
            num_heads,
            head_size,
        )

    @classmethod
    def update_step_context(cls, step_context):
        """Update step context."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/qmodules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch
import torch.distributed as dist

from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import dynamic_quant, linear_w8a8, rms_norm_w8a8
from lmdeploy.pytorch.models.q_modules import QTensor

from ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl


class DlinferLinearW8A8Impl(LinearW8A8Impl):
    """Dlinfer linear w8a8 implementation."""

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 out_dtype: torch.dtype = torch.float16,
                 quant_dtype: torch.dtype = torch.int8):
        self.in_features = in_features
        self.out_features = out_features
        self.out_dtype = out_dtype
        self.quant_dtype = quant_dtype

    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Update weights."""
        if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':
            weight = weight.data.t().contiguous()
            scale = scale.data.t().contiguous()
        return weight, scale, bias

    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        if isinstance(x, torch.Tensor):
            input_quant, input_scale = dynamic_quant(x, self.quant_dtype)
        else:
            assert isinstance(x, QTensor)
            input_quant, input_scale = x.tensor, x.scale

        out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias)
        if all_reduce:
            dist.all_reduce(out, group=group)
        return out


class DlinferLinearW8A8Builder(LinearW8A8Builder):
    """Dlinfer linear w8a8 implementation builder."""

    @staticmethod
    def build(in_features: int,
              out_features: int,
              bias: bool = True,
              dtype: torch.dtype = None,
              quant_dtype: torch.dtype = torch.int8):
        """build."""
        return DlinferLinearW8A8Impl(in_features, out_features, dtype, quant_dtype)


class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl):
    """Dlinfer RMS norm w8a8 implementation api."""

    def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.quant_dtype = quant_dtype

    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        if residual is None:
            (x, rms_scale) = rms_norm_w8a8(x, weight, self.eps, self.quant_dtype)
            x = QTensor(x, rms_scale)
            return x
        else:
            (x, rms_scale, residual) = rms_norm_w8a8(x, weight, self.eps, self.quant_dtype, residual)
            x = QTensor(x, rms_scale)
            return x, residual


class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder):
    """Dlinfer RMS norm w8a8 implementation builder."""

    @staticmethod
    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):
        """build."""
        return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype)


================================================
FILE: lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math

import torch
from torch import nn

from ..default.rotary_embedding import (FopeRotaryEmbeddingImpl, LlamaDynamicNTKScalingRotaryEmbedding,
                                        YarnRotaryEmbeddingImpl)
from ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
                                RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)


def _rotary_embedding_fwd(position_ids: torch.Tensor,
                          inv_freq: torch.Tensor,
                          scaling_factor: float,
                          mscale: float = None,
                          dtype: torch.dtype = None):
    """Rotary embedding forward."""
    if dtype is None:
        dtype = torch.float16

    if scaling_factor != 1.0:
        position_ids = position_ids.float() / scaling_factor
    else:
        position_ids = position_ids.float()

    position_ids = position_ids.unsqueeze(-1)
    angles = position_ids * inv_freq.view(1, 1, -1)
    angles = torch.cat((angles, angles), dim=-1)

    sin = angles.sin()
    cos = angles.cos()

    if mscale is not None:
        cos = cos * mscale
        sin = sin * mscale
    return cos.to(dtype=dtype), sin.to(dtype=dtype)


class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):
    """Base rotary embedding."""

    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.base = base
        # yapf: disable
        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.float, device='cuda') / self.dim))
        # yapf: enable
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, x, position_ids):
        """forward."""
        # x: [bs, num_attention_heads, seq_len, head_size]
        dtype = x.dtype
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=self.scaling_factor, dtype=dtype)


class DlinferLlamaDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling.

    Credits to the Reddit users /u/bloc97 and /u/emozilla
    """

    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0, max_position_embeddings: int = 2048):
        super().__init__(dim, base, scaling_factor, max_position_embeddings)
        self.dim_scale_ratio = self.dim / (self.dim - 2)
        self.pos_freq_scaling = torch.arange(0, self.dim, 2, dtype=torch.int64).float().cuda() / self.dim
        self.scale_offset = self.scaling_factor - 1
        self.pos_scale_factor = self.scaling_factor / \
            self.max_position_embeddings

    def _ntk_inv_freq(self, seq_len: torch.Tensor):
        """Calculate inverse frequency with NTK scaling."""
        base = self.base * ((self.pos_scale_factor * seq_len) - self.scale_offset)**self.dim_scale_ratio
        inv_freq = 1.0 / (base**self.pos_freq_scaling)
        return inv_freq

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """forward."""
        dtype = x.dtype
        seq_len = torch.max(position_ids) + 1
        ntk_inv_freq = self._ntk_inv_freq(seq_len)
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        inv_freq = torch.where(seq_len > self.max_position_embeddings, ntk_inv_freq, self.inv_freq)

        cos, sin = _rotary_embedding_fwd(position_ids, inv_freq, scaling_factor=1.0, dtype=dtype)
        return cos, sin


class DlinferLlama3RotaryEmbeddingImpl(DlinferRotaryEmbeddingImpl):
    """Llama3 rotary embedding implementation."""

    def __init__(
        self,
        dim: int,
        base: int = 10000,
        scaling_factor: float = 1.0,
        low_freq_factor: float = 1.0,
        high_freq_factor: float = 4.0,
        original_max_position_embeddings: int = 8194,
    ):
        super().__init__(dim, base, scaling_factor)
        old_context_len = original_max_position_embeddings
        low_freq_wavelen = old_context_len / low_freq_factor
        high_freq_wavelen = old_context_len / high_freq_factor

        inv_freq = self.inv_freq
        factor = self.scaling_factor

        wavelen = 2 * math.pi / inv_freq
        # wavelen < high_freq_wavelen: do nothing
        # wavelen > low_freq_wavelen: divide by factor
        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
        # otherwise: interpolate between the two, using a smooth factor
        smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        self.scaling_factor = 1.0
        self.register_buffer('inv_freq', inv_freq_llama)


class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl):
    """Yarn rotary embedding implementation."""

    def __init__(self,
                 dim: int,
                 base: int = 10000,
                 scaling_factor: float = 1.0,
                 original_max_position_embeddings: int = 4096,
                 yarn_params: YarnParameters = None):
        super().__init__(dim, base, scaling_factor, original_max_position_embeddings, yarn_params)

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """forward."""
        dtype = x.dtype
        if self.inv_freq.device != x.device:
            self.inv_freq = self.inv_freq.to(x.device)
        return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=1.0, mscale=self.mscale, dtype=dtype)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
    """Rotary embedding dlinfer builder."""

    @staticmethod
    def build(
        dim: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        scaling_factor: float = 1.0,
        yarn_params: YarnParameters = None,
        longrope_params: LongRoPEScalingParameters = None,
        llama3_params: Llama3Parameters = None,
        fope_params: FopeParameters = None,
        emb_type: RopeType = RopeType.Default,
    ):
        """build."""
        if emb_type in (RopeType.Default, RopeType.LinearScaling):
            return DlinferRotaryEmbeddingImpl(dim, base, scaling_factor)
        elif emb_type == RopeType.DynamicNTKScaling:
            return DlinferLlamaDynamicNTKScalingRotaryEmbedding(dim, base, scaling_factor, max_position_embeddings)
        elif emb_type == RopeType.Llama3:
            return DlinferLlama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,
                                                    llama3_params.high_freq_factor, max_position_embeddings)
        elif emb_type == RopeType.Yarn:
            return DlinferYarnRotaryEmbeddingImpl(dim,
                                                  base,
                                                  scaling_factor,
                                                  max_position_embeddings,
                                                  yarn_params=yarn_params)
        elif emb_type == RopeType.Fope:
            return FopeRotaryEmbeddingImpl(
                dim,
                max_position_embeddings=max_position_embeddings,
                scaling_factor=scaling_factor,
                params=fope_params,
            )
        else:
            raise NotImplementedError(f'Unsupported embedding type: {emb_type}')


================================================
FILE: lmdeploy/pytorch/backends/embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch
import torch.distributed as dist


class EmbeddingImpl(ABC):
    """Embedding implementation api."""

    @abstractmethod
    def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):
        """forward."""
        raise NotImplementedError


class EmbeddingBuilder(ABC):
    """Embedding implementation builder."""

    @staticmethod
    @abstractmethod
    def build(start_index: int, end_index: int):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/flash_attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

from torch import Tensor


class FlashAttentionImpl(ABC):
    """FlashAttention implementation."""

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                q_start_loc: Tensor,
                q_seqlens: Tensor,
                kv_start_loc: Tensor,
                kv_seqlens: Tensor,
                max_q_seqlen: int = None):
        """forward."""
        raise NotImplementedError


class FlashAttentionBuilder(ABC):
    """FlashAttention implementation builder."""

    @staticmethod
    @abstractmethod
    def build(
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = None,
        **kwargs,
    ) -> FlashAttentionImpl:
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/gated_delta_rule.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch


class GatedDeltaRuleImpl(ABC):
    """Gated Delta Rule implementation api."""

    @abstractmethod
    def chunk_gated_delta_rule(self,
                               q: torch.Tensor,
                               k: torch.Tensor,
                               v: torch.Tensor,
                               g: torch.Tensor | None = None,
                               beta: torch.Tensor | None = None,
                               initial_state: torch.Tensor | None = None,
                               state_indices: torch.Tensor | None = None,
                               scale: float | None = None,
                               use_qk_l2norm_in_kernel: bool = False,
                               cu_seqlens: torch.Tensor | None = None,
                               output_final_state: bool = False):
        """forward."""
        raise NotImplementedError

    @abstractmethod
    def fused_recurrent_gated_delta_rule(self,
                                         q: torch.Tensor,
                                         k: torch.Tensor,
                                         v: torch.Tensor,
                                         g: torch.Tensor | None = None,
                                         beta: torch.Tensor | None = None,
                                         initial_state: torch.Tensor | None = None,
                                         state_indices: torch.Tensor | None = None,
                                         scale: float | None = None,
                                         use_qk_l2norm_in_kernel: bool = False,
                                         output_final_state: bool = False):
        """forward."""
        raise NotImplementedError


class GatedDeltaRuleBuilder(ABC):
    """Gated Delta Rule implementation builder."""

    @staticmethod
    @abstractmethod
    def build() -> GatedDeltaRuleImpl:
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/graph_runner.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from dataclasses import dataclass
from typing import List

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext


@dataclass
class GraphRunnerMeta:
    padding_batch_size: int = None


@functools.lru_cache
def _get_capture_batch_size_impl(max_batches: int):
    """Capture batch size."""
    ret = []
    batch_size = 1
    while batch_size < max_batches:
        ret.append(batch_size)
        batch_size *= 2
    ret.append(max_batches)
    return ret


class GraphRunner:
    """Graph runner."""

    def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
                 backend_config: BackendConfig, device: torch.device, **kwargs):
        self.model = model
        self.ctx_mgr = model.ctx_mgr
        self.device = device
        self.model_config = model_config
        self.cache_config = cache_config
        self.backend_config = backend_config
        self._runner_meta = GraphRunnerMeta()

    def __call__(self, **kwargs):
        """Call graph runner forward."""
        return self.model(**kwargs)

    def get_model(self):
        """Get model."""
        return self.model

    def get_logits(self, hidden_states: torch.Tensor):
        """Get logits of model output."""
        if not hasattr(self.model, 'get_logits'):
            return hidden_states
        return self.model.get_logits(hidden_states)

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare inputs."""
        return self.model.prepare_inputs_for_generation(
            past_key_values,
            inputs_embeds,
            context,
        )

    def update_model_metas(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare inputs."""
        if hasattr(self.model, 'update_model_metas'):
            return self.model.update_model_metas(
                past_key_values,
                inputs_embeds,
                context,
            )

        return None

    def get_input_processor(self):
        """Get input processor."""
        if hasattr(self.model, 'get_input_processor'):
            return self.model.get_input_processor()
        else:
            return None

    def reset(self):
        """Remove all graphs to prevent hanging on exit."""
        pass

    def get_meta(self):
        """Get graphrunner meta."""
        return self._runner_meta

    def update_inputs(self, inputs):
        return inputs

    def get_capture_batch_sizes(self) -> List[int]:
        """Capture batch sizes."""
        return _get_capture_batch_size_impl(self.cache_config.max_batches)


================================================
FILE: lmdeploy/pytorch/backends/linear.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearImpl(ABC):
    """Linear implementation api."""

    def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Update weights."""
        return weight, bias

    @abstractmethod
    def forward(self,
                x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: dist.ProcessGroup = None,
                rank: int = 0,
                scatter_size: List[int] = None):
        """forward."""
        raise NotImplementedError


class LinearBuilder(ABC):
    """Linear implementation builder."""

    @staticmethod
    @abstractmethod
    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/lora.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

import torch

from lmdeploy.pytorch.model_inputs import StepContextManager


@dataclass
class AdapterInfo:
    """Adapter information."""
    in_features: int
    out_features: int
    ranks: torch.Tensor
    scalings: torch.Tensor
    base_slice: slice
    rank_offsets: torch.Tensor = field(init=False)
    max_rank: int = field(init=False)

    def __post_init__(self):
        """Post init."""
        ranks = self.ranks
        rank_offsets = ranks.cumsum(0) - ranks
        max_rank = ranks.max().item()
        self.rank_offsets = rank_offsets
        self.max_rank = max_rank


class LoRAImpl(ABC):
    """Lora implementation."""

    @abstractmethod
    def forward(self,
                x: torch.Tensor,
                base_output: torch.Tensor,
                lora_A: torch.Tensor,
                lora_B: torch.Tensor,
                adapter_info: AdapterInfo,
                ctx_mgr: StepContextManager,
                colwise: bool,
                is_tp: bool = True):
        """forward."""
        raise NotImplementedError


class LoRABuilder(ABC):
    """Lora implementation builder."""

    @staticmethod
    @abstractmethod
    def build():
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from abc import ABC, abstractmethod
from typing import Callable, List, Optional

import torch
import torch.distributed as dist


class SoftmaxTopKImpl(ABC):
    """Softmax topk implementation api."""

    @staticmethod
    @functools.lru_cache
    def get_group_offsets(n_groups: int, group_size: int, device: str):
        group_offsets = (torch.arange(n_groups, device=device) * group_size).view(1, -1, 1)  # [1, n_groups, 1]
        return group_offsets

    @abstractmethod
    def forward(self, x: torch.Tensor):
        """forward."""
        raise NotImplementedError


class SoftmaxTopKBuilder(ABC):
    """Softmax topk implementation builder."""

    @staticmethod
    @abstractmethod
    def build(top_k: int, dim: int = -1, n_groups: int = -1):
        """build."""
        raise NotImplementedError


class FusedMoEImpl(ABC):
    """Fused moe implementation."""

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
        """Update weights."""
        return gate_up_weights, down_weights

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        raise NotImplementedError('Not Implemented.')

    @abstractmethod
    def forward(self,
                hidden_states: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                down_weights: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        raise NotImplementedError


class FusedMoEBuilder(ABC):
    """Fused moe builder."""

    @staticmethod
    @abstractmethod
    def build(top_k: int,
              num_experts: int,
              renormalize: bool = False,
              hidden_dim: int = 1,
              ep_size: int = 1,
              ep_group: dist.ProcessGroup = None,
              layer_idx: int = 0,
              out_dtype: torch.dtype = torch.bfloat16):
        """Build from mlp."""
        raise NotImplementedError


class FusedMoEW8A8Impl(ABC):
    """Fused moe w8a8 implementation."""

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
                       down_scale: torch.Tensor):
        """Update weights."""
        return gate_up_weights, down_weights, gate_up_scale, down_scale

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        raise NotImplementedError('Not Implemented.')

    @abstractmethod
    def forward(self,
                hidden_states: torch.Tensor,
                input_scale: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                gate_up_scale: torch.Tensor,
                down_weights: torch.Tensor,
                down_scale: torch.Tensor,
                expert_list: List[int] = None):
        """forward."""
        raise NotImplementedError


class FusedMoEW8A8Builder(ABC):
    """Fused moe w8a8 builder."""

    @staticmethod
    @abstractmethod
    def build(top_k: int,
              num_experts: int,
              renormalize: bool = False,
              out_dtype: torch.dtype = torch.float16,
              quant_dtype: torch.dtype = torch.int8):
        """Build from mlp."""
        raise NotImplementedError


class FusedMoEBlockedF8Impl(ABC):
    """Fused moe blocked f8 implementation."""

    def __init__(self):
        self.scale_fmt: Optional[str] = None

    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
                       down_scale: torch.Tensor):
        """Update weights."""
        return gate_up_weights, down_weights, gate_up_scale, down_scale

    def ep_expert_list(self, world_size: int, rank: int):
        """Experts list of current rank."""
        raise NotImplementedError('Not Implemented.')

    def set_scale_fmt(self, scale_fmt: Optional[str]):
        """Set scale fmt."""
        self.scale_fmt = scale_fmt

    @abstractmethod
    def forward(self,
                hidden_states: torch.Tensor,
                input_scale: torch.Tensor,
                topk_weights: torch.Tensor,
                topk_ids: torch.LongTensor,
                gate_up_weights: torch.Tensor,
                gate_up_scale: torch.Tensor,
                down_weights: torch.Tensor,
                down_scale: torch.Tensor,
                gate_up_bias: torch.Tensor = None,
                down_bias: torch.Tensor = None,
                expert_list: List[int] = None,
                act_func: Callable = None):
        """forward."""
        raise NotImplementedError


class FusedMoEBlockedF8Builder(ABC):
    """Fused moe blocked f8 builder."""

    @staticmethod
    @abstractmethod
    def build(top_k: int,
              num_experts: int,
              hidden_dim: int = 1,
              renormalize: bool = False,
              block_size: int = 128,
              ep_size: int = 1,
              ep_group: dist.ProcessGroup = None,
              out_dtype: torch.dtype = torch.float16,
              layer_idx: int = 0,
              custom_gateup_act: bool = False):
        """Build from mlp."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/moe_router.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Tuple

import torch


class RouterNoauxTCImpl(ABC):
    """Noaux tc implementation api."""

    @abstractmethod
    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """forward."""
        raise NotImplementedError


class RouterNoauxTCBuilder(ABC):
    """Noaux tc implementation builder."""

    @staticmethod
    @abstractmethod
    def build(
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch


class MultinomialSamplingImpl(ABC):
    """Multinomial sampling implementation api."""

    @abstractmethod
    def forward(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, indices: torch.Tensor = None):
        """forward."""
        raise NotImplementedError


class MultinomialSamplingBuilder(ABC):
    """Multinomial sampling implementation builder."""

    @staticmethod
    @abstractmethod
    def build():
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch


class RMSNormImpl(ABC):
    """RMS norm implementation api."""

    @abstractmethod
    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        raise NotImplementedError


class RMSNormBuilder(ABC):
    """RMS norm implementation builder."""

    @staticmethod
    @abstractmethod
    def build(hidden_size: int, eps: float = 1e-6):
        """build."""
        raise NotImplementedError


class LayerNormImpl(ABC):
    """Layer norm implementation api."""

    @abstractmethod
    def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, residual: torch.Tensor = None):
        """forward."""
        raise NotImplementedError


class LayerNormBuilder(ABC):
    """Layer norm implementation builder."""

    @staticmethod
    @abstractmethod
    def build(normalized_shape: int, eps: float = 1e-6):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/nsa.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass

from torch import Tensor


@dataclass
class NSAIndexMeta:
    """Meta info of NSAIndex layer."""
    cu_seqlen_q: Tensor
    q_seqlens: Tensor
    k_seqlens: Tensor
    block_offset: Tensor
    max_q_seqlen: int = None
    max_kv_seqlen: int = None


class BaseNSAIndexFP8(ABC):

    @abstractmethod
    def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,
                meta: NSAIndexMeta) -> Tensor:
        """forward."""
        raise NotImplementedError('Not implemented.')


class BaseNSAIndexFP8Builder:

    @staticmethod
    @abstractmethod
    def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:
        """Build layer implementation."""
        raise NotImplementedError('Not implemented.')


================================================
FILE: lmdeploy/pytorch/backends/qmodules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional

import torch


class RMSNormW8A8Impl(ABC):
    """RMS norm w8a8 implementation api."""

    @staticmethod
    def create_weight(hidden_size: int, dtype: torch.dtype = None, device: torch.device = None):
        """Create weight."""
        if dtype is None:
            dtype = torch.float16
        if device is None:
            device = 'cuda'
        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)
        return weight

    @abstractmethod
    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        raise NotImplementedError


class RMSNormW8A8Builder(ABC):
    """RMS norm w8a8 implementation builder."""

    @staticmethod
    @abstractmethod
    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):
        """build."""
        raise NotImplementedError


class LinearW8A8Impl(ABC):
    """Linear w8a8 implementation api."""

    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Update weights."""
        return weight, scale, bias

    @abstractmethod
    def forward(self,
                x,
                weight: torch.Tensor,
                scale: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                all_reduce: bool = False,
                group: Optional[torch.distributed.ProcessGroup] = None):
        """forward."""
        raise NotImplementedError


class LinearW8A8Builder(ABC):
    """Linear w8a8 implementation builder."""

    @staticmethod
    @abstractmethod
    def build(in_features: int,
              out_features: int,
              bias: bool = True,
              dtype: torch.dtype = None,
              quant_dtype: torch.dtype = torch.int8):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/rotary_embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List

import torch


class RopeType(Enum):
    """Rotary embedding type."""
    Default = auto()
    LinearScaling = auto()
    DynamicNTKScaling = auto()
    Llama3 = auto()
    Yarn = auto()
    LongRoPEScaling = auto()
    Fope = auto()


@dataclass
class YarnParameters:
    """Yarn parameters."""
    beta_fast: int = 32
    beta_slow: float = 1
    mscale: int = 1
    mscale_all_dim: int = 0
    attention_factor: int = None
    truncate: bool = True


@dataclass
class LongRoPEScalingParameters:
    """Long Ropescaling parameters."""
    short_factor: List[int]
    long_factor: List[int]
    original_max_position_embeddings: int
    long_mscale: float = None
    short_mscale: float = None


@dataclass
class Llama3Parameters:
    """Llama3 rope parameters."""
    low_freq_factor: float = 1.0
    high_freq_factor: float = 4.0
    original_max_position_embeddings: int = 8192


@dataclass
class FopeParameters:
    """Fope parameters."""
    num_inv_freq: int = None
    num_key_value_heads: int = 1
    fope_sep_head: bool = False
    inv_freq: torch.Tensor = None


class RotaryEmbeddingImpl(ABC):
    """Rotary embedding implementation api."""

    @abstractmethod
    def forward(self, x, position_ids, **kwargs):
        """forward."""
        raise NotImplementedError


class RotaryEmbeddingBuilder(ABC):
    """Rotary embedding implementation builder."""

    @staticmethod
    @abstractmethod
    def build(
        dim: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        scaling_factor: float = 1.0,
        yarn_params: YarnParameters = None,
        longrope_params: LongRoPEScalingParameters = None,
        llama3_params: Llama3Parameters = None,
        fope_params: FopeParameters = None,
        emb_type: RopeType = RopeType.Default,
    ):
        """build."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/backends/selector.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager


def _get_backend():
    """Get device backend implement."""
    device_mgr = get_device_manager()
    device_ctx = device_mgr.current_context()

    device_type = device_ctx.device_type

    if device_type == 'cuda':
        from .cuda import CudaOpsBackend
        return CudaOpsBackend
    if device_type == 'ascend':
        from .dlinfer.ascend import AscendOpsBackend
        return AscendOpsBackend
    if device_type == 'maca':
        from .dlinfer.maca import MacaOpsBackend
        return MacaOpsBackend
    if device_type == 'camb':
        from .dlinfer.camb import CambOpsBackend
        return CambOpsBackend
    else:
        raise RuntimeError(f'Unsupported device type: {device_type}')


def get_backend(backend_type: str = None):
    """Get device backend."""
    if backend_type is None:
        return _get_backend()
    else:
        device_ctx = DeviceContext(backend_type)
        device_mgr = get_device_manager()
        with device_mgr.context(device_ctx):
            return _get_backend()


def init_backend(backend_type: str):
    """Init device backend."""
    backend = get_backend(backend_type)
    backend.init()


================================================
FILE: lmdeploy/pytorch/backends/token_dispatcher.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Tuple

import torch


class TokenDispatcherImpl(ABC):
    """Token dispatcher implementation api."""

    def permute(
        self,
        tokens,
        routing_map,
    ):
        """Copy from Megatron-Core moe for token permutation."""
        num_tokens, _ = tokens.shape
        num_experts = routing_map.shape[1]
        routing_map = routing_map.bool().T.contiguous()
        token_indices = (torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1))
        sorted_indices = token_indices.masked_select(routing_map)
        permuted_input = tokens.index_select(0, sorted_indices)
        return permuted_input, sorted_indices

    def unpermute(
        self,
        permuted_tokens: torch.Tensor,
        sorted_indices: torch.Tensor,
        restore_shape: torch.Size,
        probs: torch.Tensor = None,
        routing_map: torch.Tensor = None,
    ):
        """Copy from Megatron-Core moe for token unpermutation."""
        _, hidden = restore_shape
        if probs is not None:
            assert routing_map is not None, 'Mask must be provided to permute the probs.'
            permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
            permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
        output_tokens = torch.zeros(restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype)
        output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
        return output_tokens

    def indices_to_multihot(self, topk_ids, topk_weight, num_experts):
        tokens = topk_ids.shape[0]
        multihot_routing_map = torch.zeros(
            (tokens, num_experts),
            dtype=torch.bool,
            device=topk_ids.device,
        )

        multihot_probs = torch.zeros(
            (tokens, num_experts),
            dtype=topk_weight.dtype,
            device=topk_weight.device,
        )

        mask = topk_ids != -1
        valid_indices = topk_ids[mask]
        row_indices = torch.arange(tokens, device=topk_ids.device).repeat_interleave(mask.sum(dim=1))
        multihot_routing_map[row_indices, valid_indices] = True
        multihot_probs[row_indices, valid_indices] = topk_weight[mask]
        return multihot_routing_map, multihot_probs

    @abstractmethod
    def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor, topk_ids: torch.Tensor,
                 local_expert_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """dispatch."""
        raise NotImplementedError

    @abstractmethod
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """combine."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/block.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np


def _div_up(x, n):
    """Perform div up."""
    return (x + n - 1) // n


def _round_up(x, n):
    """Perform round up."""
    return _div_up(x, n) * n


class LogicalTokenBlocks:
    """Logical blocks."""
    ALLOC_SIZE = 128

    def __init__(self, blocks: np.ndarray = None):
        if blocks is None:
            self._blocks = np.zeros((self.ALLOC_SIZE, ), dtype=np.int64)
            self._num_real = 0
        else:
            assert blocks.ndim == 1
            self._blocks = blocks
            self._num_real = len(blocks)
        self.last_shared_node = None

    def reserve(self, size: int):
        """Reserve cache size."""
        num_blocks = self._blocks.size
        if num_blocks >= size:
            return
        reserve_size = _round_up(size - num_blocks, self.ALLOC_SIZE)
        self._blocks = np.pad(self._blocks, (0, reserve_size))

    def __setitem__(self, *args, **kwargs):
        """Set values."""
        return self.get_real_blocks().__setitem__(*args, **kwargs)

    def __getitem__(self, *args, **kwargs):
        """Get values."""
        return self.get_real_blocks().__getitem__(*args, **kwargs)

    def get_real_blocks(self):
        """Get logical blocks."""
        return self._blocks[:self._num_real]

    def append(self, blocks: np.ndarray):
        """Append blocks."""
        num_blocks = len(blocks)
        self.reserve(num_blocks + self._num_real)
        slice_start = self._num_real
        slice_end = slice_start + num_blocks
        self._num_real += num_blocks
        self._blocks[slice_start:slice_end] = blocks

    def __len__(self):
        """Get length."""
        return self._num_real

    def resize(self, num_blocks: int):
        """Resize logical blocks."""
        assert num_blocks <= len(self)
        self._num_real = num_blocks

    def reset(self):
        """reset."""
        self.resize(0)
        self.last_shared_node = None

    def clone(self):
        """Clone logical blocks."""
        ret = LogicalTokenBlocks()
        ret.append(self.get_real_blocks())
        return ret


================================================
FILE: lmdeploy/pytorch/check_env/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/check_env/adapter.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker


class AdapterChecker(BaseChecker):
    """Check adapter is available."""

    def __init__(self, adapter_path: str, logger=None):
        super().__init__(logger)
        self.adapter_path = adapter_path

    def check(self):
        """check."""
        path = self.adapter_path

        try:
            import peft  # noqa: F401
        except Exception as e:
            self.log_and_exit(e, 'Adapter', message='Failed to import peft.')

        try:
            from peft import PeftConfig
            PeftConfig.from_pretrained(path)
        except Exception as e:
            message = ('Please make sure the adapter can be loaded with '
                       '`peft.PeftConfig.from_pretrained`\n')
            err_msg = '' if len(e.args) == 0 else e.args[0]
            if 'got an unexpected keyword argument' in err_msg:
                message += ('Or try remove all unexpected keywords '
                            'in `adapter_config.json`.')
            self.log_and_exit(e, 'Adapter', message=message)


================================================
FILE: lmdeploy/pytorch/check_env/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from logging import Logger
from typing import List

from lmdeploy.utils import can_colorize, get_logger

RED_COLOR = '\033[31m'
RESET_COLOR = '\033[0m'


def _red_text(text: str):
    """Red text."""
    if not can_colorize():
        return text
    return f'{RED_COLOR}{text}{RESET_COLOR}'


class BaseChecker:
    """Base checker."""

    def __init__(self, logger: Logger = None):
        if logger is None:
            logger = get_logger('lmdeploy')
        self.logger = logger
        self._is_passed = False
        self._required_checker: List[BaseChecker] = list()

    def get_logger(self):
        """Get logger."""
        return self.logger

    def register_required_checker(self, checker: 'BaseChecker'):
        """register_required."""
        self._required_checker.append(checker)

    def handle(self):
        """Handle check."""
        is_passed = getattr(self, '_is_passed', False)
        if not is_passed:
            checker_name = type(self).__name__
            self.logger.debug(f'Checking <{checker_name}>:')
            for checker in self._required_checker:
                checker.handle()
            self.check()
            self.is_passed = True

    def log_and_exit(self, e: Exception = None, mod_name: str = None, message: str = None):
        logger = self.logger
        if mod_name is None:
            mod_name = type(self).__name__
        if message is None:
            message = 'Please check your environment.'
        logger.debug('Exception', exc_info=1)
        if e is not None:
            logger.error(f'{type(e).__name__}: {e}')
        logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}')
        exit(1)

    def check(self):
        """check."""
        raise NotImplementedError('check not implemented.')


================================================
FILE: lmdeploy/pytorch/check_env/cuda.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker


class CudaChecker(BaseChecker):
    """Check pytorch is available."""

    def __init__(self, model_format: str = None, logger=None) -> None:
        super().__init__(logger=logger)
        self.model_format = model_format

    def check(self):
        """check."""
        import torch

        if not torch.cuda.is_available():
            self.log_and_exit(mod_name='CUDA', message='cuda is not available.')

        if self.model_format == 'fp8':
            props = torch.cuda.get_device_properties(0)
            if props.major < 9:
                self.log_and_exit(mod_name='CUDA', message='model_format=fp8 requires sm>=9.0.')


================================================
FILE: lmdeploy/pytorch/check_env/deeplink.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.utils import try_import_deeplink

from .base import BaseChecker


class DeeplinkChecker(BaseChecker):
    """Check pytorch is available."""

    def __init__(self, device_type: str, logger=None) -> None:
        super().__init__(logger=logger)
        self.device_type = device_type

    def check(self):
        """check."""
        try_import_deeplink(self.device_type)


================================================
FILE: lmdeploy/pytorch/check_env/dist.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from lmdeploy.pytorch.config import DistConfig
from lmdeploy.utils import is_dlblas_installed

from .base import BaseChecker


class DistChecker(BaseChecker):
    """Check dist environment."""

    def __init__(self, tp: int, dp: int, ep: int, distributed_executor_backend: str, device_type: str, logger=None):
        super().__init__(logger)
        self.tp = tp
        self.dp = dp
        self.ep = ep
        self.dist_config = DistConfig(dp=dp, tp=tp, ep=ep)
        self.world_size = self.dist_config.world_size
        self.distributed_executor_backend = distributed_executor_backend
        self.device_type = device_type

    def check(self):
        """check."""
        distributed_executor_backend = self.distributed_executor_backend

        if distributed_executor_backend is None:
            from lmdeploy.pytorch.engine.executor import get_distributed_executor_backend
            distributed_executor_backend = get_distributed_executor_backend(self.world_size, self.dp, self.device_type)

        if distributed_executor_backend not in [None, 'uni', 'mp', 'ray']:
            self.log_and_exit(mod_name='Dist',
                              message=f'Unsupported distributed_executor_backend: {distributed_executor_backend}')

        if distributed_executor_backend == 'uni' and self.world_size > 1:
            self.log_and_exit(mod_name='Dist',
                              message='Does not support distributed_executor_backend="uni" and world_size!=1.')

        if self.dp > 1 and distributed_executor_backend != 'ray':
            self.log_and_exit(mod_name='Dist',
                              message='dp>1 requires distributed_executor_backend="ray". '
                              f'Get distributed_executor_backend={distributed_executor_backend}.')

        if self.ep > 1:
            if self.device_type == 'cuda' and not is_dlblas_installed():
                self.log_and_exit(mod_name='Dist',
                                  message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).')
            if self.ep % self.dp != 0:
                self.log_and_exit(mod_name='Dist',
                                  message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.')
        elif self.dist_config.enable_eplb:
            self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.')

        if distributed_executor_backend == 'ray':
            try:
                import ray  # noqa: F401
            except BaseException:
                self.log_and_exit(mod_name='Dist', message='Multi-nodes support requires `ray`.')

            from lmdeploy.pytorch.backends import get_backend
            backend = get_backend(self.device_type)
            if not backend.support_ray():
                self.log_and_exit(mod_name='Dist', message=f'device={self.device_type} does not support ray.')


================================================
FILE: lmdeploy/pytorch/check_env/model.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from packaging import version

from .base import BaseChecker


class ModelChecker(BaseChecker):
    """Check model is available."""

    def __init__(self, model_path: str, trust_remote_code: bool, dtype: str, device_type: str, logger=None) -> None:
        super().__init__(logger=logger)
        self.model_path = model_path
        self.trust_remote_code = trust_remote_code
        self.device_type = device_type
        self.dtype = dtype

    def check_config(self, trans_version):
        """Check config."""
        model_path = self.model_path
        trust_remote_code = self.trust_remote_code
        try:
            from lmdeploy.pytorch.transformers import config_from_pretrained
            config = config_from_pretrained(model_path, trust_remote_code=trust_remote_code)
        except Exception as e:
            message = (f'Load model config with transformers=={trans_version}'
                       ' failed. '
                       'Please make sure model can be loaded with transformers API.')
            self.log_and_exit(e, 'transformers', message=message)
        return config

    def check_trans_version(self, config, trans_version):
        """Check transformers version."""
        model_path = self.model_path
        logger = self.get_logger()
        model_trans_version = getattr(config, 'transformers_version', None)
        if model_trans_version is not None:
            model_trans_version = version.parse(model_trans_version)
            if trans_version < model_trans_version:
                message = (f'model `{model_path}` requires '
                           f'transformers version {model_trans_version} '
                           f'but transformers {trans_version} is installed.')
                logger.warning(message)

    def check_dtype(self, config):
        """Check dtype."""
        logger = self.get_logger()
        model_path = self.model_path
        device_type = self.device_type
        dtype = self.dtype
        try:
            import torch

            from lmdeploy.pytorch.config import ModelConfig
            from lmdeploy.utils import is_bf16_supported
            model_config = ModelConfig.from_hf_config(config,
                                                      model_path=model_path,
                                                      dtype=dtype,
                                                      device_type=device_type)
            if model_config.dtype == torch.bfloat16:
                if not is_bf16_supported(device_type):
                    logger.warning('Device does not support bfloat16.')
        except Exception as e:
            message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
            self.log_and_exit(e, 'Model', message=message)

        try:
            model_config.check_env_func(device_type)
        except Exception as e:
            message = (f'Checking failed with error {e}.')
            self.log_and_exit(e, 'Model', message=message)

    def check(self):
        """check."""
        import transformers
        trans_version = version.parse(transformers.__version__)

        # config
        config = self.check_config(trans_version)

        # transformers version
        self.check_trans_version(config, trans_version)

        # dtype check
        self.check_dtype(config)


================================================
FILE: lmdeploy/pytorch/check_env/torch.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker


class TorchChecker(BaseChecker):
    """Check pytorch is available."""

    def __init__(self, device: str = 'cuda', logger=None) -> None:
        super().__init__(logger=logger)
        self.device = device

    def check(self):
        """check."""
        try:
            import torch
            a = torch.tensor([1, 2], device=self.device)
            b = a.new_tensor([3, 4], device=self.device)
            c = a + b
            torch.testing.assert_close(c, a.new_tensor([4, 6]))
        except Exception as e:
            self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.')


================================================
FILE: lmdeploy/pytorch/check_env/transformers.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from packaging import version

from .base import BaseChecker

MIN_TRANSFORMERS_VERSION = '4.33.0'
MAX_TRANSFORMERS_VERSION = '5.2.0'


class TransformersChecker(BaseChecker):
    """Check transformers is available."""

    def check(self):
        """check."""
        import transformers
        logger = self.get_logger()
        try:
            trans_version = version.parse(transformers.__version__)
            min_version = version.parse(MIN_TRANSFORMERS_VERSION)
            max_version = version.parse(MAX_TRANSFORMERS_VERSION)
            if trans_version < min_version or trans_version > max_version:
                logger.warning('LMDeploy requires transformers version: '
                               f'[{MIN_TRANSFORMERS_VERSION} ~ '
                               f'{MAX_TRANSFORMERS_VERSION}], '
                               'but found version: '
                               f'{transformers.__version__}')
        except Exception as e:
            self.log_and_exit(e, 'transformers', 'transformers is not available.')


================================================
FILE: lmdeploy/pytorch/check_env/triton.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from packaging import version

from .base import BaseChecker

MAX_TRITON_VERSION = '3.6.0'
MIN_TRITON_VERSION = '3.0.0'


class TritonChecker(BaseChecker):
    """Check triton is available."""

    def check_version(self):
        """Check version."""
        logger = self.get_logger()

        # version check
        import triton
        max_version = version.parse(MAX_TRITON_VERSION)
        min_version = version.parse(MIN_TRITON_VERSION)
        triton_version = version.parse(triton.__version__)

        if triton_version > max_version:
            logger.warning('PytorchEngine has not been tested on '
                           f'triton>{MAX_TRITON_VERSION}.')
        if triton_version < min_version:
            msg = (f'triton>={MIN_TRITON_VERSION} is required. '
                   f'Found triton=={triton_version}')
            self.log_and_exit(mod_name='Triton', message=msg)

    def check(self):
        """check."""
        logger = self.get_logger()

        msg = (
            'Please ensure that your device is functioning properly with .\n'  # noqa: E501
            'You can verify your environment by running '
            '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
        try:
            logger.debug('Checking  environment.')
            import torch

            from .triton_custom_add import custom_add
            a = torch.tensor([1, 2], device='cuda')
            b = a.new_tensor([3, 4], device='cuda')
            c = custom_add(a, b)
            torch.testing.assert_close(c, a + b)
        except RuntimeError as e:
            ptxas_error = 'device kernel image is invalid'
            if len(e.args) > 0 and ptxas_error in e.args[0]:
                msg = (
                    'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n'  # noqa: E501
                    'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209'  # noqa: E501
                    ' or reinstall the driver.')
            self.log_and_exit(e, 'Triton', msg)
        except Exception as e:
            self.log_and_exit(e, 'Triton', msg)

        # version check
        self.check_version()


================================================
FILE: lmdeploy/pytorch/check_env/triton_custom_add.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl


@triton.jit
def _add_kernel(A, B, C, size, BLOCK: tl.constexpr):
    """Add kernel."""
    prog_id = tl.program_id(0)
    offs = prog_id * BLOCK + tl.arange(0, BLOCK)
    a = tl.load(A + offs, mask=offs < size)
    b = tl.load(B + offs, mask=offs < size)
    tl.store(C + offs, a + b, mask=offs < size)


def custom_add(a, b):
    """Custom add one."""
    c = torch.empty_like(a)
    size = c.size(0)
    BLOCK = 16

    grid = (triton.cdiv(size, BLOCK), )
    _add_kernel[grid](a, b, c, size, BLOCK=BLOCK)
    return c


if __name__ == '__main__':
    a = torch.tensor([1, 2], device='cuda')
    b = a.new_tensor([3, 4], device='cuda')
    c = custom_add(a, b)
    torch.testing.assert_close(c, a + b)
    print('Done.')


================================================
FILE: lmdeploy/pytorch/config.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import torch

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value
from lmdeploy.utils import get_logger, is_bf16_supported

logger = get_logger('lmdeploy')


def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'):
    """Update the torch dtype from the model config.

    Args:
        config (ModelConfig): The input model config.
        dtype (str): user specified data type. Refer to
            `PyTorchEngineConfig.dtype` for detailed info
        device_type (str): The device type. Refer to `PyTorchEngineConfig.device_type` for detailed info
    """
    quantization_config = getattr(config.hf_config, 'quantization_config', dict())
    quant_method = quantization_config.get('quant_method', None)
    if quant_method == 'awq':
        logger.debug('set torch_dtype to float16 for awq.')
        config.hf_config.torch_dtype = 'float16'
        config.dtype = torch.float16
        return config

    torch_dtype = getattr(config.hf_config, 'dtype', None)
    if torch_dtype is None and hasattr(config.hf_config, 'text_config'):
        torch_dtype = getattr(config.hf_config.text_config, 'dtype', None)

    if torch_dtype is None:
        torch_dtype = getattr(config.hf_config, 'torch_dtype', None)

    # deal with case when torch_dtype is not string but torch.dtype
    if isinstance(torch_dtype, torch.dtype):
        torch_dtype = str(torch_dtype).split('.')[1]

    if torch_dtype is None:
        _dtype = 'float16' if dtype == 'auto' else dtype
        logger.warning('Model config does not have `torch_dtype`,'
                       f' use: {_dtype}')
        torch_dtype = _dtype
        # update hf_config as well
        setattr(config.hf_config, 'torch_dtype', torch_dtype)
    else:
        if torch_dtype == 'bfloat16' and not is_bf16_supported(device_type):
            torch_dtype = 'float16'
        # change to user specified data type if it is not 'auto'
        if dtype == 'auto':
            torch_dtype = torch_dtype if torch_dtype in ['float16', 'bfloat16'] else 'float16'
        else:
            torch_dtype = dtype
    config.dtype = eval(f'torch.{torch_dtype}')
    return config


@dataclass
class BackendConfig:
    """Backend config."""
    eager_mode: bool = True
    device_type: str = 'cuda'


@dataclass
class SchedulerConfig:
    """Config of scheduler."""

    max_batches: int
    max_session_len: int
    max_request_output_len: int = 512
    eviction_type: str = 'recompute'
    prefill_interval: int = 16
    max_active_adapters: int = 64


@dataclass
class CacheConfig:
    """Config of key value cache."""

    max_batches: int
    block_size: int
    num_cpu_blocks: int
    num_gpu_blocks: int
    window_size: int = -1
    cache_max_entry_count: float = 0.8
    max_prefill_token_num: int = 4096
    enable_prefix_caching: bool = False
    quant_policy: Literal[0, 4, 8] = 0
    device_type: str = 'cuda'
    num_state_caches: int = None
    states_shapes: List[Tuple] = field(default_factory=list)

    # reserved blocks for dummy inputs, init to 0 for unit test.
    num_reserved_gpu_blocks: int = 0

    # For PD Disaggregation
    role: EngineRole = EngineRole.Hybrid
    migration_backend: MigrationBackend = MigrationBackend.DLSlime

    def __post_init__(self):
        """Post init."""
        if self.window_size > 1 and self.enable_prefix_caching:
            logger.warning('Prefix caching is not available for window attention.')
            self.enable_prefix_caching = False


class TPMode(enum.Enum):
    """TP Mode."""
    DEFAULT = enum.auto()
    DP_TP = enum.auto()


@dataclass
class DistConfig:
    dp: int = 1
    ep: int = 1
    dp_rank: int = 0
    enable_microbatch: bool = False
    enable_eplb: bool = False
    world_size: int = 1

    # tp
    tp: int = 1  # default tp, equal to attn_tp
    attn_tp: int = None  # tp for attention
    mlp_tp: int = None  # tp for mlp
    moe_tp: int = None  # tp for moe

    # tp mode
    mlp_tp_mode: TPMode = TPMode.DEFAULT
    moe_tp_mode: TPMode = TPMode.DEFAULT

    def __post_init__(self):
        """Post init."""
        assert self.dp_rank < self.dp
        assert self.dp >= 1

        dp = self.dp
        tp = self.tp
        ep = self.ep

        # ignore layer to for dp==1
        if dp == 1:
            self.mlp_tp = None
            self.attn_tp = None
            self.moe_tp = None

        # mlp and moe tp
        self.mlp_tp = self.mlp_tp or tp
        self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp)

        # world_size
        world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp)
        self.world_size = world_size
        assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}')
        assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}')
        assert (world_size >= self.mlp_tp
                and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}')
        assert (world_size >= self.moe_tp
                and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}')

        # attn tp
        self.attn_tp = self.attn_tp or self.world_size // dp
        self.tp = self.attn_tp
        if self.mlp_tp > 1:
            assert (self.mlp_tp >= self.attn_tp
                    and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}')
        if self.moe_tp > 1:
            assert (self.moe_tp >= self.attn_tp
                    and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}')
        assert (world_size >= self.attn_tp
                and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}')

        # tp mode
        self.mlp_tp_mode = TPMode.DEFAULT if (self.mlp_tp in [1, self.attn_tp]) else TPMode.DP_TP
        self.moe_tp_mode = TPMode.DEFAULT if (self.moe_tp in [1, self.attn_tp]) else TPMode.DP_TP

    def get_tp_by_layer(self, layer_type: str):
        """Get tp by layer type."""
        if layer_type == 'attn':
            return self.attn_tp, TPMode.DEFAULT
        elif layer_type == 'mlp':
            return self.mlp_tp, self.mlp_tp_mode
        elif layer_type == 'moe':
            return self.moe_tp, self.moe_tp_mode
        elif layer_type is None:
            # for some layer that we don't need tp
            return 1, TPMode.DEFAULT
        else:
            raise ValueError(f'Unknown layer type: {layer_type}')

    @classmethod
    def from_engine_config(cls, engine_config: PytorchEngineConfig):
        """From engine config."""
        dist_config = cls(
            dp=engine_config.dp,
            ep=engine_config.ep,
            dp_rank=engine_config.dp_rank,
            enable_microbatch=engine_config.enable_microbatch,
            enable_eplb=engine_config.enable_eplb,
            tp=engine_config.tp,
            attn_tp=engine_config.attn_tp_size,
            mlp_tp=engine_config.mlp_tp_size,
            moe_tp=engine_config.moe_tp_size,
        )
        return dist_config


def _override_hf_config_dict(hf_config: dict, key: str, hf_overrides):
    """Override hf_config dict."""
    from transformers import PretrainedConfig
    if key not in hf_config:
        # copy if key not in hf_config
        hf_config[key] = hf_overrides
        return

    hf_config_val = hf_config[key]
    is_dict = isinstance(hf_config_val, dict)
    is_cfg = isinstance(hf_config_val, PretrainedConfig)
    if not isinstance(hf_overrides, dict) or not (is_dict or is_cfg):
        # if one of them is not dict, just override
        hf_config[key] = hf_overrides
        return

    for key, value in hf_overrides.items():
        _override_hf_config(hf_config_val, key, value)


def _overide_hf_config_cfg(hf_config: list, key: str, hf_overrides):
    """Override hf_config config."""
    from transformers import PretrainedConfig
    if getattr(hf_config, key, None) is None:
        hf_config.update({key: hf_overrides})

    hf_config_val = getattr(hf_config, key)
    is_dict = isinstance(hf_config_val, dict)
    is_cfg = isinstance(hf_config_val, PretrainedConfig)
    if not isinstance(hf_overrides, dict) or not (is_dict or is_cfg):
        # if one of them is not list, just override
        hf_config.update({key: hf_overrides})
        return

    for key, value in hf_overrides.items():
        _override_hf_config(hf_config_val, key, value)


def _override_hf_config(hf_config: Any, key: str, hf_overrides):
    """Override HF config."""
    if isinstance(hf_config, dict):
        _override_hf_config_dict(hf_config, key, hf_overrides)
    else:
        _overide_hf_config_cfg(hf_config, key, hf_overrides)


def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
    """Override HF config."""
    for k, v in hf_overrides.items():
        _override_hf_config(hf_config, k, v)


def _default_check_env(device: str):
    pass


def _patch_quantization_config(hf_config: Any, model_format: str = None):
    """Patch quantization config."""
    if model_format is None:
        return hf_config

    # skip the quantized llm and vlm models
    if hasattr(hf_config, 'quantization_config') or \
        (hasattr(hf_config, 'llm_config') and hasattr(hf_config.llm_config, 'quantization_config')) \
            or (hasattr(hf_config, 'text_config') and hasattr(hf_config.text_config, 'quantization_config')):
        logger.warning('Can not perform weight quantization on quantized model.')
        return hf_config

    if model_format == 'fp8':
        logger.debug('Patch quantization config for fp8.')
        from lmdeploy.pytorch.envs import scale_fmt
        quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt)
    else:
        raise RuntimeError(f'Unsupported weight quantization method: {model_format}')

    hf_config.quantization_config = quantization_config
    # for vlm models
    if hasattr(hf_config, 'text_config'):
        hf_config.text_config.quantization_config = quantization_config
    elif hasattr(hf_config, 'llm_config'):
        hf_config.llm_config.quantization_config = quantization_config

    return hf_config


@dataclass
class ModelConfig:
    """Config of model."""

    hidden_size: int
    num_layers: int
    num_attention_heads: int
    num_key_value_heads: int
    bos_token_id: int
    eos_token_id: List[int]
    head_dim: int
    k_head_dim: int = None
    v_head_dim: int = None
    sliding_window: int = -1
    dtype: torch.dtype = torch.float16
    vocab_size: int = 40000
    hf_config: Any = None
    llm_config: Any = None
    cogvlm_style: bool = False
    custom_module_map: Dict[str, setattr] = None

    # flash mla
    use_flash_mla: bool = False
    use_mla_fp8_cache: bool = False
    mla_index_topk: Optional[int] = None

    # dllm
    model_paradigm: str = 'ar'
    dllm_mask_token: int = 0
    dllm_block_length: int = None

    # Added for deepseekv3.2 nsa index
    # caches would be added after kv cache
    cache_shapes: List[Tuple[List[int], torch.dtype]] = field(default_factory=list)
    # added for qwen3_next
    # could used for any SSM model.
    states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)

    # check env for model-device combination
    check_env_func: Callable = _default_check_env

    # fp32 lm head
    fp32_lm_head: bool = False
    tie_word_embeddings: bool = False

    # quant config
    quant_config: 'QuantizationConfig' = None

    def get_head_size(self):
        """Get head size."""
        return self.head_dim

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        trust_remote_code: bool = True,
        dtype: str = 'auto',
        dist_config: DistConfig = None,
        hf_overrides: Dict[str, Any] = None,
        is_draft_model: bool = False,
        spec_method: str = None,
        model_format: str = None,
        device_type: str = 'auto',
    ):
        """Instantiate one of the configuration classes of the library from a
        pretrained model configuration.

        Args:
            pretrained_model_name_or_path (str): the pretrained model path
            trust_remote_code (bool):  Whether or not to allow for custom
                models defined on the Hub in their own modeling files.
            dtype (str): user specified data type for model weights and
                activations. Refer to `PyTorchEngineConfig` for details
            hf_overrides (Dict[str, Any]): overrides for the HF config.
        """
        from transformers import AutoConfig

        from lmdeploy.pytorch.transformers import config_from_pretrained
        hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
        if getattr(hf_config, 'model_type', None) in ['phi3']:
            # phi3 + trust_remote_code leads to error when tp.
            hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)

        # update quantization config
        hf_config = _patch_quantization_config(hf_config, model_format=model_format)

        model_config = cls.from_hf_config(
            hf_config,
            pretrained_model_name_or_path,
            dtype=dtype,
            dist_config=dist_config,
            is_draft_model=is_draft_model,
            spec_method=spec_method,
            device_type=device_type,
        )
        fp32_lm_head = False
        if hf_overrides is not None:
            logger.warning(f'Overriding HF config with {hf_overrides}')
            fp32_lm_head = hf_overrides.pop('fp32_lm_head', False)
            override_hf_config(model_config.hf_config, hf_overrides)

        # for fp32 head
        model_config.fp32_lm_head = fp32_lm_head
        model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)

        # for serialization of transformers modules
        maybe_register_config_serialize_by_value(trust_remote_code)

        # add quant_config
        model_config.quant_config = QuantizationConfig.from_config(hf_config)
        return model_config

    @classmethod
    def from_hf_config(
        cls,
        hf_config: Any,
        model_path: str = None,
        dtype: str = 'auto',
        dist_config: DistConfig = None,
        is_draft_model: bool = False,
        spec_method: str = None,
        device_type: str = 'auto',
    ):
        """From huggingface config."""
        from lmdeploy.pytorch.configurations import AutoModelConfigBuilder
        if dist_config is None:
            dist_config = DistConfig()
        tp = dist_config.attn_tp

        model_config = AutoModelConfigBuilder.build(hf_config,
                                                    model_path,
                                                    tp=tp,
                                                    is_draft_model=is_draft_model,
                                                    spec_method=spec_method)

        if model_config.k_head_dim is None:
            assert model_config.head_dim is not None
            model_config.k_head_dim = model_config.head_dim
        if model_config.v_head_dim is None:
            assert model_config.head_dim is not None
            model_config.v_head_dim = model_config.head_dim

        # check for tp
        assert model_config.num_attention_heads % tp == 0
        if model_config.num_key_value_heads >= tp:
            assert model_config.num_key_value_heads % tp == 0
        else:
            assert tp % model_config.num_key_value_heads == 0

        # should after setting `hf_config` and `model_arch` attributes
        model_config = _update_torch_dtype(model_config, dtype, device_type=device_type)

        # update eos_token_id to list
        if isinstance(model_config.eos_token_id, int):
            model_config.eos_token_id = [model_config.eos_token_id]

        return model_config


class UnmaskingStrategy(enum.Enum):
    """Unmasking Strategy."""

    # unmasking from left to right
    SEQUENTIAL = enum.auto()
    # unmasking with confidence threshold
    LOW_CONFIDENCE_DYNAMIC = enum.auto()
    # unmasking with topk in a block
    LOW_CONFIDENCE_STATIC = enum.auto()

    @classmethod
    def from_str(cls, strategy: str):
        """From string."""
        strategy = strategy.lower()
        if strategy == 'sequential':
            return cls.SEQUENTIAL
        elif strategy == 'low_confidence_dynamic':
            return cls.LOW_CONFIDENCE_DYNAMIC
        elif strategy == 'low_confidence_static':
            return cls.LOW_CONFIDENCE_STATIC
        else:
            raise ValueError(f'Unknown unmasking strategy: {strategy}')


@dataclass
class DLLMConfig:
    block_length: int = 1
    unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC
    denoising_steps: int = None
    confidence_threshold: float = 0.85


@dataclass
class MiscConfig:
    prefill_interval: int = 16
    custom_module_map: str = None
    empty_init: bool = False
    model_format: str = None
    hf_overrides: Dict[str, Any] = None
    disable_vision_encoder: bool = False
    logprobs_mode: str = None
    dllm_config: DLLMConfig = None
    enable_return_routed_experts: bool = False
    enable_chunked_prefill: bool = False

    @classmethod
    def from_engine_config(cls, engine_config: PytorchEngineConfig):
        """From engine config."""
        dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy)
        dllm_config = DLLMConfig(block_length=engine_config.dllm_block_length,
                                 unmasking_strategy=dllm_unmasking_strategy,
                                 denoising_steps=engine_config.dllm_denoising_steps,
                                 confidence_threshold=engine_config.dllm_confidence_threshold)
        misc_config = cls(
            custom_module_map=engine_config.custom_module_map,
            empty_init=engine_config.empty_init,
            prefill_interval=engine_config.prefill_interval,
            model_format=engine_config.model_format,
            hf_overrides=engine_config.hf_overrides,
            disable_vision_encoder=engine_config.disable_vision_encoder,
            logprobs_mode=engine_config.logprobs_mode,
            dllm_config=dllm_config,
            enable_return_routed_experts=engine_config.enable_return_routed_experts,
            enable_chunked_prefill=False,
        )
        return misc_config


@dataclass
class SpecDecodeConfig:
    model: str
    method: str
    cache_config: CacheConfig = None
    num_speculative_tokens: int = 1
    model_config: ModelConfig = None

    @classmethod
    def from_config(
        cls,
        method: str,
        num_speculative_tokens: int,
        model: str,
        target_cache_cfg: CacheConfig,
        target_model: str = None,
        dtype: str = 'auto',
    ):
        model = model or target_model
        model_config = ModelConfig.from_pretrained(model,
                                                   trust_remote_code=True,
                                                   dtype=dtype,
                                                   is_draft_model=True,
                                                   spec_method=method)
        cache_config = None
        # include medusa
        no_caches = ['medusa']
        if method not in no_caches:
            cache_config = CacheConfig(max_batches=target_cache_cfg.max_batches,
                                       block_size=target_cache_cfg.block_size,
                                       num_cpu_blocks=target_cache_cfg.num_cpu_blocks,
                                       num_gpu_blocks=target_cache_cfg.num_gpu_blocks,
                                       cache_max_entry_count=target_cache_cfg.cache_max_entry_count,
                                       max_prefill_token_num=target_cache_cfg.max_prefill_token_num,
                                       device_type=target_cache_cfg.device_type,
                                       migration_backend=target_cache_cfg.migration_backend)
        obj = cls(
            model=model,
            method=method,
            cache_config=cache_config,
            model_config=model_config,
            num_speculative_tokens=num_speculative_tokens,
        )
        return obj


@dataclass
class QuantizationConfig:
    quant_method: str = None
    quant_dtype: torch.dtype = None
    scale_fmt: str = None
    bits: int = None
    group_size: int = None
    weight_block_size: Tuple[int] = None
    activation_scheme: str = None
    ignored_layers: List[str] = field(default_factory=list)
    hf_quant_config: Dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_config(cls, hf_config: Any):
        quant_config = getattr(hf_config, 'quantization_config', None)

        if quant_config is None:
            if hasattr(hf_config, 'llm_config') and hasattr(hf_config.llm_config, 'quantization_config'):
                quant_config = hf_config.llm_config.quantization_config
            elif hasattr(hf_config, 'text_config') and hasattr(hf_config.text_config, 'quantization_config'):
                quant_config = hf_config.text_config.quantization_config

        # no quant config found in hf config
        if quant_config is None:
            return cls()

        quant_method = quant_config['quant_method']
        quant_dtype = quant_config.get('quant_dtype', None)
        scale_fmt = quant_config.get('scale_fmt', None)
        weight_block_size = quant_config.get('weight_block_size', None)
        activation_scheme = quant_config.get('activation_scheme', None)

        bits = None
        group_size = None

        if quant_method == 'awq':
            bits = quant_config.get('bits', 4)
            group_size = quant_config.get('group_size', 128)
        elif quant_method == 'smooth_quant':
            if quant_dtype is None:
                quant_dtype = 'int8'
        elif quant_method == 'fp8':
            fmt = quant_config.get('fmt', 'e4m3')
            if fmt == 'e4m3':
                quant_dtype = 'float8_e4m3fn'
            elif fmt == 'e5m2':
                quant_dtype = 'float8_e5m2'
            else:
                raise TypeError(f'Unsupported fp8 fmt: {fmt}')
        else:
            raise TypeError(f'Unsupported quant method: {quant_method}')

        if quant_dtype is not None:
            quant_dtype = eval(f'torch.{quant_dtype}')

        ignored_layers = quant_config.get('ignored_layers', [])
        if not ignored_layers:
            ignored_layers = quant_config.get('modules_to_not_convert', [])

        return cls(
            quant_method=quant_method,
            quant_dtype=quant_dtype,
            scale_fmt=scale_fmt,
            bits=bits,
            group_size=group_size,
            weight_block_size=weight_block_size,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            hf_quant_config=quant_config,
        )

    def get_quant_method(self, prefix: str = ''):
        """Get quant method for module."""
        if not prefix or not self.ignored_layers:
            return self.quant_method

        is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers])
        quant_method = None if is_ignore else self.quant_method
        return quant_method

    def get(self, key, default=None):
        """Get extra key from hf quant config."""
        return self.hf_quant_config.get(key, default)


================================================
FILE: lmdeploy/pytorch/configurations/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import pkgutil

from .builder import AutoModelConfigBuilder

__all__ = []

# load all submodule
for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
    __all__.append(module_name)
    _module = importlib.import_module('{}.{}'.format(__name__, module_name))
    globals()[module_name] = _module

__all__ += ['AutoModelConfigBuilder']


================================================
FILE: lmdeploy/pytorch/configurations/builder.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class AutoModelConfigBuilder(ABC):

    _sub_classes = list()

    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        AutoModelConfigBuilder.register_builder(cls)

    @classmethod
    def register_builder(cls, sub_cls):
        """Register builder."""
        if sub_cls not in AutoModelConfigBuilder._sub_classes:
            AutoModelConfigBuilder._sub_classes.append(sub_cls)

    @classmethod
    def condition(cls, hf_config):
        """config."""
        raise NotImplementedError(f'`condition` of {cls.__name__} not implemented.')

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        from .default import DefaultModelConfigBuilder

        if cls != AutoModelConfigBuilder:
            raise NotImplementedError(f'`build` of {cls.__name__} not implemented.')

        valid_builder = DefaultModelConfigBuilder
        for builder in cls._sub_classes:
            if builder == valid_builder:
                continue

            if builder.condition(hf_config):
                valid_builder = builder
                break

        logger.debug(f'build model config with {valid_builder.__name__}')

        cfg = valid_builder.build(hf_config, model_path, **kwargs)
        if cfg.hf_config is None:
            cfg.hf_config = hf_config
        if cfg.llm_config is None:
            cfg.llm_config = hf_config

        return cfg

    @classmethod
    def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads):
        """Update num kv heads."""
        # update num_kv_heads for tp mode
        if tp > 1 and tp > num_key_value_heads:
            assert tp % num_key_value_heads == 0
            n_replicate = tp // num_key_value_heads
            hf_config.num_replicate_key_value_heads = n_replicate
            num_key_value_heads = tp

        hf_config.num_key_value_heads = num_key_value_heads
        return num_key_value_heads


================================================
FILE: lmdeploy/pytorch/configurations/chatglm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class ChatGLMModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type == 'chatglm'

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        head_dim = hf_config.hidden_size // hf_config.num_attention_heads
        bos_token_id = getattr(hf_config, 'bos_token_id', None)
        if bos_token_id is None:
            bos_token_id = hf_config.pad_token_id

        if hf_config.multi_query_attention:
            num_key_value_heads = hf_config.multi_query_group_num
        else:
            num_key_value_heads = hf_config.num_attention_heads

        tp = kwargs.get('tp', 1)
        # update num_kv_heads for tp mode
        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)

        cfg = ModelConfig(hidden_size=hf_config.hidden_size,
                          num_layers=hf_config.num_layers,
                          num_attention_heads=hf_config.num_attention_heads,
                          num_key_value_heads=num_key_value_heads,
                          bos_token_id=bos_token_id,
                          eos_token_id=hf_config.eos_token_id,
                          head_dim=head_dim,
                          vocab_size=hf_config.padded_vocab_size)
        # glm-4v
        if hasattr(hf_config, 'vision_config'):
            cfg.cogvlm_style = True
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/cogvlm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class CogVLMModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        model_arch = hf_config.architectures[0] if hf_config.architectures else None
        return model_arch == 'CogVLMForCausalLM'

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        from lmdeploy.utils import is_bf16_supported
        if getattr(hf_config, 'num_multi_query_heads', None):
            hf_config.num_key_value_heads = hf_config.num_multi_query_heads
        else:
            hf_config.num_key_value_heads = hf_config.num_attention_heads

        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        cfg.cogvlm_style = True
        torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
        hf_config.torch_dtype = torch_dtype
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/deepseek_v2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder
from .utils import flash_mla_available


class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'kimi_k2']

    @classmethod
    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):
        """build."""
        head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)
        k_head_dim = head_dim
        v_head_dim = 0
        num_attention_heads = hf_config.num_attention_heads
        # multi query attn
        num_key_value_heads = 1
        tp = kwargs.get('tp', 1)
        # update num_kv_heads for tp mode
        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)
        hf_config.use_flash_mla = flash_mla_available()
        num_layers = hf_config.num_hidden_layers
        model_paradigm = 'ar'

        if spec_method is not None:
            assert spec_method == 'deepseek_mtp'

        # draft model cfg
        if is_draft_model:
            num_layers = hf_config.num_nextn_predict_layers
            hf_config.architectures[0] = 'DeepseekMTPModel'
            # remove for correct mapping when building the patched model
            if hasattr(hf_config, 'auto_map'):
                del hf_config.auto_map

        if is_draft_model or spec_method is not None:
            model_paradigm = 'ar_spec'

        bos_token_id = getattr(hf_config, 'bos_token_id', None)
        config = ModelConfig(
            hidden_size=hf_config.hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            bos_token_id=bos_token_id,
            eos_token_id=hf_config.eos_token_id,
            head_dim=head_dim,
            k_head_dim=k_head_dim,
            v_head_dim=v_head_dim,
            vocab_size=hf_config.vocab_size,
            use_flash_mla=hf_config.use_flash_mla,
            model_paradigm=model_paradigm,
        )
        return config


================================================
FILE: lmdeploy/pytorch/configurations/deepseek_v32.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .deepseek_v2 import DeepseekV2ModelConfigBuilder


def _check_env_v32(device: str = 'cuda'):
    """Environment check."""
    if device != 'cuda':
        return

    # check cuda
    try:
        import fast_hadamard_transform  # noqa: F401
    except ImportError:
        raise ImportError('Deepseek V3.2 requires .')

    try:
        import flash_mla  # noqa: F401
    except ImportError:
        raise ImportError('Deepseek V3.2 requires .')

    if not hasattr(flash_mla, 'flash_mla_sparse_fwd'):
        raise RuntimeError('Latest flash_mla is required: https://github.com/deepseek-ai/FlashMLA.')


class DeepseekV32ModelConfigBuilder(DeepseekV2ModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['deepseek_v32', 'glm_moe_dsa']

    @classmethod
    def build(cls, hf_config, model_path: str | None = None, **kwargs):
        """build."""
        config = DeepseekV2ModelConfigBuilder.build(hf_config, model_path=model_path, **kwargs)

        assert hf_config.use_flash_mla, 'DeepSeek-V3.2 requires flash_mla to be available.'
        index_k_shape = ([hf_config.index_head_dim], torch.float8_e4m3fn)
        index_k_scale_shape = ([1], torch.float32)
        config.cache_shapes = [index_k_shape, index_k_scale_shape]
        config.use_mla_fp8_cache = True
        config.mla_index_topk = hf_config.index_topk
        config.check_env_func = _check_env_v32
        return config


================================================
FILE: lmdeploy/pytorch/configurations/deepseek_vl2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class DeepseekVLV2ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['deepseek_vl_v2']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build deepseek-vl2."""

        if hf_config.language_config.use_mla:
            from .deepseek_v2 import DeepseekV2ModelConfigBuilder
            cfg = DeepseekV2ModelConfigBuilder.build(hf_config.language_config, model_path, **kwargs)
            cfg.hf_config = hf_config
        else:
            # deepseek-vl2-tiny uses MHA, rather than MLA
            # in this case, we use DefaultModelConfigBuilder
            cfg = DefaultModelConfigBuilder.build(hf_config.language_config, model_path, **kwargs)
            cfg.hf_config = hf_config

        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/default.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class DefaultModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return True

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        head_dim = getattr(hf_config, 'head_dim', None)
        head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads

        # head_dim should not be None
        hf_config.head_dim = head_dim
        num_attention_heads = hf_config.num_attention_heads
        num_key_value_heads = getattr(hf_config, 'num_key_value_heads', num_attention_heads)
        use_sliding_window = getattr(hf_config, 'use_sliding_window', True)
        sliding_window = -1
        if use_sliding_window:
            sliding_window = getattr(hf_config, 'sliding_window', sliding_window) or -1
        tp = kwargs.get('tp', 1)
        # update num_kv_heads for tp mode
        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)

        return ModelConfig(
            hidden_size=hf_config.hidden_size,
            num_layers=hf_config.num_hidden_layers,
            num_attention_heads=hf_config.num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            bos_token_id=hf_config.bos_token_id,
            eos_token_id=hf_config.eos_token_id,
            sliding_window=sliding_window,
            head_dim=head_dim,
            k_head_dim=head_dim,
            v_head_dim=head_dim,
            vocab_size=hf_config.vocab_size,
            llm_config=hf_config,
        )


================================================
FILE: lmdeploy/pytorch/configurations/gemma.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class GemmaModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['gemma', 'gemma2', 'gemma3_text']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build gemma."""
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        cfg.head_dim = hf_config.head_dim
        return cfg


class GemmaVLModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        model_arch = hf_config.architectures[0] if hf_config.architectures else None
        return model_arch == 'Gemma3ForConditionalGeneration'

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build gemma."""
        hf_config.text_config.architectures = ['Gemma3ForCausalLM']
        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
        # gemma 3 does not enable sliding window on every layers
        cfg.sliding_window = -1
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/glm4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .deepseek_v2 import DeepseekV2ModelConfigBuilder
from .default import DefaultModelConfigBuilder


class Glm4MoeLiteModelConfigBuilder(DeepseekV2ModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['glm4_moe_lite']

    @classmethod
    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):
        """build."""
        # set default attrs
        if not hasattr(hf_config, 'scoring_func'):
            hf_config.scoring_func = 'sigmoid'
        if not hasattr(hf_config, 'moe_layer_freq'):
            hf_config.moe_layer_freq = 1
        return super().build(hf_config,
                             model_path=model_path,
                             is_draft_model=is_draft_model,
                             spec_method=spec_method,
                             **kwargs)


class Glm4MoeModelConfigBuilder(DefaultModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['glm4_moe']

    @classmethod
    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):
        """build."""

        num_layers = hf_config.num_hidden_layers
        model_paradigm = 'ar'

        if spec_method is not None:
            assert spec_method == 'deepseek_mtp'

        # draft model cfg
        if is_draft_model:
            num_layers = hf_config.num_nextn_predict_layers
            hf_config.architectures[0] = 'Glm4MoeMTPModel'
            # remove for correct mapping when building the patched model
            if hasattr(hf_config, 'auto_map'):
                del hf_config.auto_map

        if is_draft_model or spec_method is not None:
            model_paradigm = 'ar_spec'

        cfg = super().build(hf_config,
                            model_path=model_path,
                            is_draft_model=is_draft_model,
                            spec_method=spec_method,
                            **kwargs)
        cfg.model_paradigm = model_paradigm
        cfg.num_layers = num_layers
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/gpt_oss.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class GptOSSModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['gpt_oss']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build gemma."""
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        # gpt_oss 3 does not enable sliding window on every layers
        cfg.sliding_window = -1
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/interns1_pro.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class InterS1ProModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['interns1_pro', 'interns1_1']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):
            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)
        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
        setattr(hf_config, 'dtype', hf_config.text_config.dtype)
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/internvl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class InternVLModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.architectures[0] == 'InternVLChatModel'

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build llava hf."""
        cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, **kwargs)
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/internvl3_hf.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class InternVL3ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.architectures[0] in ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build config."""
        # hack quantization_config
        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):
            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)

        # fix transformers>5
        if hasattr(hf_config.text_config, 'tie_word_embeddings'):
            hf_config.tie_word_embeddings = hf_config.text_config.tie_word_embeddings

        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/llama.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class LlamaModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.architectures[0] in ['LlamaForCausalLM']

    @classmethod
    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):
        """Build llama."""
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)

        if is_draft_model:
            # update draft model arch
            assert spec_method is not None
            hf_config.architectures[0] = spec_method.capitalize() + hf_config.architectures[0]
            cfg.vocab_size = getattr(hf_config, 'draft_vocab_size', hf_config.vocab_size)
            cfg.model_paradigm = 'ar_spec'
        elif spec_method is not None:
            # add aux_hidden_state_layers for eagle3
            if spec_method == 'eagle3':
                num_layers = cfg.num_layers
                hf_config.aux_hidden_state_layers = (2, num_layers // 2, num_layers - 3)
            cfg.model_paradigm = 'ar_spec'
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/llama4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class Llama4ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['llama4']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build llama4."""
        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
        cfg.hf_config = hf_config

        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/llava_hf.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class LlavaHfModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.architectures[0] in ['LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """Build llava hf."""
        text_config = hf_config.text_config
        hidden_size = getattr(text_config, 'hidden_size', 4096)
        num_attention_heads = getattr(text_config, 'num_attention_heads', 32)
        num_key_value_heads = getattr(text_config, 'num_key_value_heads', 32)
        num_hidden_layers = getattr(text_config, 'num_hidden_layers', 32)
        bos_token_id = getattr(text_config, 'bos_token_id', 1)
        eos_token_id = getattr(text_config, 'eos_token_id', 2)
        head_dim = hidden_size // num_attention_heads

        return ModelConfig(
            hidden_size=hidden_size,
            num_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            head_dim=head_dim,
            vocab_size=text_config.vocab_size,
            hf_config=hf_config,
        )


================================================
FILE: lmdeploy/pytorch/configurations/minicpm3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)

        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        cfg.head_dim = head_dim
        cfg.k_head_dim = head_dim
        cfg.v_head_dim = head_dim

        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/qwen.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class QwenModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type == 'qwen'

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        from lmdeploy.utils import is_bf16_supported
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        if cfg.bos_token_id is None:
            cfg.bos_token_id = 151644
        if cfg.eos_token_id is None:
            cfg.eos_token_id = 151645

        torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
        if hf_config.bf16 and is_bf16_supported():
            torch_dtype = 'bfloat16'
        elif hf_config.fp16:
            torch_dtype = 'float16'
        hf_config.torch_dtype = torch_dtype
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/qwen3_5.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.utils import is_bf16_supported

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder
from .qwen3_next import _check_env_qwen3_next


class Qwen3_5ModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['qwen3_5', 'qwen3_5_moe']

    @classmethod
    def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
        """build."""
        text_config = hf_config.text_config
        # propagate quantization_config from top-level hf_config into text_config
        quantization_config = getattr(hf_config, 'quantization_config', None)
        if quantization_config is not None and not hasattr(text_config, 'quantization_config'):
            text_config.quantization_config = quantization_config
        cfg = DefaultModelConfigBuilder.build(text_config, model_path, tp=tp, **kwargs)

        # update num layers
        num_layers = cfg.num_layers
        layer_types = text_config.layer_types
        num_delta_layers = sum([1 for lt in layer_types if lt == 'linear_attention'])
        num_full_layers = num_layers - num_delta_layers
        cfg.num_layers = num_full_layers

        # set state shapes
        head_k_dim = text_config.linear_key_head_dim
        head_v_dim = text_config.linear_value_head_dim
        num_v_heads = text_config.linear_num_value_heads // tp
        num_k_heads = text_config.linear_num_key_heads // tp
        key_dim = head_k_dim * num_k_heads
        value_dim = head_v_dim * num_v_heads
        conv_dim = key_dim * 2 + value_dim
        conv_kernel_size = text_config.linear_conv_kernel_dim

        conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
        recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
        if is_bf16_supported():
            dtype = torch.bfloat16
        else:
            dtype = torch.float16
        cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
        cfg.check_env_func = _check_env_qwen3_next
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/qwen3_next.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


def _check_env_qwen3_next(device: str):
    """Check env for qwen3 next."""
    if device != 'cuda':
        return

    try:
        import fla  # noqa: F401
    except ImportError:
        raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')


class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type == 'qwen3_next'

    @classmethod
    def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
        """build."""
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)

        # update num layers
        num_layers = cfg.num_layers
        num_full_layers = num_layers // hf_config.full_attention_interval
        num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)
        cfg.num_layers = num_full_layers

        # set state shapes
        head_k_dim = hf_config.linear_key_head_dim
        head_v_dim = hf_config.linear_value_head_dim
        num_v_heads = hf_config.linear_num_value_heads // tp
        num_k_heads = hf_config.linear_num_key_heads // tp
        key_dim = head_k_dim * num_k_heads
        value_dim = head_v_dim * num_v_heads
        conv_dim = key_dim * 2 + value_dim
        conv_kernel_size = hf_config.linear_conv_kernel_dim

        conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
        recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
        dtype = torch.bfloat16
        cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
        cfg.check_env_func = _check_env_qwen3_next
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/qwen3_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


class Qwen3VLModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['qwen2_vl', 'qwen2_5_vl', 'qwen3_vl', 'qwen3_vl_moe']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        if not hasattr(hf_config, 'text_config'):
            # for transformers <= 5
            return DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)

        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):
            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)
        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
        setattr(hf_config, 'dtype', hf_config.text_config.dtype)
        cfg.hf_config = hf_config
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/sdar.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .default import AutoModelConfigBuilder, DefaultModelConfigBuilder


class SDARModelConfigBuilder(AutoModelConfigBuilder):

    @classmethod
    def condition(cls, hf_config):
        """config."""
        return hf_config.model_type in ['sdar', 'sdar_moe']

    @classmethod
    def build(cls, hf_config, model_path: str = None, **kwargs):
        """build."""
        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
        cfg.dllm_mask_token = 151669
        cfg.model_paradigm = 'dllm'
        return cfg


================================================
FILE: lmdeploy/pytorch/configurations/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def flash_mla_available():
    """Check if flash mla is available."""
    # use flash_mla by default if it is installed
    use_flash_mla = False
    try:
        """In some torch_npu versions, device_properties doesn't have 'major'
        attribute; In other torch_npu versions, the value of major is None."""
        device_properties = torch.cuda.get_device_properties(0)
        major = getattr(device_properties, 'major', None)
        if isinstance(major, int) and major >= 9:
            import flash_mla  # noqa
            use_flash_mla = True
    except ImportError:
        logger.warning('For higher performance, please install flash_mla https://github.com/deepseek-ai/FlashMLA')
    return use_flash_mla


def flash_attn_v3_available():
    """Check if flash attn v3 is available."""
    use_fa3 = False
    try:
        # Now flash-attention only support FA3 for sm90a && cuda >= 12.3
        if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):
            import flash_attn_interface  # noqa: F401
            assert torch.ops.flash_attn_3 is not None
            use_fa3 = True
    except Exception:
        logger.warning('For higher performance, please install FlashAttention-3 '
                       'https://github.com/Dao-AILab/flash-attention')
    return use_fa3


================================================
FILE: lmdeploy/pytorch/consts.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# dllm
DLLM_MASKED = 0
DLLM_UNMASKED = 1
DLLM_CACHED = 2


================================================
FILE: lmdeploy/pytorch/devices/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .device_manager import DefaultContext, DeviceContext, get_device_manager

__all__ = ['DeviceContext', 'DefaultContext', 'get_device_manager']


================================================
FILE: lmdeploy/pytorch/devices/device_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Callable

from lmdeploy.pytorch.utils import CtxMgrBase, singleton


@dataclass
class DeviceContext:
    device_type: str = 'cuda'


DefaultContext = DeviceContext()


@singleton
class DeviceManager(CtxMgrBase[DeviceContext]):

    def __init__(self):
        super().__init__(DefaultContext)
        self._context_callback: dict[int, Callable] = dict()
        self._next_cb_handle = 0

    def register_context_callback(self, callback: Callable):
        """Register callback."""
        handle = self._next_cb_handle
        self._context_callback[handle] = callback
        self._next_cb_handle += 1
        return handle

    def unregister_context_callback(self, handle: int):
        """Unregister callback."""
        self._context_callback.pop(handle, None)


def get_device_manager():
    """Get device manager."""
    return DeviceManager()


================================================
FILE: lmdeploy/pytorch/disagg/README.md
================================================
# LMDeploy-DistServe

## Key Components

1. ​**Router Service**: Coordinates between prefill/decode engines
2. ​**Migration Manager**: Facilitates high-performance memory sharing

## Installation

```
# Inference Engine
pip install lmdeploy[all] >= 0.7.0

# Transfer Engine
pip install dlslime>=0.0.2
```

## Quick Start

A PD disaggregated deployment of internlm2_5-7b-chat is shown below:

### 1. Launch Router Service

```shell
lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8000 --routing-strategy "min_expected_latency" --serving-strategy DistServe --log-level INFO
```

LMDeploy-DistServe support both NVLink and RDMA for kvcache transferring from Prefill Engine to Decode Engine. RDMA is default model. Set `--migration-protocol NVLink` for NVLink transport.

### 2. Configure Endpoints

First deploy your prefill and decode engines.

```shell
# Prefill Engine
CUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --role Prefill --proxy-url http://0.0.0.0:8000 --backend pytorch
# Decode Engine
CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23334 --role Decode --proxy-url http://0.0.0.0:8000 --backend pytorch
```

By now, only **Pytorch backend** supports PD Disaggregation.

## API Usage

```shell
# API Invoke
curl -X POST "http://localhost:8000/v1/completions" \
-H "Content-Type: application/json" \
-d '{"model": "internlm/internlm2_5-7b-chat", "temperature":0, "prompt": "Shanghai is a city that ", "max_tokens": 16, "stream": false}'
# Output
{
  "id":"2",
  "object":"text_completion",
  "created":1743662400,"
  model":"internlm/internlm2_5-7b-chat",
  "choices":[
    {
      "index":0,
      "text":" is very famous for its skyscrapers. It is also a city","logprobs":null,"finish_reason":"length"
    }
  ],
  "usage": {
    "prompt_tokens":7,"total_tokens":23,"completion_tokens":16
  }
}
```

## Trouble Shooting

### RDMA Connection Failed:

Make sure ibverbs is correctly installed:

```
# on Ubuntu
sudo apt install libibverbs-dev
# on CentOS
sudo yum install ibverbs-devel
```

```bash
ibstat        # Verify IB device status
ibv_devinfo   # Check device capabilities
```

### Check GPU Direct RDMA:

By now, lmdeploy-distserve use GPUDirect RDMA to perform KVTransfer. Make sure GPUDirect RDMA Driver is loaded to kernel.

```bash
lsmod | grep nv_peer_mem
# GPUDirect RDMA info will be printed If GPUDirect RDMA is correctly loaded.
```

### Connection Pool

Currently, if the ​​Proxy disconnects​​, the connection pool must be ​​warmed up again​​. A future enhancement could involve:

A ​​dedicated connection pool management server​​ (e.g., using ​​Raft-based tools like ETCD​​, as mentioned in ​​Mooncake​​) to improve ​​connection discovery​​ and avoid repeated warmups.

### Proxy

Do not add an engine nodes to **different proxy** because it is not supported and is not considered as a right usage by now.


================================================
FILE: lmdeploy/pytorch/disagg/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/disagg/backend/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.logger import get_logger

logger = get_logger('lmdeploy')

try:
    logger.debug('Registering DLSlime Backend')
    from .dlslime import DLSlimeBackend
except ImportError:
    logger.debug('Disable DLSlime Backend')

try:
    logger.debug('Registering Mooncake Backend')
    from .mooncake import MooncakeBackend
except ImportError:
    logger.warning('Disable Mooncake Backend')

__all__ = ['DLSlimeBackend', 'MooncakeBackend']


================================================
FILE: lmdeploy/pytorch/disagg/backend/backend.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import Registry

MIGRATION_BACKENDS = Registry('migration_backend', locations=['lmdeploy.pytorch.disagg.backend.backend'])


================================================
FILE: lmdeploy/pytorch/disagg/backend/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod

from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
                                                   MigrationProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment


class MigrationBackendImpl:

    @abstractmethod
    def p2p_initialize(self, init_request: DistServeInitRequest):
        raise NotImplementedError

    @abstractmethod
    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
        raise NotImplementedError

    @abstractmethod
    def endpoint_info(self, remote_engine_id: str, protocol: MigrationProtocol):
        return NotImplementedError

    @abstractmethod
    def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo):
        raise NotImplementedError

    @abstractmethod
    def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError

    @abstractmethod
    def store(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError

    @abstractmethod
    def load(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/disagg/backend/dlslime.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import json
import os
from typing import Dict

from dlslime import RDMAEndpoint, available_nic

from lmdeploy.logger import get_logger
from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
                                                   MigrationProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment

logger = get_logger('lmdeploy')

LMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None)


class DLSlimeMigrationManagement:

    def __init__(self, init_request: DistServeInitRequest):
        self.rank = init_request.rank
        self.local_engine_config: DistServeEngineConfig = (init_request.local_engine_config)
        self.remote_engine_config: DistServeEngineConfig = (init_request.remote_engine_config)
        self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = {}
        if init_request.protocol == MigrationProtocol.RDMA:
            nics = available_nic()
            device_name = nics[self.rank % len(nics)]
            logger.info(f'use device {device_name} for kv migration')
            self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(
                device_name=device_name,
                ib_port=1,
                link_type=init_request.rdma_config.link_type.name,
            )
        elif init_request.protocol == MigrationProtocol.NVLINK:
            try:
                from dlslime import NVLinkEndpoint
            except ImportError:
                logger.warning('Notice: DLSlime not compiled from source with NVLink. Fallback to RDMAEndpoint.')
                NVLinkEndpoint = RDMAEndpoint
            self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint()

    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
        self.endpoint[register_mr_request.protocol].register_memory_region(
            register_mr_request.mr_key,
            register_mr_request.addr,
            register_mr_request.offset,
            register_mr_request.length,
        )

    def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo):
        self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info))

    async def p2p_migrate(self, assignment: MigrationAssignment):
        batch = [(
            assign.mr_key,
            assign.mr_key,
            assign.target_offset,
            assign.source_offset,
            assign.length,
        ) for assign in assignment.batch]

        future = self.endpoint[assignment.protocol].read(batch)
        if LMDEPLOY_USE_ASYNC_MIGRATION:
            loop = asyncio.get_running_loop()
            return await loop.run_in_executor(None, future.wait)
        else:
            return future.wait()


@MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name)
class DLSlimeBackend(MigrationBackendImpl):
    """DLSlime Transfer Engine."""

    def __init__(self):
        self.links: Dict[str, DLSlimeMigrationManagement] = {}

    def p2p_initialize(self, init_request: DistServeInitRequest):
        self.links[init_request.remote_engine_id] = DLSlimeMigrationManagement(init_request)

    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
        self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)

    def endpoint_info(self, remote_engine_id: str, protocol: MigrationProtocol):
        return self.links[remote_engine_id].endpoint[protocol].endpoint_info()

    def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo):
        self.links[remote_engine_id].connect(conn_req)

    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
        await self.links[assignment.remote_engine_id].p2p_migrate(assignment)

    def store(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError

    def load(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/disagg/backend/mooncake.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import json
import os
import socket
import subprocess
from typing import Dict

from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import MigrationBackend, MooncakeEngineConfig
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
                                                   MigrationProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

LMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None)


def get_rdma_nics():
    """Get all available RDMA network interface cards on the current machine.

    Returns:
        list: List of RDMA NICs, e.g. ['erdma_0', 'erdma_1']
    """
    rdma_nics = []

    try:
        result = subprocess.run(['ibv_devices'], stdout=subprocess.PIPE, text=True)
        if result.returncode == 0:
            # Parse ibv_devices output
            # Sample output:
            # device                 node GUID
            # ------              ----------------
            lines = result.stdout.strip().split('\n')
            for line in lines[2:]:  # Skip header lines
                if line.strip():
                    device_name = line.split()[0].strip()
                    rdma_nics.append(device_name)
    except Exception as e:
        logger.error(f'Error executing ibv_devices command: {e}')

    return rdma_nics


def get_local_ip_by_remote() -> str:
    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(('8.8.8.8', 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(('2001:4860:4860::8888', 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        raise ValueError('Can not get local ip')


class MooncakeMigrationManagement:
    """Manages migration for a single connection in Mooncake backend."""

    def __init__(self, init_request: DistServeInitRequest):
        try:
            from mooncake.engine import TransferEngine
        except ImportError as e:
            raise ImportError('Please install mooncake by following the instructions at '
                              'https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md '
                              'to run LMDeploy with MooncakeBackend.') from e

        self.rank = init_request.rank
        self.local_engine_config: MooncakeEngineConfig = init_request.local_engine_config
        self.remote_engine_config: MooncakeEngineConfig = init_request.remote_engine_config
        self.local_engine_id = init_request.local_engine_id
        self.remote_engine_id = init_request.remote_engine_id

        self.engine = TransferEngine()
        self.hostname = get_local_ip_by_remote()

        # Get all RDMA information once during initialization
        self.ibv_devices = get_rdma_nics()

        self.local_kv_table: Dict[str, Dict] = {}
        self.remote_kv_table: Dict[str, Dict] = {}
        self.remote_url: str = ''  # Store remote URL for this connection

        # Initialize the p2p connection
        self._initialize_p2p(init_request)

        self.port: int = self.engine.get_rpc_port()

    def _initialize_p2p(self, init_request: DistServeInitRequest):
        """Initialize p2p connection for this specific link."""
        # TODO: Support more types of metadata_server
        # e.g. "etcd://192.168.0.137:2379"
        metadata_server = 'P2PHANDSHAKE'

        # Default protocol (Currently only RDMA is supported)
        protocol = 'rdma'

        # Get the device name from request
        if not self.ibv_devices:
            raise RuntimeError('No RDMA devices available')

        device_name = self.ibv_devices[self.rank % len(self.ibv_devices)]

        # Initialize the engine
        result = self.engine.initialize(self.hostname, metadata_server, protocol, device_name)
        if result != 0:
            raise RuntimeError(f'Failed to initialize Mooncake engine: {result}')

        logger.info(f'Mooncake engine initialized for remote_engine_id {self.remote_engine_id} '
                    f'with hostname {self.hostname}, RPC port: {self.engine.get_rpc_port()}')

    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
        """Register memory region for this connection."""
        # Transmit buffer address to int
        buffer_addr = register_mr_request.addr
        buffer_length = register_mr_request.length

        # Register memory region with the engine
        result = self.engine.register_memory(buffer_addr, buffer_length)
        if result != 0:
            raise RuntimeError(f'Failed to register memory region: {result}')

        mr_key = str(register_mr_request.mr_key)
        self.local_kv_table[mr_key] = {
            'addr': buffer_addr,
            'length': buffer_length,
            'offset': register_mr_request.offset
        }

        logger.info(f'Registered memory region with mr_key {mr_key}, '
                    f'addr: {buffer_addr}, length: {buffer_length} for remote_engine_id {self.remote_engine_id}')

    @property
    def endpoint_info(self) -> Dict:
        """Get endpoint information for this connection."""

        mr_info = {}
        for mr_key, buffer_info in self.local_kv_table.items():
            mr_info[mr_key] = {
                'addr': buffer_info['addr'],
                'length': buffer_info['length'],
                'offset': buffer_info['offset']
            }

        endpoint_info = {'mr_info': mr_info, 'session_id': f'{self.hostname}:{self.port}'}

        logger.info(f'Generated endpoint info for remote engine {self.remote_engine_id}: '
                    f"session_id={endpoint_info['session_id']}, "
                    f'mr_count={len(mr_info)}')

        return endpoint_info

    def connect(self, connect_request: DistServeKVTransferEndpointInfo):
        """Connect to the remote engine."""
        remote_endpoint_info = json.loads(connect_request.endpoint_info)

        self.remote_url = remote_endpoint_info['session_id']
        self.remote_kv_table = remote_endpoint_info['mr_info']

        logger.info(f'Received remote buffer info: {len(self.remote_kv_table)} regions')
        for mr_key, buffer_info in self.remote_kv_table.items():
            logger.debug(f"Remote buffer mr_key {mr_key}: addr=0x{buffer_info['addr']:x}, "
                         f"length={buffer_info['length']}")

        logger.info(f'Connecting to remote engine {self.remote_engine_id} at {self.remote_url}')

    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
        """Migrate data to the remote engine."""
        if not LMDEPLOY_USE_ASYNC_MIGRATION:
            # For synchronous migration, call the method directly
            self._migrate(assignment)
        else:
            # For asynchronous migration, use an async method
            import asyncio
            loop = asyncio.get_event_loop()
            future = loop.create_future()

            await loop.run_in_executor(None, self._migrate, assignment)

            result = await future
            if result != 0:
                raise RuntimeError(f'Failed to perform async transfer: {result}')

    def _migrate(self, assignment: MigrationAssignment):
        """Migrate data to the remote engine synchronously."""
        if not self.remote_url:
            raise RuntimeError(f'No connection established to remote engine {self.remote_engine_id}')

        for i, task in enumerate(assignment.batch):
            mr_key = str(task.mr_key)

            if mr_key not in self.local_kv_table:
                raise RuntimeError(f'Memory region with mr_key {mr_key} not registered locally')

            if mr_key not in self.remote_kv_table:
                raise RuntimeError(f'Remote memory region with mr_key {mr_key} not registered')

            # Get local buffer information
            local_buffer_info = self.local_kv_table[mr_key]
            local_addr = local_buffer_info['addr'] + task.source_offset

            # Get remote buffer information
            remote_buffer_info = self.remote_kv_table[mr_key]
            remote_addr = remote_buffer_info['addr'] + task.target_offset

            logger.debug(f'Task {i}: Migrating {task.length} bytes')
            logger.debug(f'  Local Engine: {self.local_engine_id}')
            logger.debug(f'  Remote Engine: {assignment.remote_engine_id}')
            logger.debug(f'  MR Key: {mr_key}')
            logger.debug(f"  Local:  0x{local_buffer_info['addr']:x} + {task.source_offset} = 0x{local_addr:x}")
            logger.debug(f"  Remote: 0x{remote_buffer_info['addr']:x} + {task.target_offset} = 0x{remote_addr:x}")
            logger.debug(f'  Session: {self.remote_url}')

            result = self.engine.transfer_sync_read(
                self.remote_url,
                local_addr,
                remote_addr,
                task.length,
            )
            if result != 0:
                raise RuntimeError(f'Failed to perform sync transfer: {result}')


@MIGRATION_BACKENDS.register_module(MigrationBackend.Mooncake.name)
class MooncakeBackend(MigrationBackendImpl):
    """Mooncake backend that manages multiple migration connections."""

    def __init__(self):
        self.links: Dict[int, MooncakeMigrationManagement] = {}

    def p2p_initialize(self, init_request: DistServeInitRequest):
        self.links[init_request.remote_engine_id] = MooncakeMigrationManagement(init_request)

    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
        self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)

    def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
        return self.links[remote_engine_id].endpoint_info

    def p2p_connect(self, remote_engine_id: str, connect_request: DistServeKVTransferEndpointInfo):
        self.links[remote_engine_id].connect(connect_request)

    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
        await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op)

    def store(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError

    def load(self, assignment: MigrationAssignment, async_op: bool = False):
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/disagg/config.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from typing import Optional

from pydantic import BaseModel


class ServingStrategy(enum.Enum):
    """Serving Strategy.

    Attributes:
        Hybrid: Prefill and Decode workload are co-located in one engine.
        DistServe: Prefill and Decode workload are assigned to different
            engines. After the execution of prefill phase in Prefill Engine,
            KVCache is migrated from Prefill to Decode Engine.
    """

    Hybrid = enum.auto()
    DistServe = enum.auto()


class EngineRole(enum.Enum):
    """Role of Engine.

    Note: In the implementation of LMDeploy-Distserve, all engine is hybrid
        engine technically, the role of engine is up to what kind of request is
        sent to the engine. However, taking implementation into the consideration,
        the role is still need to be identified when starting the engine server
        for the following reasons:
            1. Make sure the engine can be correctly discovered by the proxy.
            2. The create of ModelInputs is different among hybrid, prefill and
                decode engines in DP Engine (DSV3 DP + EP).
    """

    Hybrid = enum.auto()
    Prefill = enum.auto()
    Decode = enum.auto()


class MigrationBackend(enum.Enum):
    """Migration Backend."""

    DLSlime = enum.auto()
    Mooncake = enum.auto()


class RDMALinkType(enum.Enum):
    """RDMA Link Type."""

    IB = enum.auto()
    RoCE = enum.auto()


class DistServeRDMAConfig(BaseModel):
    """DistServe RDMA Config.

    Args:
        with_gdr: default to True.
        link_type: default to `RDMALinkType.RoCE`.

    Warning: Only GDR is supported by now.
    Warning: Technically, both RoCE and IB are supported.
        However, IB mode is not tested because of unavailable
        testing envoriment.
    """

    # RDMA with GPU Direct RDMA Access
    with_gdr: bool = True
    link_type: RDMALinkType = RDMALinkType.RoCE


class DistServeTCPConfig(BaseModel):
    """TODO: Add TCP Protocol"""


class DistServeNVLinkConfig(BaseModel):
    """TODO: Add NVLink Protocol"""


class DistServeEngineConfig(BaseModel):
    """DistServe Engine Config.

    In Disaggregated LLM Serving, we need to get engine info of each
    PD Peer for the following reason:
        1. Cache: The stride of cache block for correct offset of KV Transfer.
        2. Parallel: Prefill and decode use different parallel strategy to
            achieve high SLO Attainment or high throughput. In this situation,
            we need to caclculate which prefill-decode worker peers need to connect.
            For example, prefill worker use pp4 and decode worker use tp2pp2,
            the perfill-decode worker conn peer is (0, 0), (0, 1), (1, 0), (1, 1),
            (2, 2), (2, 3), (3, 2), (3, 3). Instead, under the situation of
            (tp4, tp4), perfill-decode worker conn peer is (0, 0), (1, 1), (2, 2),
            (3, 3).
    """

    # parallel config
    # (dp, pp, tp, ep)
    tp_size: int
    ep_size: int
    dp_size: int
    pp_size: Optional[int]

    # Rank of DP
    dp_rank: int

    # cache config
    block_size: int
    num_cpu_blocks: int
    num_gpu_blocks: int


class MooncakeEngineConfig(DistServeEngineConfig):
    """Mooncake Transfer Engine Config.

    TODO: Support more specific config for Mooncake.
    """
    pass


================================================
FILE: lmdeploy/pytorch/disagg/conn/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/disagg/conn/engine_conn.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
from typing import TYPE_CHECKING, Dict, List
from urllib.parse import urlparse

import zmq
import zmq.asyncio

from lmdeploy.logger import get_logger
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,
                                                   DistServeConnectionResponse, DistServeConnectionStatus,
                                                   DistServeDropConnectionRequest, DistServeEngineEndpointInfo,
                                                   DistServeInitRequest, DistServeInitResponse,
                                                   DistServeKVTransferEndpointInfo)
from lmdeploy.pytorch.engine.executor.dist_utils import find_available_port

if TYPE_CHECKING:
    from lmdeploy.pytorch.engine.engine import Engine

logger = get_logger('lmdeploy')


class EngineP2PConnection:

    def __init__(self, engine: 'Engine'):
        self.engine: Engine = engine
        self.p2p_conn_ctx: Dict[str, zmq.asyncio.Context] = {}
        self.p2p_sender: Dict[str, zmq.asyncio.Socket] = {}
        self.p2p_receiver: Dict[str, zmq.asyncio.Socket] = {}

        self.use_unique_kvtransfer_engine = os.environ.get('LMDEPLOY_USE_UNIQUE_KVTRANSFER_ENGINE', False)

    def p2p_initialize(self, init_request: DistServeInitRequest):
        ctx = zmq.asyncio.Context(2)
        sender = ctx.socket(zmq.PUSH)
        sender_port = find_available_port()
        sender_hostname = urlparse(init_request.local_engine_id).hostname
        zmq_address = f'tcp://{sender_hostname}:{sender_port}'
        sender.bind(zmq_address)
        receiver = ctx.socket(zmq.PULL)

        self.p2p_conn_ctx[init_request.remote_engine_id] = ctx
        self.p2p_sender[init_request.remote_engine_id] = sender
        self.p2p_receiver[init_request.remote_engine_id] = receiver

        kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] = self.engine.executor.p2p_initialize(
            init_request)

        return DistServeInitResponse(engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_address),
                                     kvtransfer_endpoint_info=kvtransfer_endpoint_info,
                                     status=DistServeConnectionStatus.SUCCESS)

    def p2p_connect(self, conn_request: DistServeConnectionRequest):
        self.p2p_receiver[conn_request.remote_engine_id].connect(conn_request.remote_engine_endpoint_info.zmq_address)
        self.engine.executor.p2p_connect(remote_engine_id=conn_request.remote_engine_id,
                                         conn_request=conn_request.remote_kvtransfer_endpoint_info)
        event_loop = asyncio.get_event_loop()
        event_loop.create_task(self.handle_zmq_recv(conn_request.remote_engine_id))
        return DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS)

    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
        # TODO (JimyMa): drop RDMA Connection
        self.zmq_disconnect(drop_conn_request.remote_engine_id)
        return {'success': True}

    async def zmq_send(self, remote_engine_id: str, remote_session_id: int):
        await self.p2p_sender[remote_engine_id].send_pyobj(
            DistServeCacheFreeRequest(remote_engine_id=remote_engine_id, remote_session_id=remote_session_id))

    async def handle_zmq_recv(self, remote_engine_id: str):
        while True:
            req: DistServeCacheFreeRequest = await self.p2p_receiver[remote_engine_id].recv_pyobj()
            if isinstance(req, DistServeCacheFreeRequest):
                session_id = req.remote_session_id
                if session_id in self.engine.scheduler.sessions:
                    self.engine.scheduler.end_session(session_id=session_id)
                else:
                    logger.error(f'invalid free, {remote_engine_id}, {session_id}')
            else:
                raise ValueError(f'Unsupported zmq request {type(req)}')

    async def zmq_disconnect(self, remote_engine_id: str):
        self.p2p_receiver[remote_engine_id].close()
        self.p2p_sender[remote_engine_id].close()
        self.p2p_conn_ctx[remote_engine_id].term()


================================================
FILE: lmdeploy/pytorch/disagg/conn/protocol.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from typing import List, Optional

from pydantic import BaseModel

from lmdeploy.pytorch.disagg.config import (DistServeEngineConfig, DistServeNVLinkConfig, DistServeRDMAConfig,
                                            DistServeTCPConfig)


class MigrationProtocol(enum.Enum):
    """Migration Transport Protocol.

    Attributes:
        RDMA: IB or RoCEv1/v2.
        NVLINK: High device-to-device link.

    Warning: By now, only `GPU Directed RDMA` is supported in DistServe.
        We preserve several protocol and will be implemented in the future.
    """

    TCP = enum.auto()
    RDMA = enum.auto()
    NVLINK = enum.auto()


class DistServeConnectionStatus(enum.Enum):
    # TODO(JimyMa): Add more connection failure handler
    SUCCESS = enum.auto()
    FAIL = enum.auto()


class DistServeInitRequest(BaseModel):
    local_engine_id: str
    local_engine_config: DistServeEngineConfig

    remote_engine_id: str
    remote_engine_config: DistServeEngineConfig

    protocol: MigrationProtocol

    rank: Optional[int] = None

    tcp_config: Optional[DistServeTCPConfig] = None
    rdma_config: Optional[DistServeRDMAConfig] = None
    nvlink_config: Optional[DistServeNVLinkConfig] = None


class DistServeEngineEndpointInfo(BaseModel):
    zmq_address: str


class DistServeKVTransferEndpointInfo(BaseModel):
    protocol: MigrationProtocol
    endpoint_info: str


class DistServeInitResponse(BaseModel):
    status: DistServeConnectionStatus
    # the control plane initialization feedback
    engine_endpoint_info: DistServeEngineEndpointInfo
    # the KVCache Transfer initialization feedback
    # To ensure generality (where endpoint_info can be initialization information
    # for different media such as RDMA, NVLink, etc.), we use a string (str) to
    # store this information.
    kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo]


class DistServeConnectionRequest(BaseModel):
    protocol: MigrationProtocol
    remote_engine_id: str
    remote_engine_endpoint_info: DistServeEngineEndpointInfo
    remote_kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo]


class DistServeConnectionResponse(BaseModel):
    status: DistServeConnectionStatus


class MigrationRequest(BaseModel):
    protocol: MigrationProtocol

    remote_engine_id: str
    remote_session_id: int
    remote_token_id: int
    remote_block_ids: List[int]

    is_dummy_prefill: bool = False


class DistServeCacheFreeRequest(BaseModel):
    remote_engine_id: str
    remote_session_id: int


class DistServeDropConnectionRequest(BaseModel):
    engine_id: str
    remote_engine_id: str


================================================
FILE: lmdeploy/pytorch/disagg/conn/proxy_conn.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import enum
import os
from collections import defaultdict
from typing import Dict, Set, Tuple

import aiohttp
import requests

from lmdeploy.logger import get_logger
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,
                                                   DistServeConnectionResponse, DistServeDropConnectionRequest,
                                                   DistServeInitRequest, DistServeInitResponse)
from lmdeploy.pytorch.disagg.messages import PDConnectionMessage

logger = get_logger('lmdeploy')

AIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None)


class PDConnectionStatus(enum.Enum):
    Disconnected = enum.auto()
    Connected = enum.auto()
    Connecting = enum.auto()


class PDConnectionState:
    """PDConnectionState."""

    def __init__(self, status: PDConnectionStatus, event: asyncio.Event):
        self.status = status
        self.event = event

    async def wait(self):
        await self.event.wait()

    def set_status(self, status: PDConnectionStatus):
        self.status = status


def get_server_api(url: str, api: str):
    return f'{url}/{api}'


class PDConnectionPool:
    """Constructing the link of Prefill and Decode engine for the migration of
    KVCache.

    Note: we use Peer to Peer transportation in KVCache migration.
    Note: Lazy link construction is supported, which perform connection
        at the first LLM request. As a result, we don't need to construct
        PD Communication group when start a engine server.
    Note: we perform simple fault tolerance by checkpointing the session_id of a
        request which is under migrating and will trigger `gc` when the decode
        instanceis crushed.
    TODO (JimyMa): By now, only engines with same parallel configuration can be
        correctly connected.
    """

    # Maximum concurrent connections​​
    CONN_SEMAPHORE_SIZE = 2048

    def __init__(self):
        # all prefill and decode instances
        # TODO (JimyMa): Maybe encoding instances
        self.prefill_endpoints: Set[str] = set()
        self.decode_endpoints: Set[str] = set()

        # Links of PD Connection.
        self.pool: Dict[Tuple[str, str], PDConnectionState] = {}

        # put migrating session to `self.migration_session_shelf` for increasing fault tolerance
        # if a session is finished, then pop it from `self.migration_session_shelf`
        # if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance.
        self.migration_session_shelf: Dict[str, Set[int]] = defaultdict(set)

        # conn_perform handler queue
        self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue())

        # conn Registry Lock
        self.conn_lock = asyncio.Lock()

        # Connection Retry when failure
        self.max_retry_cnt = 8

        # trigger signal when conn request arrive.
        self.conn_req_event = asyncio.Event()

        # conn initialized signal
        self.initialized = False

    def reg_instance(self, role: EngineRole, endpoint: str):
        if role == EngineRole.Prefill:
            self.prefill_endpoints.add(endpoint)
        elif role == EngineRole.Decode:
            self.decode_endpoints.add(endpoint)
        else:
            raise ValueError(f'Unsupported role: {role}')

    def dereg_instance(self, endpoint: str):
        if endpoint in self.prefill_endpoints:
            self.prefill_endpoints.remove(endpoint)
        elif endpoint in self.decode_endpoints:
            dropped_key = []
            for conn_key in self.pool.keys():
                if conn_key[1] == endpoint:
                    dropped_key.append(conn_key)
            for k in dropped_key:
                self.drop(k)
            # TODO(JimyMa): handle side-effect by kvcache migration
            self.decode_endpoints.remove(endpoint)

    def shelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):
        self.migration_session_shelf[conn_key].add(session_id)

    def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):
        self.migration_session_shelf[conn_key].remove(session_id)

    async def connect(self, conn_req: PDConnectionMessage):

        async def get_engine_config(server_endpoint):
            async with self.conn_sem:
                async with self.conn_sess.get(
                        get_server_api(server_endpoint, 'distserve/engine_info'),
                        timeout=self.aiotimeout,
                ) as resp:
                    result = await resp.json()
                    return DistServeEngineConfig.model_validate_json(result)

        async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest) -> DistServeInitResponse:
            async with self.conn_sem:
                async with self.conn_sess.post(
                        get_server_api(server_endpoint, 'distserve/p2p_initialize'),
                        json=init_request.model_dump(mode='json'),
                        timeout=self.aiotimeout,
                ) as resp:
                    result = await resp.json()
                    return DistServeInitResponse.model_validate(result)

        async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) -> DistServeConnectionResponse:
            async with self.conn_sem:
                async with self.conn_sess.post(
                        get_server_api(server_endpoint, 'distserve/p2p_connect'),
                        json=conn_request.model_dump(mode='json'),
                        timeout=self.aiotimeout,
                ) as resp:
                    result = await resp.json()
                    return DistServeConnectionResponse.model_validate(result)

        async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event):
            try:
                link = (conn_req.p_url, conn_req.d_url)
                logger.debug(f'{link} connecting...')
                # Step 1. Get Remote Engine Configuration
                prefill_engine_config = await get_engine_config(conn_req.p_url)
                decode_engine_config = await get_engine_config(conn_req.d_url)

                # Note: Only Same Parallel Configurations are supported by now
                assert prefill_engine_config.tp_size == decode_engine_config.tp_size

                # Step 2. Construct Initialize Configuration
                prefill_init_req = DistServeInitRequest(
                    protocol=conn_req.protocol,
                    local_engine_id=conn_req.p_url,
                    local_engine_config=prefill_engine_config,
                    remote_engine_id=conn_req.d_url,
                    remote_engine_config=decode_engine_config,
                    rdma_config=conn_req.rdma_config,
                    nvlink_config=conn_req.nvlink_config,
                )
                decode_init_req = DistServeInitRequest(
                    protocol=conn_req.protocol,
                    local_engine_id=conn_req.d_url,
                    local_engine_config=decode_engine_config,
                    remote_engine_id=conn_req.p_url,
                    remote_engine_config=prefill_engine_config,
                    rdma_config=conn_req.rdma_config,
                    nvlink_config=conn_req.nvlink_config,
                )

                prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req)
                decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req)

                # Step 3. Connection
                prefill_endpoint_conn_reqs = DistServeConnectionRequest(
                    protocol=conn_req.protocol,
                    remote_engine_id=conn_req.d_url,
                    remote_engine_endpoint_info=decode_init_resp.engine_endpoint_info,
                    remote_kvtransfer_endpoint_info=decode_init_resp.kvtransfer_endpoint_info)
                decode_endpoint_conn_reqs = DistServeConnectionRequest(
                    protocol=conn_req.protocol,
                    remote_engine_id=conn_req.p_url,
                    remote_engine_endpoint_info=prefill_init_resp.engine_endpoint_info,
                    remote_kvtransfer_endpoint_info=prefill_init_resp.kvtransfer_endpoint_info)
                await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs)
                await p2p_connect(conn_req.d_url, decode_endpoint_conn_reqs)
                self.pool[link].set_status(PDConnectionStatus.Connected)
                logger.debug(f'{(conn_req.p_url, conn_req.d_url)} connected')
            except Exception as e:
                self.pool[link].set_status(PDConnectionStatus.Disconnected)
                logger.error(f'pd connection error: {e}')
            conn_event.set()

        async def wait_for_conn(conn_req: PDConnectionMessage, conn_event: asyncio.Event):
            await self.pool[(conn_req.p_url, conn_req.d_url)].event.wait()
            conn_event.set()

        async def _perform_conn():
            logger.debug('perform_conn start')
            while True:
                if self.waiting_conn.empty():
                    await self.conn_req_event.wait()

                self.conn_req_event.clear()

                while not self.waiting_conn.empty():
                    conn_req, conn_event = self.waiting_conn.get_nowait()
                    link = (conn_req.p_url, conn_req.d_url)
                    if link not in self.pool:
                        self.pool[link] = PDConnectionState(
                            PDConnectionStatus.Disconnected,
                            conn_event,
                        )
                    if self.pool[link].status == PDConnectionStatus.Connecting:
                        asyncio.create_task(wait_for_conn(conn_req, conn_event))
                    elif self.pool[link].status == PDConnectionStatus.Disconnected:
                        self.pool[link].set_status(PDConnectionStatus.Connecting)
                        asyncio.create_task(conn_worker(conn_req, conn_event))

        if not self.initialized:
            loop = asyncio.get_event_loop()
            loop.create_task(_perform_conn())
            self.conn_sem = asyncio.Semaphore(self.CONN_SEMAPHORE_SIZE)
            self.conn_sess = aiohttp.ClientSession(
                connector=aiohttp.TCPConnector(limit_per_host=256),
                timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT),
            )
            self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT)
            self.initialized = True

        self.reg_instance(EngineRole.Prefill, conn_req.p_url)
        self.reg_instance(EngineRole.Decode, conn_req.d_url)

        cnt = 0
        while cnt < self.max_retry_cnt:
            if self.is_connected(conn_req.p_url, conn_req.d_url):
                return
            if cnt > 0:
                logger.warning(f'Connection failure, retry cnt: {cnt}')
            conn_event = asyncio.Event()
            self.waiting_conn.put_nowait((conn_req, conn_event))
            self.conn_req_event.set()
            await conn_event.wait()
            cnt += 1
        async with self.conn_lock:
            self.pool[conn_req.p_url, conn_req.d_url].set_status(PDConnectionStatus.Disconnected)
        raise TimeoutError('PDConnection Failure')

    def is_connected(self, p_url: str, d_url: str):
        link = self.pool.get((p_url, d_url), None)
        if not link:
            return False
        return link.status == PDConnectionStatus.Connected

    def drop(self, pd_key: Tuple[str, str]):
        left = pd_key[0]
        right = pd_key[1]

        def cache_free(server_endpoint, cache_free_request: DistServeCacheFreeRequest) -> Dict:
            try:
                requests.post(get_server_api(server_endpoint, 'distserve/free_cache'),
                              json=cache_free_request.model_dump(mode='json'))
            except Exception as e:
                logger.warning(f'error cache block free {server_endpoint, cache_free_request}. ErrorMsg: {str(e)}')

        def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConnectionRequest):
            try:
                requests.post(get_server_api(server_endpoint, 'distserve/p2p_drop_connect'),
                              json=p2p_disconnect_request.model_dump(mode='json'))
            except Exception as e:
                logger.warning(f'error drop connect {server_endpoint, p2p_disconnect_request}. ErrorMsg: {str(e)}')

        # trigger gc
        logger.warning('cache block gc triggered.')
        try:
            for session_id in self.migration_session_shelf[(left, right)]:
                cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id))
        except Exception as e:
            logger.warning(f'gc error, ErrorMsg: {str(e)}')

        # trigger p2p disconnect
        logger.warning('drop connection triggered.')
        try:
            drop_connect(left, DistServeDropConnectionRequest(engine_id=left, remote_engine_id=right))
            drop_connect(right, DistServeDropConnectionRequest(engine_id=right, remote_engine_id=left))
        except Exception as e:
            logger.warning(f'p2p disconnect error, ErrorMsg: {str(e)}')

        self.pool.pop((left, right), None)


================================================
FILE: lmdeploy/pytorch/disagg/messages.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

from pydantic import BaseModel

from lmdeploy.pytorch.disagg.config import DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig
from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol


class MigrationExecutionBatch(BaseModel):
    """Input of the Migration."""

    protocol: MigrationProtocol
    requests: List[Tuple[str, List[Tuple[int, int]]]] = []


class AssignmentInstruct(BaseModel):
    """Assignment Batch."""
    mr_key: int
    target_offset: int
    source_offset: int
    length: int


class MigrationAssignment(BaseModel):
    """Migration Assignment."""
    protocol: MigrationProtocol
    remote_engine_id: str
    batch: List[AssignmentInstruct]


class PDConnectionMessage(BaseModel):
    p_url: str
    d_url: str
    protocol: MigrationProtocol = MigrationProtocol.RDMA
    tcp_config: Optional[DistServeTCPConfig] = None
    rdma_config: Optional[DistServeRDMAConfig] = None
    nvlink_config: Optional[DistServeNVLinkConfig] = None


class DistServeRegisterMRMessage(BaseModel):
    protocol: MigrationProtocol

    remote_engine_id: str
    mr_key: int
    addr: int
    offset: int
    length: int


================================================
FILE: lmdeploy/pytorch/distributed.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional

import torch
from torch import distributed as dist
from torch.distributed import ProcessGroup, ReduceOp, Work  # noqa: F401

from lmdeploy.pytorch.utils import CtxMgrBase, singleton

from .config import DistConfig, TPMode


@dataclass
class DistGroup:
    """Distributed group."""
    rank: int = 0
    cpu_group: dist.ProcessGroup = None
    gpu_group: dist.ProcessGroup = None
    cpu_groups: List[dist.ProcessGroup] = None
    gpu_groups: List[dist.ProcessGroup] = None
    gpu_gather_group: dist.ProcessGroup = None

    def close(self):
        """Close groups."""
        if not dist.is_initialized():
            return
        if self.cpu_groups is not None:
            for group in self.cpu_groups:
                dist.destroy_process_group(group)
            self.cpu_groups = None
        if self.gpu_groups is not None:
            for group in self.gpu_groups:
                dist.destroy_process_group(group)
            self.gpu_groups = None


def _build_tp_group_impl(tp: int,
                         rank: int,
                         world_size: int,
                         timeout: timedelta,
                         cpu_backend: str = 'gloo',
                         ccl_backend: str = 'nccl',
                         attn_tp: int = 1,
                         tp_mode: TPMode = TPMode.DEFAULT):
    """Build tp group."""
    assert tp > 1
    tp_rank = rank % tp
    tp_group_id = rank // tp
    gather_group_id = (rank - tp_group_id * tp) % attn_tp
    ranks = range(world_size)
    tp_gpu_groups = []
    tp_cpu_groups = []
    gather_groups = []
    for start in range(0, world_size, tp):
        tp_ranks = ranks[start:start + tp]
        group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend)
        tp_gpu_groups.append(group)
        cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend)
        tp_cpu_groups.append(cpu_group)

        # create gather group
        if tp_mode == TPMode.DP_TP and attn_tp != tp:
            for g_start in range(start, start + attn_tp):
                g_ranks = ranks[g_start:(g_start + tp):attn_tp]
                gather_group = dist.new_group(ranks=g_ranks, timeout=timeout, backend=ccl_backend)
                gather_groups.append(gather_group)
    tp_gpu_group = tp_gpu_groups[tp_group_id]
    tp_cpu_group = tp_cpu_groups[tp_group_id]

    if tp_mode == TPMode.DP_TP:
        if attn_tp == tp:
            gather_group = tp_gpu_group
        else:
            gather_group = gather_groups[gather_group_id]
    else:
        gather_group = None
    return DistGroup(
        rank=tp_rank,
        cpu_group=tp_cpu_group,
        gpu_group=tp_gpu_group,
        cpu_groups=tp_cpu_groups,
        gpu_groups=tp_gpu_groups,
        gpu_gather_group=gather_group,
    )


def _build_attn_tp_group(context: 'DistContext',
                         timeout: timedelta,
                         cpu_backend: str = 'gloo',
                         ccl_backend: str = 'nccl'):
    """Build attention tp group."""
    dist_config = context.dist_config
    tp = dist_config.attn_tp
    # skip if tp == 1
    if tp == 1:
        context.attn_tp_group = DistGroup(rank=0)
        return

    dist_group = _build_tp_group_impl(
        tp,
        context.rank,
        dist_config.world_size,
        timeout=timeout,
        cpu_backend=cpu_backend,
        ccl_backend=ccl_backend,
        attn_tp=tp,
        tp_mode=TPMode.DEFAULT,
    )
    context.attn_tp_group = dist_group


def _build_mlp_tp_group(context: 'DistContext',
                        timeout: timedelta,
                        cpu_backend: str = 'gloo',
                        ccl_backend: str = 'nccl'):
    """Build mlp tp group."""
    dist_config = context.dist_config
    tp = dist_config.mlp_tp
    # skip if tp == 1
    if tp == 1:
        context.mlp_tp_group = DistGroup(rank=0)
        return

    # reuse attn tp group
    if tp == dist_config.attn_tp:
        context.mlp_tp_group = context.attn_tp_group
        return

    dist_group = _build_tp_group_impl(
        tp,
        context.rank,
        dist_config.world_size,
        timeout=timeout,
        cpu_backend=cpu_backend,
        ccl_backend=ccl_backend,
        attn_tp=dist_config.attn_tp,
        tp_mode=dist_config.mlp_tp_mode,
    )
    context.mlp_tp_group = dist_group


def _build_moe_tp_group(context: 'DistContext',
                        timeout: timedelta,
                        cpu_backend: str = 'gloo',
                        ccl_backend: str = 'nccl'):
    """Build moe tp group."""
    dist_config = context.dist_config
    tp = dist_config.moe_tp
    # skip if tp == 1
    if tp == 1:
        context.moe_tp_group = DistGroup(rank=0)
        return

    # reuse attn tp group
    if tp == dist_config.attn_tp:
        context.moe_tp_group = context.attn_tp_group
        return

    # reuse mlp tp group
    if tp == dist_config.mlp_tp:
        context.moe_tp_group = context.mlp_tp_group
        return

    dist_group = _build_tp_group_impl(
        tp,
        context.rank,
        dist_config.world_size,
        timeout=timeout,
        cpu_backend=cpu_backend,
        ccl_backend=ccl_backend,
        attn_tp=dist_config.attn_tp,
        tp_mode=dist_config.moe_tp_mode,
    )
    context.moe_tp_group = dist_group


def _build_tp_group(context: 'DistContext', timeout: timedelta, cpu_backend: str = 'gloo', ccl_backend: str = 'nccl'):
    """Build tp group."""
    _build_attn_tp_group(context, timeout, cpu_backend, ccl_backend)
    _build_mlp_tp_group(context, timeout, cpu_backend, ccl_backend)
    _build_moe_tp_group(context, timeout, cpu_backend, ccl_backend)
    context.tp_group = context.attn_tp_group


@dataclass
class DistContext:
    rank: int = 0
    dp_rank: int = 0
    ep_rank: int = 0

    tp_group: DistGroup = None
    attn_tp_group: DistGroup = None
    mlp_tp_group: DistGroup = None
    moe_tp_group: DistGroup = None

    cpu_group: dist.ProcessGroup = None
    ep_gpu_group: dist.ProcessGroup = None
    ep_gpu_groups: List[dist.ProcessGroup] = None
    dist_config: DistConfig = None

    @classmethod
    def _build_ep_group(cls, context: 'DistContext', timeout: timedelta, ccl_backend: str = 'nccl'):
        """Build ep group."""
        dist_config = context.dist_config
        ep = dist_config.ep
        if ep <= 1:
            return

        dp_rank = context.dp_rank
        world_size = dist_config.world_size
        ep_rank = context.rank % ep
        ep_group_id = dp_rank // ep
        ranks = range(world_size)
        ep_gpu_groups = []
        for start in range(0, world_size, ep):
            ep_ranks = ranks[start:start + ep]
            group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend)
            ep_gpu_groups.append(group)
        ep_gpu_group = ep_gpu_groups[ep_group_id]

        context.ep_rank = ep_rank
        context.ep_gpu_group = ep_gpu_group
        context.ep_gpu_groups = ep_gpu_groups

    @classmethod
    def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'):
        """Build dist context."""
        timeout = timedelta(days=35600)
        cpu_backend = 'gloo'

        if dist_config is None:
            dist_config = DistConfig()

        dp_rank = dist_config.dp_rank
        world_size = dist_config.world_size
        context = DistContext(rank=rank,
                              dp_rank=dp_rank,
                              dist_config=dist_config,
                              attn_tp_group=DistGroup(rank=0),
                              mlp_tp_group=DistGroup(rank=0),
                              moe_tp_group=DistGroup(rank=0),
                              tp_group=DistGroup(rank=0))
        if world_size == 1:
            return context

        assert dist.is_initialized()

        # cpu group
        context.cpu_group = dist.new_group(ranks=list(range(world_size)), timeout=timeout, backend=cpu_backend)

        # tp
        _build_tp_group(context, timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend)

        # ep
        cls._build_ep_group(context, timeout, ccl_backend=ccl_backend)

        return context

    def close(self):
        """Close groups."""
        if not dist.is_initialized():
            return
        if self.attn_tp_group is not None:
            self.attn_tp_group.close()
        if self.mlp_tp_group is not None:
            self.mlp_tp_group.close()
        if self.moe_tp_group is not None:
            self.moe_tp_group.close()
        if self.ep_gpu_groups is not None:
            for group in self.ep_gpu_groups:
                dist.destroy_process_group(group)
            self.ep_gpu_groups = None


DefaultContext = DistContext.build()


@singleton
class DistManager(CtxMgrBase[DistContext]):
    """Distributed context manager."""

    def __init__(self):
        super().__init__(DefaultContext)

    def current_config(self) -> DistConfig:
        """Get current dist config."""
        return self.current_context().dist_config


def get_dist_manager():
    """Get device manager."""
    return DistManager()


def get_world_rank():
    """Get distributed world size and rank."""
    ctx = get_dist_manager().current_context()
    world_size = ctx.dist_config.world_size
    rank = ctx.rank

    return world_size, rank


def get_tp_world_rank(layer_type: Optional[str] = None):
    ctx = get_dist_manager().current_context()
    if layer_type is None:
        return ctx.dist_config.tp, ctx.tp_group.rank
    elif layer_type == 'attn':
        return ctx.dist_config.attn_tp, ctx.attn_tp_group.rank
    elif layer_type == 'mlp':
        return ctx.dist_config.mlp_tp, ctx.mlp_tp_group.rank
    elif layer_type == 'moe':
        return ctx.dist_config.moe_tp, ctx.moe_tp_group.rank
    else:
        raise RuntimeError(f'Unknown layer type: {layer_type}')


def get_dp_world_rank():
    ctx = get_dist_manager().current_context()
    return ctx.dist_config.dp, ctx.dp_rank


def get_ep_world_rank():
    ctx = get_dist_manager().current_context()
    return ctx.dist_config.ep, ctx.ep_rank


def _check_group_device(device: str):
    """Check group device."""
    assert (device in ['cpu', 'gpu']), ('Expect process group device in ("cpu", "gpu"), '
                                        f'but get {device}.')


def get_process_group(device: str = None):
    """Get process group."""
    return dist.GroupMember.WORLD


def get_dist_group(layer_type: str = 'attn'):
    """Get dist group."""
    ctx = get_dist_manager().current_context()
    if layer_type == 'attn':
        tp_group = ctx.attn_tp_group
    elif layer_type == 'mlp':
        tp_group = ctx.mlp_tp_group
    elif layer_type == 'moe':
        tp_group = ctx.moe_tp_group
    else:
        raise RuntimeError(f'Unknown layer type: {layer_type}')
    return tp_group


def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'):
    """Get tp group."""
    _check_group_device(device)
    tp_group = get_dist_group(layer_type)

    if tp_group is None:
        return None

    if device == 'cpu':
        return tp_group.cpu_group
    else:
        return tp_group.gpu_group


def get_group(group_type: str, device: str):
    """Get group."""
    if group_type == 'tp':
        return get_tp_group(device)
    elif group_type in ['world', 'all']:
        return get_process_group(device)
    else:
        raise RuntimeError(f'Unknown group type: {group_type}')


def all_reduce(tensor, op=ReduceOp.SUM, group='tp', async_op=False):
    """All reduce."""
    if isinstance(group, str):
        group = get_group(group, 'gpu')
    return dist.all_reduce(tensor, op, group, async_op)


def broadcast(tensor, src, group='tp', async_op=False):
    """broadcast."""
    if isinstance(group, str):
        group = get_group(group, 'gpu')
    return dist.broadcast(tensor, src, group, async_op)


def all_gather_object(object_list, obj, group='tp'):
    if isinstance(group, str):
        group = get_group(group, 'cpu')
    return dist.all_gather_object(object_list, obj, group=group)


def all_gather(tensor_list, tensor, group='tp', async_op=False):
    if isinstance(group, str):
        group = get_group(group, 'gpu')
    return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)


def all_gather_into_tensor(output_tensor, input_tensor, group='tp', async_op=False):
    if isinstance(group, str):
        group = get_group(group, 'gpu')
    return dist.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op)


def reduce_scatter(output, input_list, op=ReduceOp.SUM, group='tp', async_op=False):
    """Reduce scatter."""
    if isinstance(group, str):
        group = get_group(group, 'gpu')
    return dist.reduce_scatter(output, input_list, op=op, group=group, async_op=async_op)


def gather_by_tp_sizes(x: torch.Tensor,
                       tp_sizes: List[int],
                       group: Optional[dist.ProcessGroup] = None,
                       async_op: bool = False):
    """Gather input."""
    assert all(size >= 0 for size in tp_sizes), f'Invalid tp sizes: {tp_sizes}'
    shape = (*x.shape[:-2], sum(tp_sizes), *x.shape[-1:])
    new_x = x.new_empty(shape)
    split_new_x = list(new_x.split(tp_sizes, -2))
    handle = dist.all_gather(split_new_x, x, group=group, async_op=async_op)
    if async_op:
        return new_x, handle
    return new_x


def reduce_scatter_by_tp_sizes(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup):
    """Reduce scatter."""
    attn_tp = get_dist_manager().current_config().attn_tp
    outs = list(out.split(tp_sizes, -2))
    outs = [item for item in outs for _ in range(attn_tp)]
    out = outs[rank]
    dist.reduce_scatter(out, outs, group=group)
    return out


================================================
FILE: lmdeploy/pytorch/engine/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .engine import Engine
from .engine_instance import EngineInstance

__all__ = ['Engine', 'EngineInstance']


================================================
FILE: lmdeploy/pytorch/engine/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
                                                   DistServeInitRequest)


class EngineBase:

    def close(self) -> None:
        """Close mp engine."""
        raise NotImplementedError('This method is not implemented.')

    def start_loop(self) -> None:
        """Start mp engine loop."""

    def end_session(self, session_id: int):
        """End session."""
        raise NotImplementedError('This method is not implemented.')

    def p2p_initialize(self, conn_request: DistServeInitRequest):
        """Init rdma link."""
        raise NotImplementedError('This method is not implemented.')

    def p2p_connect(self, conn_request: DistServeConnectionRequest):
        """rdma_connect."""
        raise NotImplementedError('This method is not implemented.')

    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
        """Drop connection.

        1. drop engine connection (zmq connection)
        2. TODO(JimyMa) drop RDMA Connection.
        """
        raise NotImplementedError('This method is not implemented.')

    def create_instance(self, cuda_stream_id=0):
        """Create instance."""
        raise NotImplementedError('This method is not implemented.')


class EngineInstanceBase:

    async def async_end(self, session_id: int):
        """End the given session."""
        raise NotImplementedError('This method is not implemented.')

    async def async_cancel(self, session_id: int):
        """Stop current streaming inference."""
        raise NotImplementedError('This method is not implemented.')

    async def async_stream_infer(self, *args, **kwargs):
        """Send stream inference request."""
        raise NotImplementedError('This method is not implemented.')


================================================
FILE: lmdeploy/pytorch/engine/cache_engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import json
import math
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Sequence, Tuple

import torch

from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import (AssignmentInstruct, DistServeRegisterMRMessage, MigrationAssignment,
                                              MigrationExecutionBatch)
from lmdeploy.utils import get_logger

from ..config import CacheConfig, ModelConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]

logger = get_logger('lmdeploy')


def round_up(x: int, alignment: int) -> int:
    """Round up x to the nearest multiple of alignment."""
    return ((x + alignment - 1) // alignment) * alignment


@dataclass
class CacheDesc:
    """Cache description."""
    shape: List[int]
    dtype: torch.dtype
    alignment: int = 256

    def __post_init__(self):
        self.numel = math.prod(self.shape)
        self.size = self.numel * self.dtype.itemsize
        self.aligned_size = round_up(self.size, self.alignment)


def _get_kv_cache_dtype(model_config: ModelConfig):
    kv_cache_dtype = model_config.dtype
    if model_config.use_mla_fp8_cache:
        kv_cache_dtype = torch.float8_e4m3fn
    return kv_cache_dtype


# 512*1 + 4*4 + 64*2 = 656
MLA_FP8_HEAD_DIM = 656


class CacheEngine:
    """Host and Device memory maintainer.

    Args:
        cache_config (CacheConfig): config of the cache information.
        model_config (ModelConfig): config of the model.
        rank (int): distribution rank, 0 on non-distributed environment.
        world_size (int): distribution world size, 1 on non-distributed
            environment.
        cache_stream (torch.cuda.Stream): the stream used for cache engine swap,
            if set to None, it's created in CacheEngine.
    """

    def __init__(
        self,
        cache_config: CacheConfig,
        model_config: ModelConfig,
        rank: int = 0,
        tp_rank: int = 0,
        world_size: int = 1,
        cache_stream: torch.cuda.Stream = None,
    ) -> None:
        self.world_size = world_size
        self.rank = rank
        self.tp_rank = tp_rank
        self.cache_config = cache_config
        self.model_config = model_config

        self.block_size = cache_config.block_size
        self.num_layers = model_config.num_layers
        self.kv_cache_dtype = _get_kv_cache_dtype(self.model_config)

        if self.model_config.use_mla_fp8_cache:
            cache_config.quant_policy = 0

        if cache_config.quant_policy > 0:
            if self.cache_config.device_type in ['cuda']:
                self.kv_cache_dtype = torch.uint8
            elif self.cache_config.device_type in ['ascend', 'npu']:
                self.kv_cache_dtype = torch.int8
            else:
                raise ValueError(f'unsupported device_type {self.cache_config.device_type}')

        # Initialize the cache.
        self.local_gpu_cache = self.allocate_gpu_cache()
        self.local_cpu_cache = self.allocate_cpu_cache()

        self.migration_backend_impl: Optional[MigrationBackendImpl] = None

        # Initialize the stream for caching operations.
        self.cache_stream = cache_stream or torch.cuda.Stream()
        assert self.cache_stream != torch.cuda.current_stream()
        # Initialize the events for stream synchronization.
        self.events = torch.cuda.Event()

        logger.debug(f'Initialize cache engine with {cache_config.num_gpu_blocks}'
                     f' gpu blocks and {cache_config.num_cpu_blocks} cpu blocks.')

    @property
    def cpu_cache(self):
        """Gpu cache."""
        return self.local_cpu_cache

    @property
    def gpu_cache(self):
        """Gpu cache."""
        return self.local_gpu_cache

    @property
    def num_gpu_blocks(self):
        """Num gpu blocks."""
        return self.cache_config.num_gpu_blocks

    @property
    def num_cpu_blocks(self):
        """Num gpu blocks."""
        return self.cache_config.num_cpu_blocks

    @classmethod
    def _get_key_block_shape_impl(cls,
                                  model_config: ModelConfig,
                                  block_size: int,
                                  head_size: int,
                                  world_size: int = 1,
                                  quant_policy: Literal[0, 4, 8] = 0):
        """Get single block shape."""
        attn_backend = get_backend()
        dtype = model_config.dtype
        num_heads = model_config.num_key_value_heads

        # split heads by tp
        assert num_heads % world_size == 0, \
            f'num_heads: {num_heads}, world_size: {world_size}'
        num_heads = num_heads // world_size

        # patch for flash mla
        if model_config.use_mla_fp8_cache:
            return (block_size, num_heads, MLA_FP8_HEAD_DIM)

        if quant_policy == 4:  # pack head_dim to uint8
            assert head_size % 2 == 0, \
                f'head_size: {head_size}, quant_policy: {quant_policy}'
            head_size = head_size // 2
        return attn_backend.get_k_block_shape(block_size, num_heads, head_size, dtype)

    @classmethod
    def _get_value_block_shape_impl(cls,
                                    model_config: ModelConfig,
                                    block_size: int,
                                    head_size: int,
                                    world_size: int = 1,
                                    quant_policy: Literal[0, 4, 8] = 0):
        """Get single block shape."""
        attn_backend = get_backend()
        dtype = model_config.dtype
        num_heads = model_config.num_key_value_heads

        # split heads by tp
        assert num_heads % world_size == 0, \
            f'num_heads: {num_heads}, world_size: {world_size}'
        num_heads = num_heads // world_size

        # patch for flash mla
        if model_config.use_mla_fp8_cache:
            # flash mla shared key and value
            return (block_size, num_heads, 0)

        if quant_policy == 4:  # pack head_dim to uint8
            assert head_size % 2 == 0, \
                f'head_size: {head_size}, quant_policy: {quant_policy}'
            head_size = head_size // 2

        return attn_backend.get_v_block_shape(block_size, num_heads, head_size, dtype)

    @classmethod
    def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:
        """Get key cache description."""
        head_size = model_config.k_head_dim
        if head_size is None:
            head_size = model_config.head_dim
        shape = cls._get_key_block_shape_impl(
            model_config,
            block_size=cache_config.block_size,
            head_size=head_size,
            world_size=world_size,
            quant_policy=cache_config.quant_policy,
        )
        shape = list(shape)
        dtype = _get_kv_cache_dtype(model_config)
        if cache_config.quant_policy in (4, 8):
            dtype = torch.uint8
        return CacheDesc(shape=shape, dtype=dtype)

    @classmethod
    def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:
        """Get value cache description."""
        head_size = model_config.v_head_dim
        if head_size is None:
            head_size = model_config.head_dim
        shape = cls._get_value_block_shape_impl(
            model_config,
            block_size=cache_config.block_size,
            head_size=head_size,
            world_size=world_size,
            quant_policy=cache_config.quant_policy,
        )
        shape = list(shape)
        dtype = _get_kv_cache_dtype(model_config)
        if cache_config.quant_policy in (4, 8):
            dtype = torch.uint8
        return CacheDesc(shape=shape, dtype=dtype)

    @classmethod
    def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig,
                              cache_config: CacheConfig):
        """Get quant cache descs."""
        if cache_config.quant_policy == 0:
            return []

        dtype = model_config.dtype
        key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
        val_scale_zero_shape = v_cache_desc.shape[:-1] + [2]
        key_scale_zero_desc = CacheDesc(shape=key_scale_zero_shape, dtype=dtype)
        val_scale_zero_desc = CacheDesc(shape=val_scale_zero_shape, dtype=dtype)
        return [key_scale_zero_desc, val_scale_zero_desc]

    @classmethod
    def get_custom_cache_descs(cls, model_config: ModelConfig, cache_config: CacheConfig) -> List[CacheDesc]:
        """Get custom cache descs."""
        if len(model_config.cache_shapes) == 0:
            return []

        block_size = cache_config.block_size

        descs = []
        for shape, dtype in model_config.cache_shapes:
            custom_shape = (block_size, *shape)
            desc = CacheDesc(shape=custom_shape, dtype=dtype)
            descs.append(desc)
        return descs

    @classmethod
    def allocate_caches(cls, num_blocks: int, model_config: ModelConfig, cache_config: CacheConfig, world_size: int,
                        device: str):
        """Allocate caches."""

        num_layers = model_config.num_layers

        # get all descs
        k_cache_desc = cls.get_k_cache_desc(model_config, cache_config, world_size)
        v_cache_desc = cls.get_v_cache_desc(model_config, cache_config, world_size)
        quant_cache_descs = cls.get_quant_cache_descs(k_cache_desc, v_cache_desc, model_config, cache_config)
        custom_cache_descs = cls.get_custom_cache_descs(model_config, cache_config)
        cache_descs = [k_cache_desc, v_cache_desc] + quant_cache_descs + custom_cache_descs

        # get mempool size
        mem_pool_size = 0
        for desc in cache_descs:
            mem_pool_size += desc.aligned_size

        # create pool
        mem_pool = torch.zeros((num_layers, num_blocks, mem_pool_size), dtype=torch.uint8, device=device)

        # slice caches
        caches = []
        remain_pool = mem_pool
        for desc in cache_descs:
            cache = remain_pool[:, :, :desc.size].view(desc.dtype).view((num_layers, num_blocks, *desc.shape))
            remain_pool = remain_pool[:, :, desc.aligned_size:]
            caches.append(cache)
        return mem_pool, caches

    def allocate_gpu_cache(self):
        """Allocate caches on GPU."""
        mem_pool, caches = self.allocate_caches(
            num_blocks=self.num_gpu_blocks,
            model_config=self.model_config,
            cache_config=self.cache_config,
            world_size=self.world_size,
            device='cuda',
        )
        self.full_gpu_cache = mem_pool
        self.local_gpu_cache = list(zip(*caches))
        return self.local_gpu_cache

    def allocate_cpu_cache(self):
        """Allocate caches on Host."""
        mem_pool, caches = self.allocate_caches(
            num_blocks=self.num_cpu_blocks,
            model_config=self.model_config,
            cache_config=self.cache_config,
            world_size=self.world_size,
            device='cpu',
        )
        self.full_cpu_cache = mem_pool
        self.local_cpu_cache = list(zip(*caches))
        return self.local_cpu_cache

    @staticmethod
    def get_custom_cache_shape_impl(num_layers: int, num_blocks: int, block_size: int, shape: List[int]):
        """Get single block shape."""
        return (num_layers, num_blocks, block_size, *shape)

    @staticmethod
    def _allocate_single_custom_cache(shape: Sequence[int], dtype: torch.dtype, device: str):
        """Allocate custom cache."""
        return torch.empty(shape, dtype=dtype, device=device)

    def allocate_custom_cache(self, device: str):
        """Allocate custom caches on GPU."""
        num_layers = self.model_config.num_layers
        custom_caches = []
        for shape, dtype in self.model_config.cache_shapes:
            custom_shape = self.get_custom_cache_shape_impl(
                num_layers=num_layers,
                num_blocks=self.num_gpu_blocks,
                block_size=self.block_size,
                shape=shape,
            )
            custom_cache = self._allocate_single_custom_cache(shape=custom_shape, dtype=dtype, device=device)
            custom_caches.append(custom_cache)
        return custom_caches

    @torch.inference_mode()
    def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor], src_to_dst: Dict[int, int]):
        """Move caches from src memory to dst memory.

        Args:
            src (List[KVCache]): Source cache.
            dst (List[KVCache]): Destination cache.
            src_to_dst (Dict[int, int]): Map between src and dst.
        """
        BLOCKS_PER_COPY = 2
        num_copy = len(src_to_dst)
        src_idx, dst_idx = list(zip(*src_to_dst.items()))
        src_idx = torch.tensor(src_idx, device=src[0].device)
        dst_idx = torch.tensor(dst_idx, device=dst[0].device)
        with torch.cuda.stream(self.cache_stream):
            for scache, dcache in zip(src, dst):
                for idx in range(0, num_copy, BLOCKS_PER_COPY):
                    sidx = src_idx[idx:idx + BLOCKS_PER_COPY]
                    didx = dst_idx[idx:idx + BLOCKS_PER_COPY]
                    sdata = scache[:, sidx]
                    dcache.index_copy_(1, didx, sdata.to(dcache.device))
            self.events.record(stream=self.cache_stream)

    def swap_in(self, src_to_dst: Dict[int, int]) -> None:
        """Move cache from Host to Device.

        Args:
            src_to_dst (Dict[int, int]): Map between src and dst.
        """
        self._swap([self.full_cpu_cache], [self.full_gpu_cache], src_to_dst)

    def swap_out(self, src_to_dst: Dict[int, int]) -> None:
        """Move cache from Device to Host.

        Args:
            src_to_dst (Dict[int, int]): Map between src and dst.
        """
        self._swap([self.full_gpu_cache], [self.full_cpu_cache], src_to_dst)

    @classmethod
    def get_cache_block_size(cls, cache_config: CacheConfig, model_config: ModelConfig, world_size: int = 1) -> int:
        """Get the required cache size of the model.

        Args:
            block_size (int): The token numbers of the block.
            model_config (ModelConfig): The config of the model.

        Return:
            int: Required memory size in bytes.
        """
        mem_pool, _ = cls.allocate_caches(
            num_blocks=1,
            model_config=model_config,
            cache_config=cache_config,
            world_size=world_size,
            device='meta',
        )

        return mem_pool.numel() * mem_pool.element_size()

    """ Metheds for PD Disaggregation Begin. """

    def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistServeKVTransferEndpointInfo:
        if not self.migration_backend_impl:
            self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]()
        migration_init_request.rank = self.rank
        self.migration_backend_impl.p2p_initialize(migration_init_request)
        for i, t in enumerate([self.full_gpu_cache]):
            if t.numel() == 0:
                continue
            register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol,
                                                             remote_engine_id=migration_init_request.remote_engine_id,
                                                             mr_key=i,
                                                             addr=t.data_ptr(),
                                                             offset=t.storage_offset(),
                                                             length=t.numel() * t.itemsize)
            self.migration_backend_impl.register_memory_region(register_mr_request)
        return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol,
                                               endpoint_info=json.dumps(
                                                   self.migration_backend_impl.endpoint_info(
                                                       migration_init_request.remote_engine_id,
                                                       migration_init_request.protocol)))

    def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]):
        self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank])

    async def migrate(self, migration_execution_inputs: MigrationExecutionBatch):

        assignment_len = self.full_gpu_cache.element_size() * self.full_gpu_cache.size(-1)
        layer_stride = self.cache_config.num_gpu_blocks * assignment_len

        def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote_layer_stride):
            return [
                AssignmentInstruct(mr_key=mr_key,
                                   target_offset=block_id[0] * assignment_len + layer * remote_layer_stride,
                                   source_offset=block_id[1] * assignment_len + layer * layer_stride,
                                   length=assignment_len) for layer in range(self.model_config.num_layers)
                for block_id in block_ids
            ]

        assignment_batch: List[Tuple[str, int, int, int]] = []  # mr_key, target, source, offset
        for migration_exe_req in migration_execution_inputs.requests:
            remote_engine_id = migration_exe_req[0]
            blocks_to_migration = migration_exe_req[1]
            remote_layer_stride = self.migration_backend_impl.links[
                remote_engine_id].remote_engine_config.num_gpu_blocks * assignment_len

            for i, t in enumerate([self.full_gpu_cache]):
                if t.numel() == 0:
                    continue
                assignment_batch.extend(
                    get_assignment_batch(i, blocks_to_migration, assignment_len, layer_stride, remote_layer_stride))
        await self.migration_backend_impl.p2p_migrate(
            MigrationAssignment(
                protocol=migration_execution_inputs.protocol,
                remote_engine_id=remote_engine_id,
                batch=assignment_batch,
            ))

    """ Metheds for PD Disaggregation End. """


class StateCacheEngine:
    """Cache engine for state cache."""

    def __init__(self, cache_config: CacheConfig):
        self.cache_config = cache_config
        self.mem_pool, self._state_caches = self.allocate_caches(num_caches=cache_config.num_state_caches,
                                                                 state_shapes=cache_config.states_shapes,
                                                                 device='cuda')

    @staticmethod
    def allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device):
        """Allocate cache implement."""

        if len(state_shapes) == 0 or num_caches == 0:
            return torch.empty((0, 0), dtype=torch.uint8, device=device), []

        cache_descs = [CacheDesc(shape, dtype) for shape, dtype in state_shapes]

        # get mempool size
        mem_pool_size = 0
        for desc in cache_descs:
            mem_pool_size += desc.aligned_size

        # create pool
        mem_pool = torch.zeros((num_caches, mem_pool_size), dtype=torch.uint8, device=device)

        # slice caches
        caches = []
        remain_pool = mem_pool
        for desc in cache_descs:
            cache = remain_pool[:, :desc.size].view(desc.dtype).view((num_caches, *desc.shape))
            remain_pool = remain_pool[:, desc.aligned_size:]
            caches.append(cache)
        return mem_pool, caches

    @staticmethod
    def get_cache_state_size(state_shapes: List[Tuple[Tuple[int], torch.dtype]]) -> int:
        """Get the required cache size of the state cache.

        Args:
            state_shapes (List[Tuple[Tuple[int], torch.dtype]]): The shapes and dtypes of the states.

        Return:
            int: Required memory size in bytes.
        """
        mem_pool, _ = StateCacheEngine.allocate_caches(num_caches=1, state_shapes=state_shapes, device='meta')
        return mem_pool.numel() * mem_pool.element_size()

    @property
    def state_caches(self):
        """State caches."""
        return self._state_caches

    def init_caches(self, idx: torch.Tensor, mask: torch.Tensor):
        """Initialize state caches.

        idx: indices of caches to be initialized.
        mask: mask to indicate which idx to be initialized.
        """
        if idx is None:
            return

        if len(self._state_caches) <= 0:
            return

        num_caches = self.cache_config.num_state_caches

        # get mask of all caches so we can perform inplace mask fill
        cache_masks = torch.zeros((num_caches, ), dtype=torch.bool, device=idx.device)
        cache_masks.index_copy_(0, idx, mask)
        reshaped_mask = cache_masks.view((-1, ) + (1, ) * (self.mem_pool.dim() - 1))
        self.mem_pool.masked_fill_(reshaped_mask, 0)


================================================
FILE: lmdeploy/pytorch/engine/config_builder.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os

from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig
from lmdeploy.pytorch.config import (BackendConfig, CacheConfig, DistConfig, MiscConfig, SchedulerConfig,
                                     SpecDecodeConfig)
from lmdeploy.utils import get_logger, get_max_batch_size, get_model


class ConfigBuilder:

    @staticmethod
    def update_engine_config(engine_config: PytorchEngineConfig):
        """Update pytorch engine config."""
        logger = get_logger('lmdeploy')

        # make sure engine exits
        if engine_config is None:
            engine_config = PytorchEngineConfig()
        else:
            engine_config = copy.deepcopy(engine_config)

        if engine_config.max_batch_size is None:
            engine_config.max_batch_size = get_max_batch_size(engine_config.device_type)

        if engine_config.dllm_block_length is not None:
            max_prefill_token_num = engine_config.max_prefill_token_num
            max_batch_size = engine_config.max_batch_size
            if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num:
                engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length
                logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} '
                               f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '
                               f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')

        if engine_config.dp != 1:
            if engine_config.tp == 1 and engine_config.ep == 1:
                logger.warning('Data parallelism is enabled but tensor parallelism and '
                               'expert parallelism are not enabled. Setting dp=1.')
                engine_config.dp = 1
                engine_config.dp_rank = 0

        return engine_config

    @staticmethod
    def build_scheduler_config(engine_config: PytorchEngineConfig):
        """Build scheduler config."""
        scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size,
                                           max_session_len=engine_config.session_len,
                                           prefill_interval=engine_config.prefill_interval)
        return scheduler_config

    @staticmethod
    def build_cache_config(engine_config: PytorchEngineConfig):
        """Build cache config."""
        cache_config = CacheConfig(
            max_batches=engine_config.max_batch_size,
            block_size=engine_config.block_size,
            num_cpu_blocks=engine_config.num_cpu_blocks,
            num_gpu_blocks=engine_config.num_gpu_blocks,
            cache_max_entry_count=engine_config.cache_max_entry_count,
            max_prefill_token_num=engine_config.max_prefill_token_num,
            enable_prefix_caching=engine_config.enable_prefix_caching,
            quant_policy=engine_config.quant_policy,
            device_type=engine_config.device_type,
            migration_backend=engine_config.migration_backend,
            role=engine_config.role,
            # reserve 1 blocks for dummy input and padding
            num_reserved_gpu_blocks=1)
        return cache_config

    @staticmethod
    def build_backend_config(engine_config: PytorchEngineConfig):
        """Build backend config."""
        backend_config = BackendConfig(
            eager_mode=engine_config.eager_mode,
            device_type=engine_config.device_type,
        )
        return backend_config

    @staticmethod
    def build_dist_config(engine_config: PytorchEngineConfig):
        """Build dist config."""
        dist_config = DistConfig.from_engine_config(engine_config=engine_config)
        return dist_config

    @staticmethod
    def build_misc_config(engine_config: PytorchEngineConfig):
        """Build misc config."""
        misc_config = MiscConfig.from_engine_config(engine_config)
        return misc_config

    @staticmethod
    def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig,
                                cache_config: CacheConfig):
        """Build spec decode config."""
        specdecode_config = None
        if speculative_config is not None:
            draft_model = speculative_config.model
            if draft_model and not os.path.exists(speculative_config.model):
                draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision)

            specdecode_config = SpecDecodeConfig.from_config(
                method=speculative_config.method,
                num_speculative_tokens=speculative_config.num_speculative_tokens,
                model=draft_model,
                target_model=target_model,
                target_cache_cfg=cache_config,
                dtype=engine_config.dtype,
            )
        return specdecode_config


================================================
FILE: lmdeploy/pytorch/engine/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import gc
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig
from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
                                                   DistServeInitRequest)
from lmdeploy.utils import get_logger, get_model

from ..adapter.adapter import AdapterManager
from ..config import CacheConfig, ModelConfig
from ..messages import SchedulerSequence, UpdateTokenMode
from ..paging import Scheduler
from ..strategies import build_strategy_factory
from .base import EngineBase
from .config_builder import ConfigBuilder
from .engine_checker import EngineChecker
from .executor import build_executor
from .request import Request, RequestManager, RequestType, Response

logger = get_logger('lmdeploy')

SeqList = List[SchedulerSequence]


@dataclass
class InferOutput:
    """The output of the model inference."""

    session_id: int
    resp: Response
    token_ids: Union[np.ndarray, List[int]]
    meta: Any = None
    finish: bool = False
    logits: torch.Tensor = None
    logprobs: torch.Tensor = None

    # send cache blocks back for migration in Disaggregated LLM Serving
    # when Prefill Engine is Done.
    cache_block_ids: List[int] = None

    # for logging
    req_metrics: RequestMetrics = None

    # expert ids
    routed_experts: torch.Tensor = None


def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):
    from lmdeploy.pytorch.messages import SequenceMeta

    seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy)
    return seq_meta


def response_reqs(req_manager: RequestManager,
                  resp: Response,
                  resp_type: ResponseType,
                  data: Any = None,
                  err_msg: str = ''):
    """response."""
    if resp.type == ResponseType.FINISH:
        return
    resp.type = resp_type
    resp.data = data
    resp.err_msg = err_msg
    req_manager.response(resp)


class Engine(EngineBase):
    """The inference engine of lmdeploy pytorch.

    Args:
        model_path (str): The hugging face model path.
        engine_config (PytorchEngineConfig): The config of the Engine.
        trust_remote_code (bool): Trust remote code.
    """

    def __init__(
        self,
        model_path: str,
        engine_config: PytorchEngineConfig = None,
        trust_remote_code: bool = True,
        speculative_config: SpeculativeConfig = None,
    ) -> None:
        # make sure engine config exist
        engine_config = ConfigBuilder.update_engine_config(engine_config)

        # frequently gc would cause latency spike
        # default threshold (700, 10, 10)
        # WARNING: I don't know if it is a good idea to put gc setting here.
        gc.set_threshold(10000, 100, 100)

        # dist args
        self.tp = engine_config.tp
        self.dp = engine_config.dp
        self.dp_rank = engine_config.dp_rank

        # download models and adapters
        if not os.path.exists(model_path):
            model_path = get_model(model_path, engine_config.download_dir, engine_config.revision)

        adapters = engine_config.adapters
        if adapters is not None and len(adapters) > 0:
            adapters = self._download_adapters(adapters, engine_config)

        # check environment
        checker = EngineChecker(model_path=model_path,
                                engine_config=engine_config,
                                trust_remote_code=trust_remote_code,
                                logger=logger)
        checker.handle()

        # build configs
        scheduler_config = ConfigBuilder.build_scheduler_config(engine_config)
        cache_config = ConfigBuilder.build_cache_config(engine_config)
        backend_config = ConfigBuilder.build_backend_config(engine_config)
        dist_config = ConfigBuilder.build_dist_config(engine_config)
        misc_config = ConfigBuilder.build_misc_config(engine_config)
        # spec decode
        self.specdecode_config = ConfigBuilder.build_specdecode_config(model_path, speculative_config, engine_config,
                                                                       cache_config)

        # build model agent
        self.executor = build_executor(
            model_path,
            cache_config=cache_config,
            backend_config=backend_config,
            dist_config=dist_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=engine_config.device_type,
            distributed_executor_backend=engine_config.distributed_executor_backend,
            dtype=engine_config.dtype,
            specdecode_config=self.specdecode_config,
        )
        self.executor.init()

        # strategies
        self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config,
                                                       self.specdecode_config)
        self.sampling_strategy = self.strategy_factory.build_sampling_strategy()
        self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy()
        self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config,
                                                                           scheduler_config=scheduler_config)
        self.seq_strategy = self.strategy_factory.build_sequence_strategy()

        self.input_processor = self.executor.get_input_processor()
        cache_config = self.executor.cache_config
        self.adapter_manager = self._build_adapter_manager(adapters)
        self.seq_meta = _build_seq_meta(cache_config,
                                        seq_strategy=self.seq_strategy,
                                        sampling_strategy=self.sampling_strategy)
        self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta)

        # engine args
        self.model_path = model_path
        self.engine_config = engine_config
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.backend_config = backend_config
        self.dist_config = dist_config
        self.misc_config = self.executor.misc_config
        self.max_session_len = self._get_max_session_len()
        self.engine_config.num_cpu_blocks = self.cache_config.num_cpu_blocks
        self.engine_config.num_gpu_blocks = self.cache_config.num_gpu_blocks

        self.req_manager = self._bind_request_manager()

        # create main thread
        self.req_manager.set_main_loop_func(self.async_loop)
        self._loop_main = None

        # for PD Disaggregation
        # For migrating prefill request to decode engine
        self.migration_event: asyncio.Event = None
        # For backpressure prefill request when cache is full
        self.perfill_watermark_event: asyncio.Event = None

        self.engine_conn = EngineP2PConnection(self)

    @classmethod
    def from_pretrained(cls,
                        pretrained_model_name_or_path: str,
                        engine_config: PytorchEngineConfig = None,
                        trust_remote_code: bool = True,
                        speculative_config: SpeculativeConfig = None,
                        **kwargs):
        """Lmdeploy python inference engine.

        Args:
            pretrained_model_name_or_path (str):
                It could be one of the following options:
                    - i) The model_id of a lmdeploy-quantized model hosted
                      inside a model repo on huggingface.co, such as
                      "InternLM/internlm-chat-20b-4bit",
                      "lmdeploy/llama2-chat-70b-4bit", etc.
                    - ii) The model_id of a model hosted inside a model repo
                      on huggingface.co, such as "InternLM/internlm-chat-7b",
                      "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                      and so on.
            engine_config (PytorchEngineConfig): Pytorch engine config.
            trust_remote_code (bool): Trust remote code
        """
        if engine_config is not None and engine_config.enable_mp_engine:
            from .mp_engine import build_mp_engine
            backend = engine_config.mp_engine_backend
            return build_mp_engine(
                backend=backend,
                model_path=pretrained_model_name_or_path,
                engine_config=engine_config,
                trust_remote_code=trust_remote_code,
                speculative_config=speculative_config,
            )
        if len(kwargs) > 0:
            logger.debug(f'Get unexpected kwargs: {kwargs}')
        return cls(
            model_path=pretrained_model_name_or_path,
            engine_config=engine_config,
            trust_remote_code=trust_remote_code,
            speculative_config=speculative_config,
        )

    def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig):
        """Download adapters."""
        download_dir = engine_config.download_dir
        revision = engine_config.revision
        new_adapters = dict()
        for name, path in adapters.items():
            if os.path.exists(path):
                new_adapters[name] = path
                continue
            new_path = get_model(path, download_dir=download_dir, revision=revision)
            new_adapters[name] = new_path

        return new_adapters

    def _build_adapter_manager(self, adapters):
        return AdapterManager(adapters)

    def _bind_request_manager(self):
        """Bind request manager."""
        req_manager = RequestManager()
        req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session)
        req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session)
        req_manager.bind_func(RequestType.END_SESSION, self._on_end_session)
        req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message)
        return req_manager

    def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, err_msg: str = ''):
        """response."""
        return response_reqs(self.req_manager, resp, resp_type, data, err_msg)

    def _get_max_session_len(self):
        """Get max session len."""
        session_len = self.scheduler_config.max_session_len
        num_gpu_blocks = self.cache_config.num_gpu_blocks - self.cache_config.num_reserved_gpu_blocks
        max_tokens = (num_gpu_blocks * self.cache_config.block_size)
        window_size = self.cache_config.window_size
        if window_size > 0 and window_size <= max_tokens:
            max_tokens = (1 << 63) - 1
        max_tokens -= self.cache_config.block_size
        if session_len is None:
            session_len = max_tokens
        else:
            session_len = min(max_tokens, session_len)
        return session_len

    def _on_add_session(self, reqs: List[Request], **kwargs):
        """On add session callback."""
        for req in reqs:
            session_id = req.data['session_id']
            resp = req.data.get('response', True)
            resp_type = ResponseType.SESSION_REPEAT
            if session_id not in self.scheduler.sessions:
                self.scheduler.add_session(session_id)
                resp_type = ResponseType.SUCCESS
            if resp:
                self._response(req.resp, resp_type)

    def _on_stop_session(self, reqs: List[Request], **kwargs):
        """On stop session callback."""
        for req in reqs:
            session_id = req.data['session_id']
            resp = req.data.get('response', True)
            resp_type = ResponseType.SESSION_NOT_EXIST
            if session_id in self.scheduler.sessions:
                self.scheduler.stop_session(session_id)
                session = self.scheduler.sessions[session_id]
                for seq in session.sequences.values():
                    _resp: Response = getattr(seq, 'resp', None)
                    if _resp is not None:
                        _resp.type = ResponseType.CANCEL
                        _resp.is_done = True
                        self.req_manager.response(_resp)
                resp_type = ResponseType.SUCCESS
            if resp:
                self._response(req.resp, resp_type)

    def _on_end_session(self, reqs: List[Request], **kwargs):
        """On end session callback."""
        for req in reqs:
            session_id = req.data['session_id']
            resp = req.data.get('response', True)
            resp_type = ResponseType.SESSION_NOT_EXIST
            if session_id in self.scheduler.sessions:
                msgs = list(self.scheduler.sessions[session_id].sequences.values())
                if len(msgs) > 0 and msgs[0].preserve_cache:
                    msgs[0].state.finish()
                else:
                    self.end_session(session_id)
                resp_type = ResponseType.SUCCESS
            if resp:
                self._response(req.resp, resp_type)

    def _on_add_message(self, reqs: List[Request], **kwargs):
        """On add message callback."""
        valid_reqs = []
        for req in reqs:
            req_data = req.data
            session_id = req_data['session_id']
            if self.scheduler and session_id not in self.scheduler.sessions:
                self._response(req.resp, ResponseType.SESSION_NOT_EXIST)
                continue
            valid_reqs.append(req)
            if req_data.get('input_multimodals', None) is None:
                continue
            elif self.input_processor is None:
                logger.warning('Do not support Multimodal inputs.')
                continue
            input_ids = req_data['token_ids']
            input_multimodals = req_data['input_multimodals']
            if len(input_multimodals) == 0:
                req_data['input_multimodals'] = None
                continue

            if self.engine_config.disable_vision_encoder:
                # ignore multimodal inputs
                req_data['input_multimodals'] = None
                logger.warning('Vision encoder has not been loaded, multimodal inputs will be ignored.')
                continue

            result = self.input_processor.preprocess_input(input_ids, input_multimodals)

            input_ids = result.input_ids
            input_multimodals = result.input_multimodals

            req_data['token_ids'] = input_ids
            req_data['input_multimodals'] = input_multimodals

        if len(valid_reqs) > 0:
            self._add_message(valid_reqs)

    def _add_message(self, reqs: List[Request]):

        def __update_max_new_tokens(msg):
            """Update max new tokens."""
            max_session_len = self.max_session_len
            sampling_param = msg.sampling_param
            max_new_tokens = sampling_param.max_new_tokens
            num_all_tokens = msg.num_valid_ids
            if self.engine_config.role == EngineRole.Prefill:
                sampling_param.max_new_tokens = 1
            elif max_new_tokens + num_all_tokens > max_session_len:
                logger.warning(
                    f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. '
                    f'Update max_new_tokens={max_session_len - num_all_tokens}.')
                sampling_param.max_new_tokens = max_session_len - num_all_tokens

        scheduler = self.scheduler
        for req in reqs:
            session_id = req.data['session_id']
            sess = scheduler.sessions.get(session_id, None)
            if sess is None:
                self._response(req.resp, ResponseType.SESSION_NOT_EXIST)
                continue
            # TODO: support 1 session n sequence
            sampling_param = req.data['sampling_param']
            if len(sess.sequences) == 0:
                migration_request = req.data.get('migration_request')
                assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.')
                sess.add_sequence(req.data['token_ids'],
                                  sampling_param=sampling_param,
                                  adapter_name=req.data['adapter_name'],
                                  multimodals=req.data.get('input_multimodals'),
                                  input_embeddings=req.data.get('input_embeddings', ),
                                  migration_request=migration_request,
                                  resp_cache=req.data.get('with_cache'),
                                  preserve_cache=req.data.get('preserve_cache'))
                msg = next(iter(sess.sequences.values()))
                if migration_request:
                    self.migration_event.set()
            else:
                msg = next(iter(sess.sequences.values()))
                msg.update_token_ids(
                    req.data['token_ids'],
                    multimodals=req.data.get('input_multimodals'),
                    embeddings=req.data.get('input_embeddings'),
                    mode=UpdateTokenMode.INPUTS,
                )
                msg.sampling_param = sampling_param
                msg.state.activate()

            __update_max_new_tokens(msg)
            msg.resp = req.resp

    @property
    def model_config(self) -> ModelConfig:
        """Model config."""
        return self.executor.model_config

    def p2p_initialize(self, init_request: DistServeInitRequest):
        return self.engine_conn.p2p_initialize(init_request)

    def p2p_connect(self, conn_request: DistServeConnectionRequest):
        return self.engine_conn.p2p_connect(conn_request)

    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
        return self.engine_conn.p2p_drop_connect(drop_conn_request)

    def _loop_finally(self):
        """Finally process for dist."""
        logger.info('Cleanup executor.')
        self.migration_event = None
        self.executor.release()

    def update_params(self, request: Any):
        """Update params."""
        self.executor.update_params(request)

    def sleep(self, level: int = 1):
        """Sleep."""
        self.executor.sleep(level)

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        self.executor.wakeup(tags)

    async def async_loop(self):
        engine_loop = None
        try:
            from lmdeploy.pytorch.engine.engine_loop import build_engine_loop
            self._loop_main = asyncio.current_task()
            event_loop = asyncio.get_event_loop()

            # create engine loop
            engine_loop = build_engine_loop(self)
            self.migration_event = engine_loop.migration_event

            # start engine loop
            engine_loop.start(event_loop)
            await engine_loop.wait_tasks()
        except asyncio.CancelledError:
            logger.info('Engine main loop cancelled.')
            raise
        except BaseException:
            # since AsyncEngine will not wait for engine loop
            # we have to log it here.
            logger.exception('Engine main loop failed.')
            raise
        finally:
            logger.debug('Engine main loop finally cleanup.')
            if engine_loop is not None:
                engine_loop.stop()
            self._loop_finally()

    def close(self):
        if self.executor.device_type == 'cuda':
            # https://discuss.pytorch.org/t/how-to-delete-a-tensor-in-gpu-to-free-up-memory/48879/32
            # W/O this, repeatedly rebuilding and destroying engines within the same process
            # will cause more and more reserved CUDA memory.
            torch._C._cuda_clearCublasWorkspaces()
        if self._loop_main is not None:
            self._loop_main.cancel()
        else:
            self._loop_finally()

    def start(self):
        """Start engine loop tasks."""
        if self.req_manager.is_loop_alive():
            return True
        self.req_manager.create_loop_task()
        return True

    def stop(self):
        """Stop engine loop tasks."""
        if self._loop_main is not None:
            self._loop_main.cancel()

    async def wait_tasks(self):
        """Wait async tasks to finish."""
        if self._loop_main is None:
            logger.warning('No engine main loop to wait for.')
            return

        try:
            # await self._loop_main
            await self.req_manager.wait_tasks()
        except asyncio.CancelledError:
            logger.info('Engine main loop cancelled in wait_tasks.')
            raise

    def create_instance(self, cuda_stream_id=0):
        """Create a pytorch engine instance.

        Args:
            cuda_stream_id(int): identity of a cuda stream
        Returns:
            EngineInstance: an instance of pytorch engine
        """
        from .engine_instance import EngineInstance
        return EngineInstance(self)

    def start_loop(self):
        """Alias of start, API for AsyncEngine."""
        return self.start()

    def end_session(self, session_id: int):
        """End session."""
        if session_id in self.scheduler.sessions:
            self.scheduler.end_session(session_id)
            return True
        return False

    def get_engine_config(self):
        return self.engine_config

    def get_schedule_metrics(self):
        return self.scheduler.schedule_metrics


================================================
FILE: lmdeploy/pytorch/engine/engine_checker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.messages import PytorchEngineConfig

from ..check_env.adapter import AdapterChecker
from ..check_env.base import BaseChecker
from ..check_env.dist import DistChecker
from ..check_env.model import ModelChecker
from ..check_env.torch import TorchChecker
from ..check_env.transformers import TransformersChecker


class EngineChecker(BaseChecker):
    """Check transformers is available."""

    def __init__(self,
                 model_path: str,
                 engine_config: PytorchEngineConfig,
                 trust_remote_code: bool = True,
                 logger=None):
        super().__init__(logger)
        logger = self.get_logger()

        self.engine_config = engine_config

        dtype = engine_config.dtype
        device_type = engine_config.device_type

        # pytorch
        torch_checker = TorchChecker(logger=logger)

        if device_type == 'cuda':
            # triton
            from ..check_env.cuda import CudaChecker
            from ..check_env.triton import TritonChecker
            cuda_checker = CudaChecker(model_format=engine_config.model_format, logger=logger)
            cuda_checker.register_required_checker(torch_checker)
            triton_checker = TritonChecker(logger=logger)
            triton_checker.register_required_checker(cuda_checker)
            self.register_required_checker(triton_checker)
        else:
            # deeplink
            from ..check_env.deeplink import DeeplinkChecker
            dl_checker = DeeplinkChecker(device_type, logger=logger)
            self.register_required_checker(dl_checker)
            self.register_required_checker(torch_checker)

        # transformers

        # model
        trans_checker = TransformersChecker()
        model_checker = ModelChecker(model_path=model_path,
                                     trust_remote_code=trust_remote_code,
                                     dtype=dtype,
                                     device_type=device_type,
                                     logger=logger)
        model_checker.register_required_checker(torch_checker)
        model_checker.register_required_checker(trans_checker)
        self.register_required_checker(model_checker)

        # adapters
        adapters = engine_config.adapters
        if adapters is not None:
            adapter_paths = list(adapters.values())
            for adapter in adapter_paths:
                adapter_checker = AdapterChecker(adapter, logger=logger)
                self.register_required_checker(adapter_checker)

        # dist
        dist_checker = DistChecker(engine_config.tp,
                                   engine_config.dp,
                                   engine_config.ep,
                                   engine_config.distributed_executor_backend,
                                   device_type=engine_config.device_type,
                                   logger=logger)
        self.register_required_checker(dist_checker)

    def check(self):
        """check."""
        engine_config = self.engine_config

        if engine_config.thread_safe:
            self.log_and_exit(
                mod_name='Engine',
                message='thread safe mode is no longer supported.\n'
                'Read https://github.com/InternLM/lmdeploy/blob/main/docs/en/advance/pytorch_multithread.md for more details.',  # noqa: E501
            )

        if engine_config.max_batch_size <= 0:
            self.log_and_exit(mod_name='Engine',
                              message='max_batch_size should be'
                              f' greater than 0, but got {engine_config.max_batch_size}')

        num_gpu_blocks = engine_config.num_gpu_blocks
        if num_gpu_blocks > 0 and num_gpu_blocks < 16:
            self.log_and_exit(mod_name='Engine',
                              message='num_gpu_blocks should be greater than 16, '
                              f'but got {num_gpu_blocks}. Set num_gpu_blocks to 0 to automatically '
                              'determine the number of GPU blocks based on the model size and device memory.')

    def _handle_impl(self):
        return super().handle()

    def handle(self):
        import multiprocessing as mp
        from concurrent.futures import ProcessPoolExecutor

        from lmdeploy.pytorch import envs
        if not envs.enable_check_env:
            return

        current_proc = mp.current_process()
        if not current_proc.daemon and self.engine_config.device_type == 'cuda':
            mp_ctx = mp.get_context('spawn')
            with ProcessPoolExecutor(mp_context=mp_ctx) as executor:
                try:
                    executor.submit(self._handle_impl).result()
                except SystemExit:
                    exit(1)
                except BaseException as e:
                    self.log_and_exit(e, mod_name='Engine')
        else:
            return self._handle_impl()


================================================
FILE: lmdeploy/pytorch/engine/engine_instance.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List

from lmdeploy.messages import EngineOutput, GenerationConfig
from lmdeploy.utils import get_logger

from ..messages import SamplingParam
from .base import EngineInstanceBase
from .engine import Engine
from .request import RequestSender, RequestType, Response, ResponseType

logger = get_logger('lmdeploy')

InputMultiModalType = List[Dict[str, Any]]


def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None):
    """Check if response has state."""
    if isinstance(state, ResponseType):
        state = [state]
    ret = resp.type in state
    if not ret and warning_msg is not None:
        logger.warning(warning_msg)
    return ret


def _check_resp_success(resp: Response, warning_msg: str = None):
    """Check if response success."""
    return _check_resp(resp, ResponseType.SUCCESS, warning_msg)


async def async_try_add_session(req_sender: RequestSender, session_id: int):
    """Add new session.

    Args:
        session_id (int): The session id to add.
    """
    resp = await req_sender.async_send(RequestType.ADD_SESSION, dict(session_id=session_id))
    _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], (f'Can not add session {session_id} '
                                                                            f'with error: {resp.type}'))


async def async_cancel(req_sender: RequestSender, session_id: int):
    """Stop current streaming inference."""
    resp = await req_sender.async_send(RequestType.STOP_SESSION, dict(session_id=session_id))
    _check_resp_success(resp, (f'Failed to cancel session: {session_id}. '
                               f'Error: {resp.type}.'))


def try_add_session(req_sender: RequestSender, session_id: int):
    """Add new session.

    Args:
        session_id (int): The session id to add.
    """
    resp = req_sender.send(RequestType.ADD_SESSION, dict(session_id=session_id))
    _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], (f'Can not add session {session_id} '
                                                                            f'with error: {resp.type}'))


def end(req_sender: RequestSender, session_id: int):
    """End the given session."""
    logger.debug(f'session[{session_id}] try end session.')
    req_sender.send_async(RequestType.END_SESSION, dict(session_id=session_id, response=False))


def cancel(req_sender: RequestSender, session_id: int):
    """Stop current streaming inference."""
    logger.debug(f'session[{session_id}] try end session.')
    resp = req_sender.send(RequestType.STOP_SESSION, dict(session_id=session_id))
    _check_resp_success(resp, (f'Failed to cancel session: {session_id}. '
                               f'Error: {resp.type}.'))


class EngineInstance(EngineInstanceBase):
    """Instance of TurboMind.

    Args:
        engine (Engine): engine
    """

    def __init__(self, engine: Engine):
        self.engine = engine
        self.req_sender = engine.req_manager.build_sender()

        self.max_input_len = self.engine.max_session_len
        self._enable_transfer_obj_ref = engine.engine_config.enable_transfer_obj_ref and \
            engine.engine_config.distributed_executor_backend == 'ray'

    def __del__(self):
        """Destructor."""
        self.engine.req_manager.senders.pop(self.req_sender.sender_id)

    def _get_extra_outputs(self, resp: Response):
        """Get extra outputs."""
        outputs = dict(routed_experts=None)
        routed_experts = resp.data.get('routed_experts', None) if resp.data else None
        if routed_experts is not None and resp.type in [ResponseType.FINISH, ResponseType.CANCEL]:
            if self._enable_transfer_obj_ref:
                import pybase64
                import ray

                ref = ray.put(routed_experts)
                data = ray.cloudpickle.dumps(ref)
                outputs['routed_experts'] = pybase64.b64encode(data).decode('utf-8')
            else:
                outputs['routed_experts'] = routed_experts
        return outputs

    async def _async_try_add_session(self, session_id: int):
        """Add new session.

        Args:
            session_id (int): The session id to add.
        """
        return await async_try_add_session(self.req_sender, session_id)

    def _try_add_session(self, session_id: int):
        """Add new session.

        Args:
            session_id (int): The session id to add.
        """
        return try_add_session(self.req_sender, session_id)

    async def async_stream_infer(self,
                                 session_id: int,
                                 input_ids: List[int],
                                 gen_config: GenerationConfig = None,
                                 multimodal: InputMultiModalType = None,
                                 adapter_name: str = None,
                                 **kwargs):
        """Send stream inference request.

        Args:
            session_id (int): The session id.
            input_ids (List[int]): The input token ids.
            gen_config (GenerationConfig): The sampling parameters.
            adapter_name (str): The lora adapter name.

        Yields:
            int: Error flags. 0 if success.
            List[int]: The streaming output tokens.
            int: The number of the output tokens.
        """
        if len(input_ids) > self.max_input_len:
            yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [])
            return
        gen_config = gen_config or GenerationConfig()
        sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
        logger.debug(f'session[{session_id}] try add session.')
        self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False))
        msg = dict(
            token_ids=input_ids,
            session_id=session_id,
            sampling_param=sampling_param,
            adapter_name=adapter_name,
            input_multimodals=multimodal,
            migration_request=gen_config.migration_request,
            with_cache=gen_config.with_cache,
            preserve_cache=gen_config.preserve_cache,
        )
        logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.')
        resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
        output_offset = 0

        while True:
            resp = await self.req_sender.async_recv(resp, wait_main=True)

            cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
            req_metrics = resp.data.get('req_metrics', None) if resp.data else None
            logprobs = resp.data.pop('logprobs', None) if resp.data else None
            extra_outputs = self._get_extra_outputs(resp)
            routed_experts = extra_outputs.get('routed_experts', None)

            if resp.type == ResponseType.SUCCESS:
                token_ids = resp.data['token_ids']
                num_ids = len(token_ids) - output_offset
                logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.')
                yield EngineOutput(resp.type,
                                   token_ids[output_offset:].tolist(),
                                   cache_block_ids=cache_block_ids,
                                   req_metrics=req_metrics,
                                   routed_experts=routed_experts,
                                   logprobs=logprobs)
                output_offset = len(token_ids)
            elif resp.type in (ResponseType.FINISH, ResponseType.CANCEL):
                resp_data = resp.data
                if resp_data is None:
                    # request might be cancelled before any output
                    token_ids = []
                    logits = None
                else:
                    token_ids = resp_data['token_ids'][output_offset:].tolist()
                    logits = resp_data.get('logits', None)
                num_ids = len(token_ids) - output_offset
                logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
                yield EngineOutput(resp.type,
                                   token_ids,
                                   logits=logits,
                                   cache_block_ids=cache_block_ids,
                                   req_metrics=req_metrics,
                                   routed_experts=routed_experts,
                                   logprobs=logprobs)
                break
            else:
                logger.debug(f'session[{session_id}] failed.')
                yield EngineOutput(resp.type, [])
                break

    async def async_infer(self,
                          session_id: int,
                          input_ids: List[int] = None,
                          multimodal: InputMultiModalType = None,
                          gen_config: GenerationConfig = None,
                          **kwargs):
        """Send inference request.

        Args:
            session_id (int): The session id.
            input_ids (List[int]): The input token ids.
            gen_config (GenerationConfig): The sampling parameters.

        Returns:
            int: Error flags. 0 if success.
            List[int]: The streaming output tokens.
            int: The number of the output tokens.
        """
        async for outputs in self.async_stream_infer(session_id,
                                                     input_ids,
                                                     multimodal=multimodal,
                                                     gen_config=gen_config,
                                                     **kwargs):
            status = outputs.status
            if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
                return outputs

        return outputs

    def stream_infer(self,
                     session_id: int,
                     input_ids: List[int],
                     multimodal: InputMultiModalType = None,
                     gen_config: GenerationConfig = None,
                     adapter_name: str = None,
                     **kwargs):
        """Send stream inference request.

        Args:
            session_id (int): The session id.
            input_ids (List[int]): The input token ids.
            gen_config (GenerationConfig): The sampling parameters.
            adapter_name (str): The lora adapter name.

        Yields:
            int: Error flags. 0 if success.
            List[int]: The streaming output tokens.
            int: The number of the output tokens.
        """

        def __call_async():
            """Call async."""
            coro_gen = self.async_stream_infer(session_id,
                                               input_ids,
                                               multimodal=multimodal,
                                               gen_config=gen_config,
                                               adapter_name=adapter_name,
                                               **kwargs)
            while True:
                try:
                    yield self.req_sender.run_until_complete(coro_gen.__anext__())
                except StopAsyncIteration:
                    break

        yield from __call_async()

    def infer(self,
              session_id: int,
              input_ids: List[int] = None,
              multimodal: InputMultiModalType = None,
              gen_config: GenerationConfig = None,
              **kwargs):
        """Send inference request.

        Args:
            session_id (int): The session id.
            input_ids (List[int]): The input token ids.
            gen_config (GenerationConfig): The sampling parameters.

        Returns:
            int: Error flags. 0 if success.
            List[int]: The streaming output tokens.
            int: The number of the output tokens.
        """
        return self.req_sender.run_until_complete(
            self.async_infer(session_id, input_ids, multimodal=multimodal, gen_config=gen_config, **kwargs))

    async def async_end(self, session_id: int):
        """End the given session."""
        return end(self.req_sender, session_id)

    def end(self, session_id: int):
        """End the given session."""
        return end(self.req_sender, session_id)

    async def async_cancel(self, session_id: int):
        """Stop current streaming inference."""
        return await async_cancel(self.req_sender, session_id)

    def cancel(self, session_id: int):
        """Stop current streaming inference."""
        return cancel(self.req_sender, session_id)


================================================
FILE: lmdeploy/pytorch/engine/engine_loop.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

import numpy as np
import torch
from torch.profiler import record_function

from lmdeploy.messages import RequestMetrics
from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.messages import MessageStatus, UpdateTokenMode
from lmdeploy.pytorch.utils import cancel_async_tasks, wait_for_async_tasks
from lmdeploy.utils import get_logger

from .engine import InferOutput, ResponseType, response_reqs

if TYPE_CHECKING:
    from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection
    from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta
    from lmdeploy.pytorch.paging import Scheduler
    from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy

    from .engine import Engine, SeqList
    from .executor import ExecutorBase
    from .inputs_maker import InputsMakerAsync
    from .request import RequestManager

logger = get_logger('lmdeploy')
_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64)


class CounterEvent(asyncio.Event):

    def __init__(self):
        super().__init__()
        self._counter = 0

    def set(self):
        if self._counter > 0:
            self._counter -= 1
        if self._counter == 0:
            super().set()

    def clear(self):
        if self._counter == 0 and super().is_set():
            super().clear()
        self._counter += 1


class RunableEventAsync:
    """Awaitable async runable event."""

    def __init__(self, scheduler: 'Scheduler'):
        self.scheduler = scheduler
        self.event = asyncio.Event()

    async def wait(self):
        """Wait event."""
        await self.event.wait()

    def set(self):
        """Set event."""
        if self.scheduler.has_unfinished():
            self.event.set()
        else:
            self.event.clear()


def build_runable_event(scheduler: 'Scheduler'):
    """Build runable event."""
    return RunableEventAsync(scheduler)


@dataclass
class EngineLoopConfig:
    """Engine loop config.

    This config is added for Dependency Injection
    """
    role: EngineRole
    num_speculative_tokens: Optional[int] = None
    enable_metrics: bool = False
    enable_transfer_obj_ref: bool = False

    @staticmethod
    def from_engine(engine: 'Engine'):
        """Create engine loop config from engine."""
        if engine.specdecode_config is None:
            num_speculative_tokens = None
        else:
            num_speculative_tokens = engine.specdecode_config.num_speculative_tokens

        return EngineLoopConfig(
            role=engine.engine_config.role,
            num_speculative_tokens=num_speculative_tokens,
            enable_metrics=engine.engine_config.enable_metrics,
            enable_transfer_obj_ref=engine.engine_config.enable_transfer_obj_ref,
        )


class EngineLoop:
    """Engine loop manager should be created in an async context."""

    def __init__(self,
                 req_manager: 'RequestManager',
                 scheduler: 'Scheduler',
                 executor: 'ExecutorBase',
                 seq_strategy: 'SequenceStrategy',
                 inputs_maker: 'InputsMakerAsync',
                 config: EngineLoopConfig,
                 engine_conn: Optional['EngineP2PConnection'] = None):
        self.req_manager = req_manager
        self.scheduler = scheduler
        self.executor = executor
        self.seq_strategy = seq_strategy
        self.inputs_maker = inputs_maker
        self.config = config
        self.engine_conn = engine_conn

        # tasks and control events
        self.tasks: Set[asyncio.Task] = set()
        self.stop_event = asyncio.Event()
        self.resp_queue = asyncio.Queue()
        self.forward_event = CounterEvent()
        self.migration_event = asyncio.Event()
        self.has_runable_event = RunableEventAsync(self.scheduler)

        # check init
        if self.config.role != EngineRole.Hybrid:
            assert self.engine_conn is not None, 'Engine connection must be provided for non-hybrid engine role.'

    async def preprocess_loop(self):
        """Preprocess request."""
        while not self.stop_event.is_set():
            await self.req_manager.step()
            self.has_runable_event.set()

    @staticmethod
    def _log_resps(outputs: List[InferOutput]):
        """Log resps."""
        if logger.level <= logging.DEBUG:
            session_ids = [out.session_id for out in outputs]
            logger.debug(f'Response sessions: {session_ids}')
            logger.debug(f'Response: num_outputs={len(outputs)}.')

    def _send_resp(self, out: InferOutput):
        """Send response."""
        # skip cancelled response
        if out.resp.is_done:
            return
        resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
        logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None)
        response_reqs(self.req_manager,
                      out.resp,
                      resp_type,
                      data=dict(token_ids=out.token_ids,
                                logits=out.logits,
                                cache_block_ids=out.cache_block_ids,
                                req_metrics=out.req_metrics,
                                routed_experts=out.routed_experts,
                                logprobs=logprobs))

    @staticmethod
    def _update_logprobs(step_outputs: List[InferOutput]):
        for out in step_outputs:
            cur_logprobs = out.logprobs
            if cur_logprobs is None:
                continue

            if out.resp.data is None:
                out.resp.data = dict()
            out.resp.data.setdefault('logprobs', [])

            # logprobs to dict
            vals = cur_logprobs[0]
            indices = cur_logprobs[1]
            cur_logprobs = dict(zip(indices, vals))
            logprobs = out.resp.data['logprobs']
            logprobs.append(cur_logprobs)

    def _send_resps(self, step_outputs: List[InferOutput]):
        """Send response callback."""
        self._log_resps(step_outputs)
        self._update_logprobs(step_outputs)

        is_done = set()
        for out in reversed(step_outputs):
            if out.session_id in is_done:
                continue
            is_done.add(out.session_id)
            self._send_resp(out)

    async def send_response_loop(self):
        """Send response to client."""
        que = self.resp_queue
        while not self.stop_event.is_set():
            num_outs = que.qsize()
            if num_outs > 0:
                resps = []
                for _ in range(num_outs):
                    resps += que.get_nowait().values()
            else:
                resps = (await que.get()).values()
            self._send_resps(resps)

    @record_function('make_infer_outputs')
    def _make_infer_outputs(
        self,
        batched_outputs: 'BatchedOutputs',
        running: 'SeqList',
        model_inputs: 'ModelInputs',
        delta: 'ModelInputsDelta',
    ):
        """Make infer output."""

        def __get_logit(msg, logits: torch.Tensor, seq_length: List[int], idx: int):
            logit = logits.split(seq_length)[idx]
            if len(msg.all_logits) > 0:
                # for chunked long context
                msg.append_logits(logit)
                logit = msg.logits
                msg.all_logits.resize(0)

            return logit

        logits = batched_outputs.logits
        all_routed_experts = batched_outputs.all_routed_experts

        if model_inputs is not None and model_inputs.is_chunk:
            # chunk long context does not need to update seqs and outputs
            seq = running[0]
            seq.append_routed_experts(all_routed_experts)
            seq.append_logits(logits)
            return dict()

        new_token_timestamp = batched_outputs.new_token_timestamp
        logprobs = batched_outputs.logprobs

        if logprobs is not None:
            logprobs.vals = logprobs.vals.tolist()
            logprobs.indices = logprobs.indices.tolist()

        seq_length = [seq.num_token_ids for seq in running]
        is_run = [seq.status == MessageStatus.RUNNING for seq in running]
        self.seq_strategy.update_running(running=running,
                                         batched_outputs=batched_outputs,
                                         model_inputs=model_inputs,
                                         delta=delta)

        # generate output
        outputs: Dict[int, InferOutput] = dict()
        for idx, msg in enumerate(running):
            if not is_run[idx]:
                continue
            token_ids = msg.generated_ids
            finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED
            if not finish and len(token_ids) == 0:
                continue
            resp_data = msg.resp.data
            if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids):
                # no new tokens
                continue
            session_id = msg.session_id
            if msg.resp_cache:
                cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist()
            else:
                cache_block_ids = None

            # logprobs
            num_logprobs = msg.sampling_param.num_logprobs
            cur_logprobs = None
            if logprobs is not None and num_logprobs > 0:
                cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])
            # get spec stats info
            spec_info = None
            num_draft_tokens = self.config.num_speculative_tokens
            if num_draft_tokens is not None and model_inputs is None and self.config.enable_metrics:
                num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1
                spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens.item())
            req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info)
            out = InferOutput(session_id=session_id,
                              resp=msg.resp,
                              finish=finish,
                              token_ids=token_ids,
                              cache_block_ids=cache_block_ids,
                              req_metrics=req_metrics,
                              logprobs=cur_logprobs,
                              routed_experts=msg.routed_experts)
            outputs[session_id] = out

            if msg.return_logits:
                logit = __get_logit(msg, logits, seq_length, idx)
                outputs[session_id].logits = logit
        return outputs

    async def _main_loop_try_send_next_inputs(self):
        """Try send next inputs."""
        scheduler = self.scheduler
        if not scheduler.has_unfinished():
            await self.has_runable_event.wait()

        scheduler.collect_migration_done()
        return await self.inputs_maker.send_next_inputs()

    async def _main_loop_get_outputs(
        self,
        running: 'SeqList',
        forward_inputs: Dict[str, Any],
    ):
        """Get outputs and prefetch."""
        model_inputs = forward_inputs['inputs']
        delta = forward_inputs['delta']
        self.inputs_maker.update_running_seqs(running, model_inputs)

        # try prefetch inputs
        self.scheduler.collect_migration_done()
        forward_inputs, next_running = await self.inputs_maker.prefetch_next_inputs()

        # send output
        out = await self.executor.get_output_async()
        if out is not None:
            step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta)
            self.resp_queue.put_nowait(step_outputs)

        return forward_inputs, next_running

    async def main_loop(self):
        """Main loop of the engine.

        Each engine instance would communicate with the engine by queue.
        """
        has_runable_event = self.has_runable_event
        scheduler = self.scheduler
        forward_inputs = None
        next_running = None

        async def __no_running_warning():
            # TODO (JimyMa): add watermark check event instead of async sleep.
            # self.perfill_watermark_event.wait()
            logger.warning(f'no next prefill running request, Maybe cache is full, '
                           f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, '
                           f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}')
            await asyncio.sleep(0.1)

        while not self.stop_event.is_set():
            if next_running is None:
                forward_inputs, next_running = await self._main_loop_try_send_next_inputs()
                if next_running is None:
                    await __no_running_warning()
                    continue

            scheduler.activate_seqs(next_running)
            forward_inputs, next_running = await self._main_loop_get_outputs(
                running=next_running,
                forward_inputs=forward_inputs,
            )
            self.inputs_maker.deactivate_evict_seqs()
            has_runable_event.set()

    def update_running_migration(self, running: 'SeqList', next_token_ids: np.ndarray, stopped: torch.Tensor,
                                 model_metas: List[Dict[str, Any]]):
        """Update scheduler."""
        if model_metas is None:
            model_metas = [None] * len(running)
        for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas):
            if msg.status != MessageStatus.MIGRATION_RUNNING:
                continue
            update_token = token

            # fill token
            msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)
            if stop:
                update_token = _EMPTY_TOKEN
                msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)
                msg.state.finish()

    async def _migration_loop_migrate(self, migration_ready: 'SeqList'):
        """Migration loop migrate."""
        for msg in migration_ready:
            # skip dummy prefill migration
            if msg.migration_request.is_dummy_prefill:
                continue

            migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = []
            migration_request = msg.migration_request
            prefill_block_ids = migration_request.remote_block_ids
            decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg))

            assert len(prefill_block_ids) == len(decode_block_ids), (
                f'#prefill block ids ({len(prefill_block_ids)}) must equal to '
                f'#decode block ids ({len(decode_block_ids)})'
                f'all id length: {msg.num_token_ids}')
            migration_execution_requests.append((
                migration_request.remote_engine_id,
                list(zip(prefill_block_ids, decode_block_ids)),
            ))
            migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol,
                                                       requests=migration_execution_requests)
            logger.info(f'migrating session: {msg.session_id} begin')
            await self.executor.migrate(migration_inputs)
            logger.info(f'migrating session: {msg.session_id} done')
            await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id,
                                            remote_session_id=migration_request.remote_session_id)

    async def _migration_loop_get_outputs(self, migration_ready: 'SeqList'):
        """Migration loop get outputs."""
        outputs: Dict[int, InferOutput] = dict()
        for _, msg in enumerate(migration_ready):
            session_id = msg.session_id
            msg.resp.type = ResponseType.SUCCESS
            token_ids = [msg.migration_request.remote_token_id]
            # MUST be a wall-clock time
            new_token_timestamp = time.time()
            req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
            out = InferOutput(
                session_id=session_id,
                resp=msg.resp,
                finish=False,
                token_ids=np.array(token_ids),
                req_metrics=req_metrics,
            )
            outputs[session_id] = out
            self.update_running_migration([msg], np.array([token_ids]), [False], [None])
        self.resp_queue.put_nowait(outputs)

    async def _migration_loop_process_ready(self, migration_ready: 'SeqList'):
        """Process migration ready."""
        await self._migration_loop_migrate(migration_ready)

        # generate output
        with self.scheduler.seqs_migration_activation(migration_ready):
            await self._migration_loop_get_outputs(migration_ready)
        self.has_runable_event.set()

    async def migration_loop(self):
        """Async loop migration."""
        while not self.stop_event.is_set():
            migration_ready = self.scheduler._schedule_migration()
            if not migration_ready and not self.scheduler.has_migration_waiting():
                await self.migration_event.wait()
            elif migration_ready:
                self.migration_event.clear()
                await self._migration_loop_process_ready(migration_ready)
            else:
                # release coroutine for decoding
                await asyncio.sleep(.5)

    def start(self, event_loop: asyncio.AbstractEventLoop):
        """Create async tasks."""
        # start executor
        logger.info('Starting executor.')
        self.executor.start(self.forward_event)
        # start owned loops
        self.tasks.add(event_loop.create_task(self.executor.wait_tasks(), name='MainLoopWaitExecutor'))
        logger.info('Starting async task MainLoopPreprocessMessage.')
        self.tasks.add(event_loop.create_task(self.preprocess_loop(), name='MainLoopPreprocessMessage'))
        logger.info('Starting async task MainLoopResponse.')
        self.tasks.add(event_loop.create_task(self.send_response_loop(), name='MainLoopSendResponse'))
        logger.info('Starting async task MainLoop.')
        self.tasks.add(event_loop.create_task(self.main_loop(), name='MainLoopMain'))
        if self.config.role != EngineRole.Hybrid:
            logger.info('Starting async task MigrationLoop.')
            self.tasks.add(event_loop.create_task(self.migration_loop(), name='MainLoopMigration'))

        for task in self.tasks:
            task.add_done_callback(self.tasks.discard)

    async def wait_tasks(self):
        """Wait for all tasks to finish."""
        if not self.tasks:
            return

        # copy the tasks so callback of tasks would not update it
        tasks = self.tasks.copy()
        try:
            await wait_for_async_tasks(tasks)
        except asyncio.CancelledError:
            logger.info('EngineLoop wait_tasks cancelled.')
            raise
        except BaseException:
            logger.error('EngineLoop wait_tasks failed.')
            raise
        finally:
            logger.debug('EngineLoop wait_tasks cleanup.')
            # Make sure task finished/cancelled here.
            # Error might happen if executor release before executor wait_tasks finish.
            await cancel_async_tasks(tasks)

    def stop(self):
        """Stop all loops."""
        if self.stop_event.is_set():
            # Already stopped, avoid calling executor.stop() multiple times
            return
        self.executor.stop()
        self.stop_event.set()
        self.cancel()

    def cancel(self):
        """Cancel all loops."""
        for task in self.tasks:
            if not task.done():
                task.cancel()
        self.tasks.clear()


def build_engine_loop(engine: 'Engine'):
    """Build engine loop."""
    from .inputs_maker import build_inputs_maker

    config = EngineLoopConfig.from_engine(engine)
    inputs_maker = build_inputs_maker(engine)
    return EngineLoop(
        req_manager=engine.req_manager,
        scheduler=engine.scheduler,
        executor=engine.executor,
        seq_strategy=engine.seq_strategy,
        inputs_maker=inputs_maker,
        config=config,
        engine_conn=engine.engine_conn,
    )


================================================
FILE: lmdeploy/pytorch/engine/executor/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from logging import Logger
from typing import Dict

from lmdeploy.pytorch import envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.utils import get_logger

from .base import ExecutorBase


def get_distributed_executor_backend(world_size: int, dp: int, device_type: str, logger: Logger = None):
    """Get distributed executor backend."""
    from lmdeploy.pytorch.backends import get_backend

    def _log_info(message: str):
        if logger is not None:
            logger.info(message)

    def _log_and_set_backend(message: str, executor_backend: str):
        """Log and set backend."""
        message += f' distributed_executor_backend={executor_backend}.'
        _log_info(message)
        return executor_backend

    executor_backend = envs.executor_backend
    if executor_backend is not None:
        return _log_and_set_backend('found environment LMDEPLOY_EXECUTOR_BACKEND.', executor_backend)

    if world_size == 1:
        return 'uni'

    if dp > 1:
        executor_backend = 'ray'
        return _log_and_set_backend(f'dp={dp}.', 'ray')

    backend = get_backend(device_type)
    if not backend.support_ray():
        return _log_and_set_backend(f'device={device_type} does not support ray.', 'mp')
    else:
        return 'ray'

    # TODO: fix mp hanging, do not delete the comment.
    # device_count = backend.device_count()
    # if device_count is None:
    #     return _log_and_set_backend(f'device={device_type} can not get device_count.', 'mp')

    # if device_count < world_size:
    #     executor_backend = 'ray'
    #     return _log_and_set_backend(f'local device_count({device_count})=world_size({world_size}),', 'mp')


def build_executor(
    model_path: str,
    cache_config: CacheConfig,
    backend_config: BackendConfig,
    dist_config: DistConfig,
    misc_config: MiscConfig,
    adapters: Dict[str, str] = None,
    device_type: str = 'cuda',
    distributed_executor_backend: str = None,
    dtype: str = 'auto',
    specdecode_config: SpecDecodeConfig = None,
) -> ExecutorBase:
    """Build model agent executor."""
    logger = get_logger('lmdeploy')
    dp = dist_config.dp
    world_size = dist_config.world_size

    model_config = ModelConfig.from_pretrained(
        model_path,
        trust_remote_code=True,
        dtype=dtype,
        hf_overrides=misc_config.hf_overrides,
        dist_config=dist_config,
        is_draft_model=False,
        spec_method=None if specdecode_config is None else specdecode_config.method,
        model_format=misc_config.model_format,
        device_type=device_type,
    )

    if distributed_executor_backend is None:
        distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger)

    if dp > 1:
        assert distributed_executor_backend == 'ray', (
            'dp>1 requires distributed_executor_backend="ray", ',
            f'get distributed_executor_backend="{distributed_executor_backend}"')

    if misc_config.empty_init:
        assert distributed_executor_backend == 'ray', (
            'empty_init requires distributed_executor_backend="ray", ',
            f'get distributed_executor_backend="{distributed_executor_backend}"')

    if distributed_executor_backend is not None:
        logger.info(f'Build <{distributed_executor_backend}> executor.')
    if distributed_executor_backend == 'uni':
        assert world_size == 1, 'uni executor only support world_size==1.'
        from .uni_executor import UniExecutor
        return UniExecutor(
            model_path=model_path,
            model_config=model_config,
            cache_config=cache_config,
            backend_config=backend_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=device_type,
            specdecode_config=specdecode_config,
        )
    elif distributed_executor_backend == 'mp':
        from .mp_executor import MPExecutor
        logger.warning('MPExecutor will be deprecated in future releases, please use RayExecutor instead.')
        return MPExecutor(
            model_path=model_path,
            model_config=model_config,
            cache_config=cache_config,
            backend_config=backend_config,
            dist_config=dist_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=device_type,
            specdecode_config=specdecode_config,
        )
    elif distributed_executor_backend == 'ray':
        from .ray_executor import RayExecutor
        return RayExecutor(
            model_path=model_path,
            model_config=model_config,
            cache_config=cache_config,
            backend_config=backend_config,
            dist_config=dist_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=device_type,
            dtype=dtype,
            specdecode_config=specdecode_config,
        )
    else:
        raise RuntimeError(f'Unsupported distributed_executor_backend: {distributed_executor_backend}.')


================================================
FILE: lmdeploy/pytorch/engine/executor/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Inspired by vLLM: https://github.com/vllm-project/vllm
import asyncio
import contextlib
from typing import Any, Dict, List, Optional

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.engine.cache_engine import CacheEngine
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class ExecutorBase:
    """Executor base class."""

    def __init__(self,
                 model_path: str,
                 model_config: ModelConfig,
                 cache_config: CacheConfig,
                 backend_config: BackendConfig,
                 dist_config: DistConfig,
                 misc_config: MiscConfig,
                 adapters: Dict[str, str] = None,
                 specdecode_config: SpecDecodeConfig = None,
                 device_type: str = 'cuda'):
        """Initialize Executor."""
        cache_config.window_size = model_config.sliding_window
        if cache_config.window_size is not None and cache_config.window_size > 0:
            # do not support sliding window prefix caching
            logger.warning('Sliding window prefix caching is not supported.')
            cache_config.enable_prefix_caching = False
        self.model_config = model_config
        self.cache_config = cache_config
        self.backend_config = backend_config
        self.dist_config = dist_config
        self.misc_config = misc_config
        self.dp = dist_config.dp
        self.world_size = dist_config.world_size
        self.device_type = device_type
        self.specdecode_config = specdecode_config

    def download_models(self):
        """Download model."""
        raise NotImplementedError('Not Implemented.')

    def build_model(self):
        """Build model."""
        raise NotImplementedError('Not Implemented.')

    def gather_free_mem(self):
        """Gather available memory."""
        raise NotImplementedError('Not Implemented.')

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        raise NotImplementedError('Not Implemented.')

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
        """Set all model config."""
        raise NotImplementedError('Not Implemented.')

    def build_graph_runner(self):
        """Build graph runner."""
        raise NotImplementedError('Not Implemented.')

    def build_cache_engine(self):
        """Build cache engine."""
        raise NotImplementedError('Not Implemented.')

    def warmup(self):
        """warmup."""
        raise NotImplementedError('Not Implemented.')

    async def sleep(self, level: int = 1):
        """Sleep."""
        raise NotImplementedError('Not Implemented.')

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        raise NotImplementedError('Not Implemented.')

    def update_params(self, request: Any):
        """Update params."""
        raise NotImplementedError('Not Implemented.')

    def get_input_processor(self):
        """Get input processor."""
        raise NotImplementedError('Not Implemented.')

    def start(self, forward_event: asyncio.Event):
        """Start engine loop."""
        raise NotImplementedError('Not Implemented.')

    async def wait_tasks(self):
        """Wait tasks."""
        raise NotImplementedError('Not Implemented.')

    def stop(self):
        """Stop engine loop."""
        raise NotImplementedError('Not Implemented.')

    def release(self):
        """Release resources."""
        raise NotImplementedError('Not Implemented.')

    async def forward_async(self, inputs):
        """Start forward."""
        raise NotImplementedError('Not Implemented')

    async def get_output_async(self):
        """Get output async."""
        raise NotImplementedError('Not Implemented')

    """ PD Disaggregation API Begin """

    def p2p_initialize(self, remote_engine_config: DistServeInitRequest):
        """Init rdma link."""
        raise NotImplementedError('Not implemented')

    def p2p_connect(self, conn_request: List[DistServeKVTransferEndpointInfo]):
        """rdma_connect."""
        raise NotImplementedError('Not Implemented')

    async def migrate(self, batch: MigrationExecutionBatch):
        """KV Cache Migration."""
        raise NotImplementedError('Not Implemented')

    """ PD Disaggregation API End """

    def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_size: int):
        """Find best prefill num."""
        cache_max_entry_count = self.cache_config.cache_max_entry_count
        max_prefill_token_num = self.cache_config.max_prefill_token_num
        max_batches = self.cache_config.max_batches
        runtime_cache_size = 0
        while max_prefill_token_num > 0:
            # estimate runtime mem size
            runtime_cache_size = int((max_prefill_token_num + max_batches * 2) * vocal_size * 2)
            num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count
            if cache_block_size == 0 or int(num_available) // cache_block_size >= 16:
                break
            max_prefill_token_num = max_prefill_token_num // 2
        return runtime_cache_size, max_prefill_token_num

    def _adjust_block_size(self):
        """Adjust block_size."""
        if self.model_config.use_flash_mla is True:
            if self.cache_config.block_size != 64:
                raise ValueError('Please set block_size to 64 for flash_mla.')
            return
        # TODO: support kernel with both large head dim and large block size.
        if self.model_config.k_head_dim >= 512 and self.cache_config.block_size > 32:
            self.cache_config.block_size = 32
            logger.warning(
                f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.'  # noqa
            )

    def _get_state_cache_mem(self):
        """Get state cache mem usage."""
        cache_config = self.cache_config
        if len(cache_config.states_shapes) == 0:
            return 0

        from lmdeploy.pytorch.engine.cache_engine import StateCacheEngine

        num_state_caches = cache_config.num_state_caches
        if num_state_caches is None:
            # add more caches for eviction
            # TODO: Share memory between state cache and pageable cache
            num_state_caches = int(cache_config.max_batches + 8)
            cache_config.num_state_caches = num_state_caches

        mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes)
        mems *= num_state_caches

        if cache_config.enable_prefix_caching:
            cache_config.enable_prefix_caching = False
            logger.warning('Prefix caching has not been support for state space model.')

        return mems

    def update_configs(self):
        """Update cache config."""
        self._adjust_block_size()
        # spec
        if self.specdecode_config and self.specdecode_config.cache_config:
            self.specdecode_config.cache_config.block_size = self.cache_config.block_size
        cache_config = self.cache_config
        model_config = self.model_config
        cache_config.states_shapes = model_config.states_shapes

        # get free mems
        free_mems = self.gather_free_mem()
        free_mem = min(free_mems)
        logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')

        # get state cache size
        state_cache_mem = self._get_state_cache_mem()
        free_mem = free_mem - state_cache_mem
        assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.'

        vocal_size = self.model_config.vocab_size
        tp = self.dist_config.attn_tp
        cache_block_size = CacheEngine.get_cache_block_size(cache_config, model_config, tp)
        spec_cache_config = None
        spec_model_config = None
        spec_cache_block_size = 0
        if self.specdecode_config:
            spec_model_config = self.specdecode_config.model_config
            if spec_cache_config := self.specdecode_config.cache_config:
                spec_cache_block_size = CacheEngine.get_cache_block_size(spec_cache_config, spec_model_config, 1)

        runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size + spec_cache_block_size,
                                                                    vocal_size)
        if cache_config.max_prefill_token_num != max_prefill_token_num:
            if max_prefill_token_num <= 0:
                raise RuntimeError('No enough gpu memory for runtime.')
            cache_config.max_prefill_token_num = max_prefill_token_num
            logger.warning(f'No enough memory. Update max_prefill_token_num={max_prefill_token_num}')

        if spec_cache_config is not None:
            spec_cache_config.max_prefill_token_num = max_prefill_token_num

        free_mem -= runtime_mem
        logger.debug(f'estimated max runtime memory: {runtime_mem >> 20} mb')
        available_mem = free_mem * cache_config.cache_max_entry_count

        if cache_config.num_gpu_blocks == 0:
            cache_config.num_gpu_blocks = int(available_mem / cache_block_size)
            if cache_config.num_gpu_blocks <= 0:
                raise RuntimeError('No enough gpu memory for kv cache.')
            if spec_cache_config is not None:
                spec_cache_config.num_gpu_blocks = cache_config.num_gpu_blocks

        self.set_cache_config(cache_config, spec_cache_config)
        self.set_model_config(model_config, spec_model_config)

    def init(self):
        """init."""
        logger.info('Building Model.')
        self.build_model()
        logger.info('Updating configs.')
        self.update_configs()
        logger.info('Building GraphRunner and warmup ops, please waiting.')
        self.build_graph_runner()
        logger.info(f'Building CacheEngine with config: \n{self.cache_config}.')
        if self.specdecode_config:
            if spec_cache_config := self.specdecode_config.cache_config:
                logger.info(f'Building Spec CacheEngine with config: \n{spec_cache_config}.')
        self.build_cache_engine()
        logger.info('Warming up model.')
        self.warmup()

    @contextlib.contextmanager
    def remote_log(self, msg: str):
        """Send log for debugging.

        Do not use it in production.
        """
        # Different executor may have different log sending logic.
        yield


================================================
FILE: lmdeploy/pytorch/engine/executor/base_worker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import gc
from typing import Any, Dict, List, Optional

from lmdeploy.pytorch.backends.selector import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.model_agent import build_model_agent
from lmdeploy.utils import get_logger

from .dist_utils import init_process_group, setup_master_addr

logger = get_logger('lmdeploy')


class WorkerWrapperBase:
    """Worker wrapper."""

    def __init__(
        self,
        model_path: str,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        model_config: ModelConfig,
        dist_config: DistConfig,
        misc_config: MiscConfig,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        log_level: int = 30,
        specdecode_config: SpecDecodeConfig = None,
    ):
        self.model_path = model_path
        self.model_config = model_config
        self.cache_config = cache_config
        self.backend_config = backend_config
        self.dist_config = dist_config
        self.misc_config = misc_config
        self.adapters = adapters
        self.device_type = device_type
        self.log_level = log_level
        self.dp = dist_config.dp
        self.tp = dist_config.tp
        self.world_size = dist_config.world_size
        self.device_type = device_type
        self.specdecode_config = specdecode_config
        logger.setLevel(log_level)
        self.out_que: asyncio.Queue = None

        # frequently gc would cause latency spike
        # default threshold (700, 10, 10)
        gc.set_threshold(10000, 100, 100)

    def init_process_group(self, rank: int, master_addr: str = None, master_port: str = None):
        """Initialize process group."""
        self.rank = rank
        if self.world_size > 1:
            if master_addr is not None and master_port is not None:
                setup_master_addr(master_addr, master_port)

            init_process_group(rank, self.world_size)

        ccl_backend = get_backend(self.device_type).ccl_backend()
        self.dist_ctx = DistContext.build(self.rank, self.dist_config, ccl_backend)

    def pack_output(self, output: Dict):
        """Pack output."""
        return output

    async def get_outputs(self):
        """Get outputs."""
        return await self.get_output_async()

    def build_model(self):
        """Build model."""
        self.device_ctx = DeviceContext(device_type=self.device_type)

        self.model_agent = build_model_agent(
            model_path=self.model_path,
            model_config=self.model_config,
            cache_config=self.cache_config,
            backend_config=self.backend_config,
            misc_config=self.misc_config,
            device_ctx=self.device_ctx,
            dist_ctx=self.dist_ctx,
            adapters=self.adapters,
            specdecode_config=self.specdecode_config,
        )
        self.model_agent.build_model()

    def get_free_mem(self):
        """Gather free mem."""
        return self.model_agent.get_free_mem()

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        self.model_agent.set_cache_config(cache_config, spec_cache_config)

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
        """Set all model config."""
        self.model_agent.set_model_config(model_config, spec_model_config)

    def build_graph_runner(self):
        """Build graph runner."""
        self.model_agent.build_graph_runner()

    def build_cache_engine(self):
        """Build cache engine."""
        self.model_agent.build_cache_engine()

    def update_params(self, request: Any):
        """Update params."""
        self.model_agent.update_params(request)

    def warmup(self):
        """warmup."""
        self.model_agent.warmup()

    async def sleep(self, level: int = 1):
        """Sleep."""
        await self.model_agent.sleep(level)

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        self.model_agent.wakeup(tags)

    def get_input_processor(self):
        """Build cache engine."""
        return self.model_agent.get_input_processor()

    def start(self):
        """Start engine loop."""
        self.model_agent.start()
        self.out_que = asyncio.Queue()

    async def wait_tasks(self):
        """Wait tasks."""
        try:
            await self.model_agent.wait_tasks()
        except asyncio.CancelledError:
            logger.debug('WorkerWrapper wait_tasks cancelled.')
            raise
        except BaseException:
            # we want to keep logs in both ray logs and engine logs
            msg = f'WorkerWrapper rank[{self.rank}] wait_tasks failed.'
            logger.exception(msg)
            raise

    def stop(self):
        """Stop engine loop."""
        self.model_agent.stop()

    async def stop_async(self):
        await self.model_agent.stop_async()

    async def forward_async(self, inputs):
        """Start forward."""
        self.model_agent.set_forward_inputs(inputs)

    async def get_output_async(self):
        """Get output async."""
        ret = await self.model_agent.get_output_async()
        ret = self.pack_output(ret)
        return ret

    def release(self):
        """Stop engine loop."""
        self.model_agent.release()

    """ PD Disaggregation API Begin """

    def p2p_initialize(self, init_request: DistServeInitRequest):
        return self.model_agent.cache_engine.p2p_initialize(init_request)

    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):
        return self.model_agent.cache_engine.p2p_connect(remote_engine_id, conn_request)

    async def migrate(self, inputs: MigrationExecutionBatch):
        return await self.model_agent.cache_engine.migrate(inputs)

    """ PD Disaggregation API End """


================================================
FILE: lmdeploy/pytorch/engine/executor/dist_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
import socket
from datetime import timedelta

import torch.distributed as dist

from lmdeploy.pytorch.backends.selector import get_backend


def find_available_port() -> bool:
    """Find available port."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('0.0.0.0', 0))
        s.listen(1)
        port = s.getsockname()[1]
        return port


def setup_master_addr(addr: str, port: str):
    """Setup master addr."""
    from lmdeploy.utils import get_logger
    logger = get_logger('lmdeploy')

    if not isinstance(port, str):
        port = str(port)
    os.environ['MASTER_ADDR'] = addr
    os.environ['MASTER_PORT'] = port
    logger.info(f'MASTER_ADDR={addr}, MASTER_PORT={port}')


def init_dist_environ(rank: int, world_size: int):
    """Init environ."""
    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)


def init_process_group(rank: int, world_size: int):
    """Init process group."""
    DIST_TIMEOUT = timedelta(days=35600)
    init_dist_environ(rank, world_size)
    os.environ.pop('TORCHELASTIC_USE_AGENT_STORE', None)

    ccl_backend = get_backend().ccl_backend()
    dist.init_process_group(backend=ccl_backend, rank=rank, world_size=world_size, timeout=DIST_TIMEOUT)
    assert dist.is_initialized()


================================================
FILE: lmdeploy/pytorch/engine/executor/mp_executor.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/v1/executor/multiproc_executor.py
import asyncio
import multiprocessing.shared_memory as shared_memory
import os
import pickle
import signal
import struct
from contextlib import asynccontextmanager, contextmanager
from multiprocessing.context import SpawnContext
from typing import Any, Dict, List, Tuple

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from lmdeploy.pytorch.backends.selector import init_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.utils import get_logger, try_import_deeplink

from .base import ExecutorBase
from .base_worker import WorkerWrapperBase
from .dist_utils import find_available_port, setup_master_addr

logger = get_logger('lmdeploy')

# 1m shared memory
SHARED_BLOCK_SIZE = 1 << 20
# num shared block
NUM_SHARED_BLOCK = 32
# data size
HEAD_SIZE = 8
# block real size
SHARED_BLOCK_REAL_SIZE = SHARED_BLOCK_SIZE + HEAD_SIZE


def get_num_packages(data_size):
    """Get num packages."""
    return (data_size + SHARED_BLOCK_SIZE - 1) // SHARED_BLOCK_SIZE


class Notifier:

    def __init__(self, num_receiver: int, mp_ctx: SpawnContext):
        self.events = [mp_ctx.Event() for _ in range(NUM_SHARED_BLOCK)]
        self.bar = mp_ctx.Barrier(num_receiver + 1)
        self._event_id = 0

    def _update_event_id(self):
        self._event_id = (self._event_id + 1) % NUM_SHARED_BLOCK

    def set(self):
        self.events[self._event_id].set()
        if self._event_id == NUM_SHARED_BLOCK - 1:
            self.bar.wait()
            [event.clear() for event in self.events]
            self.bar.wait()
        self._update_event_id()

    async def set_async(self):
        # not safe if we might launch multiple reqs
        event_loop = asyncio.get_event_loop()
        self.events[self._event_id].set()
        if self._event_id == NUM_SHARED_BLOCK - 1:
            await event_loop.run_in_executor(None, self.bar.wait)
            [event.clear() for event in self.events]
            self.bar.wait()
        self._update_event_id()

    @contextmanager
    def wait(self):
        self.events[self._event_id].wait()
        yield
        if self._event_id == NUM_SHARED_BLOCK - 1:
            self.bar.wait()
            self.bar.wait()
        self._update_event_id()

    @asynccontextmanager
    async def wait_async(self):
        event_loop = asyncio.get_event_loop()
        await event_loop.run_in_executor(None, self.events[self._event_id].wait)
        yield
        if self._event_id == NUM_SHARED_BLOCK - 1:
            self.bar.wait()
            self.bar.wait()
        self._update_event_id()

    def close(self):
        for event in self.events:
            event.set()
        self.bar.abort()


class SharedBuffer:
    """Shared buffer."""

    def __init__(self, proc_id: int, notifier: Notifier, name: str = None):
        self.proc_id = proc_id
        self.notifier = notifier
        self.is_create = name is None
        if self.is_create:
            # double buffer
            self.shm = shared_memory.SharedMemory(create=True, size=SHARED_BLOCK_REAL_SIZE * NUM_SHARED_BLOCK)
        else:
            self.shm = shared_memory.SharedMemory(name=name)
        self._buf_id = 0

        if proc_id >= 0:
            self.proc_mask = 1 << proc_id
        else:
            self.proc_mask = 0

        self.is_closed = False

    @contextmanager
    def acquire_buf(self):
        buf = self.shm.buf
        assert buf is not None
        buf_start = self._buf_id * SHARED_BLOCK_REAL_SIZE
        out_buf = buf[buf_start:buf_start + SHARED_BLOCK_REAL_SIZE]
        yield out_buf
        self._buf_id = (self._buf_id + 1) % NUM_SHARED_BLOCK

    def name(self):
        return self.shm.name

    def pack_data(self, data, receiver_mask):
        """Pack data."""
        dumped_data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
        data_size = len(dumped_data)

        num_packs = get_num_packages(data_size)
        head = struct.pack('II', data_size, receiver_mask)

        for _ in range(num_packs):
            with self.acquire_buf() as buf:
                pac_size = min(len(dumped_data), SHARED_BLOCK_SIZE)
                packed_data = head + dumped_data[:pac_size]
                buf[:HEAD_SIZE + pac_size] = packed_data
                dumped_data = dumped_data[pac_size:]
                yield buf

    def send(self, data, receiver_mask: int = 0xff):
        """Pack data."""
        for _ in self.pack_data(data, receiver_mask):
            self.notifier.set()

    async def send_async(self, data, receiver_mask: int = 0xff):
        """Async pack data."""
        for _ in self.pack_data(data, receiver_mask):
            await self.notifier.set_async()

    def _receive_step0(self):
        """step0."""
        with self.acquire_buf() as buf:
            head = buf[:HEAD_SIZE]
            data_size, receiver_mask = struct.unpack('II', head)
            is_receiver = ((receiver_mask & self.proc_mask) > 0)

            pac_size = min(data_size, SHARED_BLOCK_SIZE)
            remain_size = data_size - pac_size

            dumped_data = b''
            if is_receiver:
                dumped_data += buf[HEAD_SIZE:HEAD_SIZE + pac_size]

        return dumped_data, is_receiver, remain_size

    def _receive_step1(self, dumped_data, is_receiver, remain_size):
        """step1."""
        while remain_size > 0:
            with self.notifier.wait(), self.acquire_buf() as buf:
                pac_size = min(remain_size, SHARED_BLOCK_SIZE)
                remain_size -= pac_size
                if not is_receiver:
                    continue
                dumped_data += buf[HEAD_SIZE:HEAD_SIZE + pac_size]

        if not is_receiver:
            return None
        data = pickle.loads(dumped_data)
        return data

    def receive(self):
        """Unpack data."""
        with self.notifier.wait():
            dumped_data, is_receiver, remain_size = self._receive_step0()
        return self._receive_step1(dumped_data, is_receiver, remain_size)

    async def receive_async(self):
        """Async receive data."""
        async with self.notifier.wait_async():
            dumped_data, is_receiver, remain_size = self._receive_step0()
        return self._receive_step1(dumped_data, is_receiver, remain_size)

    def close(self):
        if self.is_closed:
            return
        self.shm.close()
        if self.is_create:
            self.shm.unlink()
        self.notifier.close()
        self.is_closed = True


class MPExecutor(ExecutorBase):
    """Single node multi device Executor powered by multiprocess."""

    @classmethod
    def setup_master_addr(cls):
        """Setup master addr."""
        port = find_available_port()
        os.environ.setdefault('MASTER_ADDR', '127.0.0.1')
        os.environ.setdefault('MASTER_PORT', str(port))
        addr = os.environ['MASTER_ADDR']
        port = os.environ['MASTER_PORT']
        setup_master_addr(addr, port)

    def __init__(self,
                 model_path: str,
                 model_config: ModelConfig,
                 cache_config: CacheConfig,
                 backend_config: BackendConfig,
                 dist_config: DistConfig,
                 misc_config: MiscConfig,
                 adapters: Dict[str, str] = None,
                 specdecode_config: SpecDecodeConfig = None,
                 device_type: str = 'cuda'):
        """Initialize Executor."""
        super().__init__(model_path=model_path,
                         model_config=model_config,
                         cache_config=cache_config,
                         backend_config=backend_config,
                         dist_config=dist_config,
                         misc_config=misc_config,
                         specdecode_config=specdecode_config,
                         adapters=adapters,
                         device_type=device_type)

        # initialize processes.
        self.setup_master_addr()
        mp_ctx = mp.get_context('spawn')
        self.mp_ctx = mp_ctx
        self.comm_notifier = Notifier(self.world_size, mp_ctx)
        self.comm_buf = SharedBuffer(-1, notifier=self.comm_notifier)
        self.comm_buf_name = self.comm_buf.name()

        logger.info('Creating processes.')
        self.procs: List[ExecutorProc] = []
        self.ret_bufs: List[SharedBuffer] = []
        for proc_id in range(self.world_size):
            proc = ExecutorProc(proc_id=proc_id, mp_ctx=mp_ctx)

            ret_notifier = Notifier(1, mp_ctx)
            ret_buf = SharedBuffer(0, notifier=ret_notifier)
            self.ret_bufs.append(ret_buf)
            proc.start(proc_id=proc_id,
                       comm_notifier=self.comm_notifier,
                       comm_buf_name=self.comm_buf_name,
                       ret_notifier=ret_notifier,
                       ret_buf_name=ret_buf.name(),
                       model_path=model_path,
                       model_config=model_config,
                       cache_config=cache_config,
                       backend_config=backend_config,
                       dist_config=dist_config,
                       misc_config=misc_config,
                       specdecode_config=specdecode_config,
                       adapters=adapters,
                       device_type=device_type,
                       log_level=logger.level)
            self.procs.append(proc)

        self._prefetch_task: asyncio.Task = None
        self.remote_outs: asyncio.Queue = None

        def signal_handler(signum, frame):
            logger.error('Received custom termination signal from sub processing, exiting...')
            self.stop()
            self.release()
            os._exit(1)

        signal.signal(signal.SIGUSR1, signal_handler)

    def collective_rpc(self,
                       method: str,
                       args: Tuple[Any] = None,
                       kwargs: Dict[str, Any] = None,
                       receiver_mask: int = 0xff,
                       return_mask: int = 0xff):
        """Collective rpc."""
        if args is None:
            args = list()
        if kwargs is None:
            kwargs = dict()
        return_mask &= receiver_mask
        self.comm_buf.send(
            dict(
                method=method,
                args=args,
                kwargs=kwargs,
                return_mask=return_mask,
            ),
            receiver_mask=receiver_mask,
        )

        if return_mask:
            outputs = [None] * len(self.ret_bufs)
            for proc_id, ret_buf in enumerate(self.ret_bufs):
                if bool(return_mask & (1 << proc_id)):
                    outputs[proc_id] = ret_buf.receive()
            return outputs

    async def collective_rpc_async(self,
                                   method: str,
                                   args: Tuple[Any] = None,
                                   kwargs: Dict[str, Any] = None,
                                   receiver_mask: int = 0xff,
                                   return_mask: int = 0xff):
        """Collective rpc."""
        if args is None:
            args = list()
        if kwargs is None:
            kwargs = dict()
        self.comm_buf.send(
            dict(
                method=method,
                args=args,
                kwargs=kwargs,
                return_mask=return_mask,
            ),
            receiver_mask=receiver_mask,
        )

        if return_mask:
            outputs = [None] * len(self.ret_bufs)
            for proc_id, ret_buf in enumerate(self.ret_bufs):
                if bool(return_mask & (1 << proc_id)):
                    outputs[proc_id] = await ret_buf.receive_async()
            return outputs

    def download_models(self):
        """Download model."""
        raise NotImplementedError('Not Implemented.')

    def build_model(self):
        """Build model."""
        self.collective_rpc('build_model')

    def gather_free_mem(self):
        """Gather available memory."""
        ret = self.collective_rpc('get_free_mem')
        return ret

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        self.collective_rpc('set_cache_config', args=(cache_config, spec_cache_config))

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
        """Set all cache config."""
        self.collective_rpc('set_model_config', args=(model_config, spec_model_config))

    def build_graph_runner(self):
        """Build graph runner."""
        self.collective_rpc('build_graph_runner')

    def build_cache_engine(self):
        """Build cache engine."""
        self.collective_rpc('build_cache_engine')

    def warmup(self):
        """Build cache engine."""
        self.collective_rpc('warmup')

    async def _prefetch_outputs(self):
        while True:
            out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]
            self.remote_outs.put_nowait(out)

    def start(self, forward_event: asyncio.Event):
        """Start engine loop."""
        self.collective_rpc('start')

        self.remote_outs = asyncio.Queue()
        event_loop = asyncio.get_event_loop()
        self._prefetch_task = event_loop.create_task(self._prefetch_outputs())

    async def wait_tasks(self):
        """Wait tasks."""
        # we don't need a complex wait tasks since MPExecutor will be deprecated soon.
        await self._prefetch_task

    async def forward_async(self, inputs):
        """Start forward."""
        await self.collective_rpc_async('forward_async', args=(inputs, ), return_mask=0)

    async def get_output_async(self):
        """Get output async."""
        return await self.remote_outs.get()

    def get_input_processor(self):
        """Get input processor."""
        return self.collective_rpc('get_input_processor', receiver_mask=1, return_mask=1)[0]

    def stop(self):
        """Stop engine loop."""
        if self._prefetch_task is not None:
            self._prefetch_task.cancel()

    def release(self):
        """release."""
        for proc in self.procs:
            proc.close()

        for proc in self.procs:
            proc.join()

        self.comm_buf.close()
        for ret_buf in self.ret_bufs:
            ret_buf.close()


class MPWorkerWrapper(WorkerWrapperBase):
    """Mp worker wrapper."""

    def __init__(
        self,
        model_path: str,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        model_config: ModelConfig,
        dist_config: DistConfig,
        misc_config: MiscConfig,
        specdecode_config: SpecDecodeConfig = None,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        log_level: int = 30,
    ):
        super().__init__(
            model_path=model_path,
            cache_config=cache_config,
            backend_config=backend_config,
            model_config=model_config,
            dist_config=dist_config,
            misc_config=misc_config,
            specdecode_config=specdecode_config,
            adapters=adapters,
            device_type=device_type,
            log_level=log_level,
        )


class ExecutorProc:

    def __init__(self, proc_id: int, mp_ctx: SpawnContext):
        """Executor proc."""
        self.proc_id = proc_id
        self.mp_ctx = mp_ctx
        self._proc = None

    def start(self, **kwargs):
        """Start proc."""
        assert self._proc is None
        proc = self.mp_ctx.Process(target=self._main_loop,
                                   kwargs=kwargs,
                                   name=f'ExecutorProc-{self.proc_id}',
                                   daemon=True)
        proc.start()
        self._proc = proc

    def close(self):
        """Stop proc."""
        if self._proc is None:
            return
        if not self._proc.is_alive():
            return
        self._proc.terminate()

    def join(self):
        if self._proc is None:
            return
        self._proc.join()

    def _main_loop(
        self,
        proc_id: int,
        comm_notifier: Any,
        comm_buf_name: str,
        ret_notifier: Any,
        ret_buf_name: str,
        model_path: str,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        dist_config: DistConfig,
        misc_config: MiscConfig,
        specdecode_config: SpecDecodeConfig = None,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        log_level: int = 30,
    ):
        """Main loop."""
        init_backend(device_type)
        torch.cuda.set_device(proc_id)

        # catch signal
        def handle_sigterm(signum, frame):
            logger.debug(f'Proc[{proc_id}] terminated.')
            exit(0)

        signal.signal(signal.SIGTERM, handle_sigterm)

        worker = MPWorkerWrapper(model_path,
                                 cache_config=cache_config,
                                 backend_config=backend_config,
                                 model_config=model_config,
                                 dist_config=dist_config,
                                 misc_config=misc_config,
                                 specdecode_config=specdecode_config,
                                 adapters=adapters,
                                 device_type=device_type,
                                 log_level=log_level)
        try_import_deeplink(device_type)
        worker.init_process_group(proc_id)
        comm_buf = SharedBuffer(proc_id, notifier=comm_notifier, name=comm_buf_name)
        ret_buf = SharedBuffer(-1, notifier=ret_notifier, name=ret_buf_name)
        event_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(event_loop)
        destroy_pg = worker.world_size > 1
        try:
            event_loop.run_until_complete(
                self._main_loop_impl(proc_id, comm_buf=comm_buf, ret_buf=ret_buf, worker=worker))
        except asyncio.CancelledError:
            logger.warning(f'Proc[{proc_id}] main loop cancelled.')
            destroy_pg = False
            os.kill(os.getppid(), signal.SIGUSR1)
        except SystemExit:
            # terminated by executor
            logger.debug(f'Proc[{proc_id}] system exit.')
        except KeyboardInterrupt:
            logger.debug(f'Proc[{proc_id}] keyboard interrupt.')
            exit(0)
        except BaseException:
            logger.exception(f'Proc[{proc_id}] failed')
            os.kill(os.getppid(), signal.SIGUSR1)
        finally:
            logger.debug(f'Proc[{proc_id}] cleanup.')
            worker.stop()
            worker.release()
            comm_buf.close()
            ret_buf.close()
            if dist.is_initialized() and destroy_pg:
                dist.destroy_process_group()

    @staticmethod
    async def _task_wrapper(func, args: List, kwargs: Dict, need_return: bool, ret_buf: SharedBuffer):
        ret = await func(*args, **kwargs)
        if need_return:
            await ret_buf.send_async(ret)

    async def _main_loop_impl(self, proc_id: int, comm_buf: SharedBuffer, ret_buf: SharedBuffer,
                              worker: MPWorkerWrapper):
        """Main loop."""
        proc_mask = 1 << proc_id
        event_loop = asyncio.get_event_loop()
        while True:
            command = await comm_buf.receive_async()
            if command is None:
                continue
            method = command['method']
            return_mask = command.get('return_mask', True)
            args = command.get('args', list())
            kwargs = command.get('kwargs', dict())
            need_return = bool(proc_mask & return_mask)

            func = getattr(worker, method, None)
            assert func is not None, f'method: <{method}> not exists.'
            call_async = asyncio.iscoroutinefunction(func)

            logger.debug(f'proc[{proc_id}] call method: <{method}>.')
            if call_async:
                event_loop.create_task(self._task_wrapper(func, args, kwargs, need_return, ret_buf))
            else:
                ret = func(*args, **kwargs)
                if need_return:
                    ret_buf.send(ret)


================================================
FILE: lmdeploy/pytorch/engine/executor/ray_executor.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import contextlib
import json
import os
from typing import Any, Dict, List, Optional, Tuple

import ray
import ray.exceptions
import torch
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.backends.selector import init_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.ray import RayContext, get_device_str
from lmdeploy.pytorch.utils import wait_for_async_tasks
from lmdeploy.utils import get_logger, try_import_deeplink

from .base import ExecutorBase
from .base_worker import WorkerWrapperBase
from .dist_utils import find_available_port

logger = get_logger('lmdeploy')


def _get_master_addr():
    """Get master addr."""
    addr = _envs.dist_master_addr
    if addr is not None:
        return addr
    gcs_addr = ray.get_runtime_context().gcs_address
    master_addr = gcs_addr.split(':')[0]
    return master_addr


def _get_master_port():
    """Get master port."""
    port = _envs.dist_master_port
    if port is not None:
        return port
    return find_available_port()


def get_ascend_device_rank_mapping(master_addr):
    rank_table_file = _envs.ascend_rank_table_file
    if not rank_table_file:
        raise ValueError('ASCEND_RANK_TABLE_FILE_PATH is not set')
    with open(rank_table_file, 'r') as f:
        rank_table = json.load(f)
    try:
        assert master_addr == rank_table['server_list'][0]['server_id'], 'Master address does not match rank table'
        rank_mapping: Dict[int, int] = {}
        worker_ip_by_rank: Dict[int, str] = {}
        for server in rank_table['server_list']:
            node_ip = server['server_id']
            for idx, device in enumerate(server['device']):
                # Prefer explicit device_id if present; fall back to enumeration order.
                local_rank = int(device.get('device_id', idx))
                global_rank = int(device['rank_id'])
                rank_mapping[global_rank] = local_rank
                worker_ip_by_rank[global_rank] = node_ip

        if len(worker_ip_by_rank) == 0:
            raise ValueError('Rank table contains no devices.')

        ranks = sorted(worker_ip_by_rank.keys())
        if ranks[0] != 0 or ranks[-1] != len(ranks) - 1:
            raise ValueError(f'Rank ids are not contiguous starting from 0: {ranks[:8]}...{ranks[-8:]}')
        worker_ips = [worker_ip_by_rank[r] for r in range(len(ranks))]
    except Exception as e:
        logger.error(f'Parse rank table file({rank_table})  failed')
        raise e

    envs = {
        'ASCEND_RANK_TABLE_FILE_PATH': rank_table_file,
    }
    return rank_mapping, worker_ips, envs


def _update_env_cuda_alloc_conf(env_vars: Dict):
    """Update runtime env for CUDA alloc conf."""
    cuda_alloc_conf = os.getenv('PYTORCH_CUDA_ALLOC_CONF', None)
    if cuda_alloc_conf is None:
        return

    # check and update conf, skip expandable_segments
    cuda_alloc_conf = cuda_alloc_conf.split(',')
    new_cuda_alloc_conf = []
    for conf in cuda_alloc_conf:
        if 'expandable_segments' in conf:
            if 'True' in conf:
                logger.warning('"expandable_segments:True" is not supported.')
            continue
        new_cuda_alloc_conf.append(conf)
    if len(new_cuda_alloc_conf) == 0:
        new_cuda_alloc_conf = ['expandable_segments:False']
    cuda_alloc_conf = ','.join(new_cuda_alloc_conf)

    # update env_vars
    env_vars['PYTORCH_CUDA_ALLOC_CONF'] = cuda_alloc_conf


def _update_runtime_envs(runtime_env: Dict):
    """Update runtime envs."""
    new_envs = _envs.get_all_envs()
    env_vars: Dict = runtime_env.get('env_vars', {})
    env_vars.update(new_envs)
    _update_env_cuda_alloc_conf(env_vars)
    runtime_env['env_vars'] = env_vars
    return runtime_env


def _update_runtime_env_nsys(runtime_env: Dict):
    """Update runtime env for nsys."""
    nsight_env = {
        't': 'cuda,cudnn,cublas,nvtx',
        'o': "'worker_process_%p'",
        'stop-on-exit': 'true',
    }
    prefix_path = _envs.ray_nsys_output_prefix
    if prefix_path is not None:
        nsight_env['o'] = f'{prefix_path}%p'
    runtime_env['nsight'] = nsight_env
    return runtime_env


class RemoteLogger:
    """Remote logger."""

    def __init__(self):
        self._records = dict()
        self._next_handle = 0

    def start(self, msg: str):
        """Start remote log."""
        record = torch.profiler.record_function(msg)
        record.__enter__()
        handle = self._next_handle
        self._records[handle] = record
        self._next_handle += 1
        return handle

    def end(self, handle: int):
        """End remote log."""
        record = self._records.pop(handle, None)
        if record is not None:
            record.__exit__(None, None, None)


class RayWorkerWrapper(WorkerWrapperBase):
    """Worker wrapper."""

    def __init__(
        self,
        model_path: str,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        model_config: ModelConfig,
        dist_config: DistConfig,
        misc_config: MiscConfig,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        dtype: str = 'auto',
        log_level: int = 30,
        specdecode_config: SpecDecodeConfig = None,
    ):
        init_backend(device_type)
        try_import_deeplink(device_type)

        super().__init__(
            model_path=model_path,
            cache_config=cache_config,
            backend_config=backend_config,
            model_config=model_config,
            dist_config=dist_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=device_type,
            log_level=log_level,
            specdecode_config=specdecode_config,
        )
        self.node_ip = ray.util.get_node_ip_address()
        self._remote_logger = RemoteLogger()

    def set_device(self, local_rank):
        """Set worker local rank."""
        torch.cuda.set_device(local_rank)

    def set_env(self, envs: Dict[str, str]):
        for key, value in envs.items():
            os.environ[key] = value

    def get_node_ip(self):
        """Get worker ip."""
        return self.node_ip

    def warmup_dist(self):
        # None default CUDA_VISIBLE_DEVICES might leads to slow first time all_reduce
        # WHY?
        logger.debug('Warmup all_reduce.')
        import torch

        from lmdeploy.pytorch.distributed import all_reduce, get_dist_manager
        with get_dist_manager().context(self.dist_ctx):
            group = self.dist_ctx.tp_group.gpu_group
            tmp = torch.empty((1, ), device='cuda')
            all_reduce(tmp, group=group)

    def pack_output(self, output: Dict):
        """Pack output."""
        return output.to_numpy()

    def remote_log_start(self, msg: str):
        """Remote log start."""
        return self._remote_logger.start(msg)

    def remote_log_end(self, handle: int):
        """Remote log end."""
        return self._remote_logger.end(handle)

    def exit(self):
        """Exit actor."""
        ray.actor.exit_actor()


class RayExecutor(ExecutorBase):
    """Ray executor."""

    def __init__(
        self,
        model_path: str,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        dist_config: DistConfig,
        misc_config: MiscConfig,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        dtype: str = 'auto',
        specdecode_config: SpecDecodeConfig = None,
    ):
        """Initialize Executor."""
        super().__init__(
            model_path=model_path,
            model_config=model_config,
            cache_config=cache_config,
            backend_config=backend_config,
            dist_config=dist_config,
            misc_config=misc_config,
            adapters=adapters,
            device_type=device_type,
            specdecode_config=specdecode_config,
        )

        device_ctx = DeviceContext(device_type)
        with get_device_manager().context(device_ctx):
            logger.info('Init ray cluster.')
            attn_tp = dist_config.attn_tp
            self.ray_ctx = RayContext(attn_tp, dp=dist_config.dp, device_type=device_type)
            placement_group = self.ray_ctx.get_placement_group()
            self.placement_group = placement_group

            if self.dp == 1:
                self.master_addr = _get_master_addr()
                self.master_port = _get_master_port()
            else:
                self.master_addr = _envs.dp_master_addr
                self.master_port = _envs.dp_master_port
                if self.master_addr is None or self.master_port is None:
                    raise RuntimeError('DP > 1 requires "LMDEPLOY_DP_MASTER_ADDR" and "LMDEPLOY_DP_MASTER_PORT".')

            # create workerwrapper actors
            worker_kwargs = dict(
                model_path=model_path,
                cache_config=cache_config,
                model_config=model_config,
                backend_config=backend_config,
                dist_config=dist_config,
                misc_config=misc_config,
                adapters=adapters,
                device_type=device_type,
                dtype=dtype,
                log_level=logger.level,
                specdecode_config=specdecode_config,
            )

            logger.info('Init ray workers.')
            self.workers = self._init_workers_ray(placement_group, worker_kwargs)
            self.dag = None
            self._prefetch_task: asyncio.Task = None
            self.remote_outs: asyncio.Queue = None

            logger.info('Init distributed environment by device.')
            self.rank_offset = dist_config.dp_rank * attn_tp
            self._init_distributed_environment_by_device(device_type)

            logger.info('Init distributed process group.')
            ray.get([
                worker.init_process_group.remote(rank + self.rank_offset, self.master_addr, self.master_port)
                for rank, worker in enumerate(self.workers)
            ])

            if self.dist_config.world_size > 1:
                logger.info('Warming up distribute environment, this might take long time, please waiting...')
                ray.get([worker.warmup_dist.remote() for worker in self.workers])

    def collective_rpc(self,
                       method: str,
                       args: Tuple[Any] = None,
                       kwargs: Dict[str, Any] = None,
                       timeout: float = None):
        """Collective rpc."""
        if args is None:
            args = list()
        if kwargs is None:
            kwargs = dict()
        return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)

    def build_model(self):
        """Build model."""
        self.collective_rpc('build_model')

    def gather_free_mem(self):
        """Gather available memory."""
        return self.collective_rpc('get_free_mem')

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        self.collective_rpc('set_cache_config', (cache_config, spec_cache_config))

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
        """Set all model config."""
        self.collective_rpc('set_model_config', (model_config, spec_model_config))

    def build_graph_runner(self):
        """Build graph runner."""
        self.collective_rpc('build_graph_runner')

    def build_cache_engine(self):
        """Build cache engine."""
        self.collective_rpc('build_cache_engine')

    def update_params(self, request: Any):
        """Update params."""
        self.collective_rpc('update_params', (request, ))

    def warmup(self):
        """Build cache engine."""
        self.collective_rpc('warmup')

    def sleep(self, level: int = 1):
        """Sleep."""
        self.collective_rpc('sleep', (level, ))

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        if tags is None or 'kv_cache' in tags:
            self.update_configs()
        self.collective_rpc('wakeup', (tags, ))

    def get_input_processor(self):
        """Build cache engine."""
        return ray.get(self.workers[0].get_input_processor.remote())

    def _prefetch_task_callback(self, task: asyncio.Task):
        try:
            task.result()
        except asyncio.CancelledError:
            logger.debug(f'{task.get_name()} cancelled.')
        except KeyboardInterrupt:
            logger.debug(f'{task.get_name()} KeyboardInterrupt.')
        except BaseException:
            logger.debug(f'{task.get_name()} task failed.')

    def start(self, forward_event: asyncio.Event):
        """Start engine loop."""
        self.forward_event = forward_event
        self.collective_rpc('start')

        self.remote_outs = asyncio.Queue()
        logger.info('Starting async task RayPrefetchOutput loop.')

    async def wait_tasks(self):
        """Wait tasks."""
        dp_rank = self.dist_config.dp_rank
        tasks_to_cancel = set()
        event_loop = asyncio.get_event_loop()

        async def _wait_single_worker(worker):
            try:
                task = worker.wait_tasks.remote()
                tasks_to_cancel.add(task)
                await task
            except ray.exceptions.ActorDiedError:
                # It is safe to ignore wait tasks on died actor
                logger.info('RayExecutor worker has been killed before finish wait_tasks.')

        tasks = [
            event_loop.create_task(_wait_single_worker(worker), name=f'WorkerWaitTasks_{idx}')
            for idx, worker in enumerate(self.workers)
        ]
        if self._prefetch_task is not None:
            tasks.append(self._prefetch_task)
        try:
            await wait_for_async_tasks(tasks)
        except asyncio.CancelledError:
            logger.info(f'RayExecutor DP[{dp_rank}] wait_tasks cancelled.')
            raise
        except BaseException:
            logger.error(f'RayExecutor DP[{dp_rank}] wait_tasks failed.')
            raise
        finally:
            logger.debug(f'RayExecutor DP[{dp_rank}] wait_tasks cleanup.')
            for task in tasks_to_cancel:
                try:
                    ray.cancel(task)
                except ray.exceptions.ActorDiedError:
                    logger.debug('RayExecutor worker has been killed before finish cancel task.')
                except Exception as e:
                    logger.error(f'RayExecutor DP[{dp_rank}] Cancel wait_tasks failed: {e}')

    def stop(self):
        """Stop engine loop."""
        # TODO: For dp > 1 we currently rely on external teardown (e.g. Ray actor
        # destruction) instead of explicitly stopping worker loops here. Implementing
        # coordinated shutdown across multiple dp ranks is non-trivial, especially
        # when some ranks may have already failed. The explicit stop_async RPC is
        # therefore only issued when dp == 1.
        if self.dp == 1:
            try:
                # add timeout might disable dump profile
                # hope this will not lead to hanging
                self.collective_rpc('stop_async')
            except ray.exceptions.ActorDiedError:
                logger.info('RayExecutor worker has been killed before finish stop_async.')
            logger.debug('RayExecutor workers stopped.')
        if self._prefetch_task is not None:
            self._prefetch_task.cancel()

    def release(self):
        """release."""
        if _envs.ray_timeline_enable:
            ray.timeline(_envs.ray_timeline_output_path)

        if self.dp == 1:
            try:
                self.collective_rpc('release', timeout=5.0)
                logger.debug('RayExecutor workers released.')
            except ray.exceptions.ActorDiedError:
                logger.info('RayExecutor worker has been killed before finish release.')
                [ray.kill(worker) for worker in self.workers]
            except ray.exceptions.GetTimeoutError:
                logger.info('Ray release timeout, killing workers')
                [ray.kill(worker) for worker in self.workers]
        else:
            [ray.kill(worker) for worker in self.workers]

        self.ray_ctx.shutdown()

    def _compile_dag(self):
        """Compile dag."""
        from ray.dag.input_node import InputNode
        from ray.dag.output_node import MultiOutputNode
        with InputNode() as input_data:
            outputs = [worker.forward_async.bind(input_data) for worker in self.workers]
            output = MultiOutputNode(outputs)

        return output

    async def forward_async(self, inputs):
        """Start forward."""

        if self.dag is None:
            self.dag = self._compile_dag()
            self._prev_inputs = None
            self._prev_out = None

        if self._prev_out is not None:
            try:
                ray.get(self._prev_out)
            except SystemExit:
                logger.error('Ray worker exited.')
                raise
            finally:
                # free ray.put inputs
                try:
                    ray._private.internal_api.free(self._prev_inputs)
                except Exception as e:
                    logger.warning(f'Free input ref failed: {e}')

        self._prev_inputs = ray.put(inputs)
        # make sure in order
        self._prev_out = self.dag.execute(self._prev_inputs)

    async def get_output_async(self):
        """Get output async."""
        ret = await self.workers[0].get_outputs.remote()
        ret = ret.to_tensor()
        return ret

    @contextlib.contextmanager
    def remote_log(self, msg: str):
        """Send log for debugging.

        Do not use it in production.
        """
        handle_ref = self.workers[0].remote_log_start.remote(msg)
        yield
        handle = ray.get(handle_ref)
        ray.get(self.workers[0].remote_log_end.remote(handle))

    def _sort_workers(self, driver_ip: str, workers: List[RayWorkerWrapper]):
        """Sort workers by ip."""
        worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])

        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

        worker_ip_map = list(zip(workers, worker_ips))

        def sort_by_driver_then_worker_ip(item):
            """Sort the workers based on 3 properties:

            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
            ip = item[1]
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
        sorted_worker_ip_map = sorted(worker_ip_map, key=sort_by_driver_then_worker_ip)
        workers = [item[0] for item in sorted_worker_ip_map]
        return workers

    def _sort_workers_by_ip(self, ips, workers: List[RayWorkerWrapper]):
        worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])

        if len(ips) != len(workers):
            raise ValueError(f'The length of the ips list does not match the workers, '
                             f'ips length: {len(ips)}, workers length: {len(workers)}')

        # Check if all elements in ips are present in worker_ips and vice versa (ignoring order)
        if set(ips) != set(worker_ips):
            raise ValueError(f'The IP addresses in the ips list do not match the worker IPs. '
                             f'ips: {ips}, worker_ips: {worker_ips}')

        worker_ip_map = list(zip(workers, worker_ips))
        ip_priority = {ip: idx for idx, ip in enumerate(ips)}

        def get_priority(ip):
            return ip_priority.get(ip)

        sorted_worker_ip_map = sorted(worker_ip_map, key=lambda x: get_priority(x[1]))
        sorted_workers = [item[0] for item in sorted_worker_ip_map]
        return sorted_workers

    def _valid_bundle_id(self, bundle_id: int):
        """Check if a bundle is valid only when self.use_external_ray=True."""
        if (not self.ray_ctx.owned_pg and _envs.ray_external_pg_bundles
                and bundle_id not in _envs.ray_external_pg_bundles):
            return False
        return True

    def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict):
        """Init worker ray."""
        device_str = get_device_str()
        bundle_indices = []
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if bundle.get(device_str, 0) and self._valid_bundle_id(bundle_id):
                bundle_indices.append(bundle_id)
        attn_tp = self.dist_config.attn_tp
        bundle_indices = bundle_indices[:attn_tp]

        workers = list()
        for _, bundle_id in enumerate(bundle_indices):
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )

            if device_str == 'GPU':
                runtime_env = dict()
                runtime_env = _update_runtime_envs(runtime_env)
                if _envs.ray_nsys_enable:
                    runtime_env = _update_runtime_env_nsys(runtime_env)
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0.01,
                    scheduling_strategy=scheduling_strategy,
                    runtime_env=runtime_env,
                )(RayWorkerWrapper).remote(**worker_kwargs)
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={device_str: 0.01},
                    scheduling_strategy=scheduling_strategy,
                )(RayWorkerWrapper).remote(**worker_kwargs)
            workers.append(worker)
        return workers

    def _init_distributed_environment_by_device(self, device_str: str):
        """Init distributed environment."""
        driver_ip = _get_master_addr()
        if device_str == 'cuda':
            self.workers = self._sort_workers(driver_ip, self.workers)

        elif device_str == 'ascend':
            self._init_ascend_distributed_environment(driver_ip)
        elif device_str in ['camb', 'maca']:
            self.workers = self._sort_workers(driver_ip, self.workers)
            ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
        else:
            raise ValueError(f'Unsupported device type: {device_str}')

    def _init_ascend_distributed_environment(self, driver_ip):
        """Init ascend distributed environment."""
        rank_table_file = _envs.ascend_rank_table_file
        set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray

        if rank_table_file:
            # if rank table file is set, use it to get rank mapping, multiple nodes
            rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)
            rank_start = self.rank_offset
            rank_end = rank_start + len(self.workers)
            if rank_end > len(worker_ips):
                raise ValueError(
                    'Rank table world_size is smaller than required ranks for current dp_rank. '
                    f'rank_table_world_size={len(worker_ips)}, required_rank_range=[{rank_start}, {rank_end})')

            # In dp mode each process only owns a slice of global ranks.
            expected_worker_ips = worker_ips[rank_start:rank_end]
            self.workers = self._sort_workers_by_ip(expected_worker_ips, self.workers)

            ray.get(
                [worker.set_device.remote(rank_mapping[rank_start + idx]) for idx, worker in enumerate(self.workers)])
            ray.get([worker.set_env.remote(envs) for worker in self.workers])
        elif not set_rt_visable_devices_by_ray:
            # if rank table file is not set, treat as single node
            # simply set device by index, this is for single node, multiple devices
            self.workers = self._sort_workers(driver_ip, self.workers)
            ray.get([worker.set_device.remote(idx + self.rank_offset) for idx, worker in enumerate(self.workers)])
        else:
            self.workers = self._sort_workers(driver_ip, self.workers)

    """ PD Disaggregation API Begin """

    def p2p_initialize(self, init_request: DistServeInitRequest):
        return self.collective_rpc('p2p_initialize', (init_request, ))

    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):
        """Rdma connect."""
        return self.collective_rpc('p2p_connect', (
            remote_engine_id,
            conn_request,
        ))

    async def migrate(self, batch: MigrationExecutionBatch):
        jobs = (worker.migrate.remote(batch) for worker in self.workers)
        return await asyncio.gather(*jobs)

    """ PD Disaggregation API Begin """


================================================
FILE: lmdeploy/pytorch/engine/executor/uni_executor.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Dict, List

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.engine.model_agent import build_model_agent
from lmdeploy.utils import get_logger

from .base import ExecutorBase

logger = get_logger('lmdeploy')


class UniExecutor(ExecutorBase):
    """Single node single device Executor."""

    def __init__(
        self,
        model_path: str,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        misc_config: MiscConfig,
        adapters: Dict[str, str] = None,
        device_type: str = 'cuda',
        specdecode_config: SpecDecodeConfig = None,
    ):
        """Initialize Executor."""
        super().__init__(model_path=model_path,
                         model_config=model_config,
                         cache_config=cache_config,
                         backend_config=backend_config,
                         dist_config=DistConfig(),
                         misc_config=misc_config,
                         adapters=adapters,
                         device_type=device_type,
                         specdecode_config=specdecode_config)

        self.device_ctx = DeviceContext(device_type=device_type)
        self.model_agent = build_model_agent(
            model_path=model_path,
            model_config=model_config,
            cache_config=cache_config,
            backend_config=backend_config,
            misc_config=misc_config,
            device_ctx=self.device_ctx,
            adapters=adapters,
            specdecode_config=specdecode_config,
        )

    def download_models(self):
        """Download model."""
        raise NotImplementedError('Not Implemented.')

    def build_model(self):
        """Build model."""
        self.model_agent.build_model()

    def gather_free_mem(self):
        """Gather available memory."""
        return [self.model_agent.get_free_mem()]

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        self.model_agent.set_cache_config(cache_config, spec_cache_config)

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig):
        """Set all cache config."""
        self.model_agent.set_model_config(model_config, spec_model_config)

    def build_graph_runner(self):
        """Build graph runner."""
        self.model_agent.build_graph_runner()

    def build_cache_engine(self):
        """Build cache engine."""
        self.model_agent.build_cache_engine()

    def warmup(self):
        self.model_agent.warmup()

    def start(self, forward_event: asyncio.Event):
        """Start engine loop."""
        self.model_agent.start(forward_event)

    async def wait_tasks(self):
        """Wait tasks."""
        await self.model_agent.wait_tasks()

    def stop(self):
        """Stop engine loop."""
        self.model_agent.stop()

    def release(self):
        """Release resources."""
        self.model_agent.release()

    async def forward_async(self, inputs):
        """Start forward."""
        self.model_agent.set_forward_inputs(inputs)
        # switch to task: ModelAgent._async_loop_inputs_preprocess
        await asyncio.sleep(0)

    async def get_output_async(self, dp_rank: int = 0):
        """Get output async."""
        assert dp_rank == 0
        return await self.model_agent.get_output_async()

    def get_input_processor(self):
        """Get input processor."""
        return self.model_agent.get_input_processor()

    """ PD Disaggregation API Begin """

    def p2p_initialize(self, init_request: DistServeInitRequest):
        """Init rdma link.

        note: return list to be composible with multiprocess executor like ray.
        """
        return [self.model_agent.cache_engine.p2p_initialize(init_request)]

    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):
        """rdma_connect."""
        self.model_agent.cache_engine.p2p_connect(remote_engine_id, conn_request)

    async def migrate(self, batch: MigrationExecutionBatch):
        """KV Cache Migration."""
        return await self.model_agent.cache_engine.migrate(batch)

    """ PD Disaggregation API End """


================================================
FILE: lmdeploy/pytorch/engine/guided_process.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
from typing import Any, Dict, List, Optional, Tuple

import torch
import xgrammar as xgr
from transformers import PreTrainedTokenizerBase

logger = logging.getLogger('lmdeploy')


class GuidedDecodingManager:
    processors = {}

    def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]):
        if vocab_size is None:
            vocab_size = tokenizer.vocab_size

        tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)
        self.compiler = xgr.GrammarCompiler(tokenizer_info)
        self.vocab_size = vocab_size

    def get_processors(self, session_ctx: List[Dict[str, Any]],
                       response_formats: Tuple[Dict]) -> Dict[int, xgr.GrammarMatcher]:
        processors = {}
        for i, _format in enumerate(response_formats):
            if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
                schema_type = _format['type']
                if schema_type == 'json_schema':
                    schema = _format['json_schema']
                    if isinstance(schema, Dict):
                        for key in ['json_schema', 'schema']:
                            if key in schema:
                                schema = json.dumps(schema[key], ensure_ascii=False)

                    if not isinstance(schema, str):
                        raise ValueError(f'Cannot parse schema {schema}. The schema must be '
                                         'either a dictionary or a string that contains the'
                                         ' JSON Schema specification')
                elif schema_type == 'regex_schema':
                    schema = _format.get('regex_schema', '')
                elif schema_type == 'json_object':
                    schema = '{"type" : "object", "additionalProperties": true}'
                else:
                    raise ValueError(f'unsupported format type: {schema_type}')

                session_id = session_ctx[i]['session_id']
                seq_id = session_ctx[i]['seq_id']

                processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)

        return processors

    def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) -> xgr.GrammarMatcher:
        if session_id in self.processors:
            session_dict = self.processors[session_id]
            if seq_id in session_dict:
                processor = session_dict[seq_id]
                return processor

        if type == 'json_schema':
            if isinstance(schema, str):
                schema = json.loads(schema)

            assert isinstance(schema, dict)
            compiled = self.compiler.compile_json_schema(schema)
        elif type == 'regex_schema':
            compiled = self.compiler.compile_regex(schema)
        elif type == 'json_object':
            compiled = self.compiler.compile_json_schema(schema)
        else:
            assert False, f'Do not support schema type {type}'

        processor = xgr.GrammarMatcher(compiled, terminate_without_stop_token=True)
        self.processors.setdefault(session_id, {})[seq_id] = processor
        logger.info(f'create guided processor for session_id={session_id}, seq_id={seq_id}, and '
                    f'total_processors={len(self.processors)}')
        return processor

    def remove_processor(self, session_id: int):
        if session_id in self.processors:
            del self.processors[session_id]
            logger.info(
                f'delete guided processor for session_id={session_id}, and total_processors={len(self.processors)}')

    def allocate_batched_bitmap(self, batch_size: int) -> torch.Tensor:
        return xgr.allocate_token_bitmask(batch_size, self.vocab_size)

    def fill_bitmap(self, processor: xgr.GrammarMatcher, guided_bitmask: torch.Tensor, index: int) -> None:
        processor.fill_next_token_bitmask(guided_bitmask, index)

    def accept_token(self, processor: xgr.GrammarMatcher, token: int) -> None:
        processor.accept_token(token)

    def apply_batched_bitmap(self, logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None:
        device = logits.device
        dtype = logits.dtype

        if device.type in {'cpu', 'cuda'}:
            xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device))
        else:
            cpu_logits = logits.cpu().float()
            cpu_mask = guided_bitmask.cpu()
            xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask)
            logits.copy_(cpu_logits.to(device, dtype))

    def clear(self) -> None:
        self.processors.clear()
        logger.info(f'clear guided processors, total_processors={len(self.processors)}')


================================================
FILE: lmdeploy/pytorch/engine/input_process.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs

TypeModelMetas = Dict[str, Any]

InputMultiModalType = List[Dict[str, Any]]


@dataclass
class PreprocessInputResult:
    """Results of preprocess input."""
    input_ids: List[int]
    input_multimodals: Optional[MultiModalInputs] = None
    model_metas: Optional[TypeModelMetas] = None


class BaseModelInputProcessor(ABC):
    """Processor of model inputs."""

    @abstractmethod
    def preprocess_input(self,
                         input_ids: List[int],
                         input_mms: InputMultiModalType = None,
                         **kwargs) -> PreprocessInputResult:
        """Preprocess input."""
        raise NotImplementedError('Not implemented.')


class DefaultModelInputProcessor(BaseModelInputProcessor):
    """Default model input processor."""

    def preprocess_input(self,
                         input_ids: List[int],
                         input_mms: MultiModalInputs = None,
                         **kwargs) -> PreprocessInputResult:
        """Preprocess input."""
        return PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=input_mms,
        )


================================================
FILE: lmdeploy/pytorch/engine/inputs_maker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional

import numpy as np
import torch
from torch.profiler import record_function

from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.pytorch.messages import MessageStatus
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, VisionModelInputs
from lmdeploy.utils import get_logger

if TYPE_CHECKING:
    from lmdeploy.pytorch.adapter.adapter import AdapterManager
    from lmdeploy.pytorch.messages import SchedulerSequence
    from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
    from lmdeploy.pytorch.paging import Scheduler
    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy

    from .engine import Engine, SeqList
    from .executor import ExecutorBase

logger = get_logger('lmdeploy')


def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
    """Tensorlize block_offsets."""
    # copy on numpy is faster than torch.nn.utils.rnn.pad_sequence
    batch_size = len(block_offsets)
    max_len = max([len(off) for off in block_offsets])
    out = np.zeros((batch_size, max_len), dtype=block_offsets[0].dtype)

    for idx, off in enumerate(block_offsets):
        off_len = len(off)
        out[idx, :off_len] = off
    return torch.as_tensor(out, dtype=dtype)


@dataclass
class InputsMakerConfig:
    """Input maker config.

    This config is added for Dependency Injection
    """
    max_batches: int
    max_prefill_token_num: int
    role: EngineRole
    is_ssm: bool = False
    dp: int = 1
    spec_decoding: bool = False
    enable_chunked_prefill: bool = False

    @staticmethod
    def from_engine(engine: 'Engine'):
        cache_config = engine.cache_config
        return InputsMakerConfig(
            spec_decoding=engine.specdecode_config is not None,
            max_batches=cache_config.max_batches,
            max_prefill_token_num=cache_config.max_prefill_token_num,
            role=cache_config.role,
            is_ssm=len(cache_config.states_shapes) > 0,
            dp=engine.dist_config.dp,
            enable_chunked_prefill=engine.misc_config.enable_chunked_prefill,
        )


class LongContextChunker:
    """Long context chunker."""

    def __init__(self, max_prefill_token_num: int):
        self.max_prefill_token_num = max_prefill_token_num

        # long prefill seq
        self.clear()

    def enabled(self):
        """Is enabled."""
        return self.seq is not None

    def is_long_context(self, seq: 'SchedulerSequence'):
        """Is long context."""
        return seq.num_token_ids > self.max_prefill_token_num

    def set_seq(self, seq: 'SchedulerSequence'):
        """Set seq."""
        self.seq = seq
        self.next_step = seq.num_history_ids

        # fill multimodals
        # if image size exceeds max_prefill_token_num, enlarge it
        max_prefill_num = self.max_prefill_token_num
        mm = seq.get_input_multimodals()
        self.multimodals = defaultdict(list)
        for key, value in mm.items():
            # sorted by start
            value = sorted(value, key=lambda x: x.start)
            self.multimodals[key] = value
            max_mm_size = max([v.end - v.start for v in value], default=0)
            max_prefill_num = max(max_prefill_num, max_mm_size)

        self.max_prefill_num = max_prefill_num

    def multimodal_iter(self):
        """Multimodal iterator."""
        multimodal_data = []
        for modal_type, modal_datas in self.multimodals.items():
            if len(modal_datas) == 0:
                continue
            multimodal_data += [(modal_type, data) for data in modal_datas]

        multimodal_data = sorted(multimodal_data, key=lambda x: x[1].start)
        for modal_type, data in multimodal_data:
            yield modal_type, data

    def next_chunk_size(self):
        """Get chunk size."""
        seq = self.seq
        if seq is None:
            return 0, None

        llm_chunk_size = min(seq.num_token_ids, self.max_prefill_num)

        if len(self.multimodals) == 0:
            # no vlm inputs found
            return llm_chunk_size, None

        start = seq.num_history_ids
        end = start + llm_chunk_size
        out_multimodals: 'MultiModalInputs' = defaultdict(list)
        for modal_type, mm in self.multimodal_iter():
            assert mm.start >= start, 'multimodal data should be sorted by start'
            if mm.start >= end:
                # | start ... end ... mm.start ... mm.end |
                # if start is beyond threshold, stop
                break

            if mm.end > end:
                # | start ... mm.start ... end ... mm.end |
                # assume multimodals not overlap
                end = mm.start
                break

            # | start ... mm.start ... mm.end ... end |
            out_multimodals[modal_type].append(mm)

        return end - start, out_multimodals

    def is_last_chunk(self):
        """Is last chunk."""
        if self.seq is None:
            return True
        return self.seq.num_token_ids <= self.max_prefill_num

    def clear(self):
        """Clear."""
        self.seq: 'SchedulerSequence' = None
        self.multimodals: MultiModalInputs = defaultdict(list)
        self.next_step: int = 0
        self.max_prefill_num: int = self.max_prefill_token_num

    def update_step(self, inputs: ModelInputs):
        """Step chunker."""
        if self.seq is None:
            return
        if self.is_last_chunk():
            # last chunk should be treated as normal prefill
            return
        assert inputs.is_chunk
        chunk_size = inputs.max_q_seqlen
        self.next_step += chunk_size
        self.seq.set_step(self.next_step)

        # remove used multimodals
        for mms in self.multimodals.values():
            while len(mms) > 0 and mms[0].end <= self.next_step:
                mms.pop(0)
        self.multimodals = dict((k, v) for k, v in self.multimodals.items() if len(v) > 0)

    def check_enable(self):
        if not self.enabled():
            return
        if self.seq.status != MessageStatus.RUNNING:
            self.clear()


class InputsMakerAsync:

    def __init__(
        self,
        executor: 'ExecutorBase',
        scheduler: 'Scheduler',
        adapter_manager: 'AdapterManager',
        engine_strategy: 'EngineStrategy',
        sampling_strategy: 'SamplingStrategy',
        model_agent_strategy: 'ModelAgentStrategy',
        config: InputsMakerConfig,
    ):
        self.executor = executor
        self.scheduler = scheduler
        self.adapter_manager = adapter_manager
        self.config = config
        self.spec_decoding = config.spec_decoding

        # strategies
        self.engine_strategy = engine_strategy
        self.sampling_strategy = sampling_strategy
        self.model_agent_strategy = model_agent_strategy

        self._init_do_prefill(config)

        # record for next forward.
        self.next_is_prefill = True
        self.forward_inputs = None

        # running seqs
        # mark the seqs that have been sent to executor
        self.running_seqs: List['SchedulerSequence'] = []
        self.to_evict_seqs: List['SchedulerSequence'] = []

        # long context chunker
        self.long_context_chunker = LongContextChunker(config.max_prefill_token_num)

    def _init_do_prefill(self, config: InputsMakerConfig):
        if config.role == EngineRole.Prefill:
            self.do_prefill = self.do_prefill_pnode
        elif config.enable_chunked_prefill:
            self.do_prefill = self.do_prefill_chunked
        else:
            self.do_prefill = self.do_prefill_default

    def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs):
        """Create vision model inputs."""
        batch_size = len(messages)

        def __get_vlm_embeddings():
            """Get vlm input embeddings and indexings."""
            max_q_seq_length = model_inputs.seq_length.max().item()
            input_embeddings = [[
                emb.embeddings if isinstance(emb.embeddings, torch.Tensor) else torch.as_tensor(emb.embeddings)
                for emb in msg.input_embeddings
            ] for msg in messages]
            input_embedding_ranges = [
                torch.tensor([[emb.start, emb.end] for emb in msg.input_embeddings]) for msg in messages
            ]
            input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool)
            for msg_id, msg in enumerate(messages):
                num_history_ids = msg.num_history_ids
                for emb in msg.input_embeddings:
                    # make slice index relative to embeddings
                    emb_start = emb.start - num_history_ids
                    emb_end = emb.end - num_history_ids
                    input_embedding_indexing[msg_id][emb_start:emb_end] = True
            return (input_embeddings, input_embedding_indexing, input_embedding_ranges)

        def __has_values(input_multimodals):
            for input_mm in input_multimodals:
                for val in input_mm.values():
                    if len(val) > 0:
                        return True
            return False

        has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages])
        if has_embedding:
            has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages])

        has_multimodal = any([not msg.history_multimodals.empty() for msg in messages])
        input_multimodals = None
        if has_multimodal:
            input_multimodals = [msg.get_input_multimodals() for msg in messages]
            has_multimodal = __has_values(input_multimodals)
            if not has_multimodal:
                # no multimodal inputs
                input_multimodals = None

        if not has_embedding and not has_multimodal:
            # no vision inputs
            return None

        if has_embedding:
            # for inputs with embeddings
            (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings()
        else:
            input_embeddings = None
            input_embedding_indexing = None
            input_embedding_ranges = None

        history_lengths = model_inputs.history_lengths
        vision_embedding_inputs = VisionModelInputs(history_lengths=history_lengths,
                                                    input_embeddings=input_embeddings,
                                                    input_embedding_indexing=input_embedding_indexing,
                                                    input_embedding_ranges=input_embedding_ranges,
                                                    input_multimodals=input_multimodals)
        return vision_embedding_inputs

    @property
    def torch_int_dtype(self):
        """Return int32 for cuda, int64 for others."""
        if self.executor.device_type == 'cuda':
            return torch.int32
        return torch.int64

    def _set_adapter_ids(self, model_inputs: ModelInputs, messages: 'SeqList'):
        """Set adapter ids to model inputs."""
        if self.adapter_manager.num_adapters() <= 1:
            return
        adapter_names = [msg.adapter_name for msg in messages]
        local_adapter_ids = self.adapter_manager.get_adapter_ids(adapter_names)
        local_adapter_ids = model_inputs.seq_length.new_tensor(local_adapter_ids)
        model_inputs.local_adapter_ids = local_adapter_ids

    @torch.inference_mode()
    @record_function('create_model_inputs')
    def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
        """Create model inputs from messages.

        Args:
            messages (SeqList): The input messages.
        """
        batch_size = len(messages)
        # history lengths
        history_lengths = torch.tensor([msg.num_history_ids for msg in messages])

        # input ids
        token_ids = [msg.token_ids for msg in messages]

        input_ids = torch.as_tensor(np.concatenate(token_ids))[None]

        # seqlens
        is_decoding = not is_prefill
        if not is_decoding:
            seq_length = [len(tokens) for tokens in token_ids]
            seq_length = torch.tensor(seq_length, dtype=torch.long)
            max_q_seqlen = seq_length.max().item()
        else:
            max_q_seqlen = len(token_ids[0])
            seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long)
        kv_seqlens = seq_length + history_lengths
        max_kv_seqlen = kv_seqlens.max().item()
        sum_kv_seqlen = kv_seqlens.sum().item()

        # block offsets
        block_offsets = self.scheduler.get_block_tables(messages)
        block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)

        # num_ignored_history
        num_ignored_history = torch.tensor([msg.num_ignored_history for msg in messages])

        # model_metas
        model_metas = [msg.model_meta for msg in messages]

        # create model inputs for all required fields
        model_inputs = ModelInputs(
            input_ids=input_ids,
            seq_length=seq_length,
            history_lengths=history_lengths,
            block_offsets=block_offsets,
            is_decoding=is_decoding,
            num_ignored_history=num_ignored_history,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            model_metas=model_metas,
        )

        # adapters
        self._set_adapter_ids(model_inputs, messages)

        # vision inputs
        vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs)
        model_inputs.vision_inputs = vision_model_inputs

        # ssm
        if self.config.is_ssm:
            state_offsets = torch.tensor([msg.logical_state for msg in messages])
            model_inputs.state_offsets = state_offsets

        return model_inputs

    @torch.inference_mode()
    @record_function('create_model_inputs_long_context')
    def create_model_inputs_long_context(self,
                                         seq: 'SchedulerSequence',
                                         chunk_size: int,
                                         multimodals: Optional['MultiModalInputs'] = None):
        """Create model inputs for long context messages."""
        token_ids = seq.token_ids[:chunk_size]
        input_ids = torch.as_tensor(token_ids)[None]
        q_seqlens = torch.tensor([chunk_size])
        history_lens = torch.tensor([seq.num_history_ids])

        # block offsets
        block_offsets = self.scheduler.get_block_tables([seq])
        block_offsets = torch.as_tensor(block_offsets[0], dtype=self.torch_int_dtype)[None]

        # num_ignored_history
        num_ignored_history = torch.tensor([seq.num_ignored_history])

        # model_metas
        model_metas = [seq.model_meta]

        kv_seqlens = q_seqlens + history_lens
        max_kv_seqlen = kv_seqlens.item()
        sum_kv_seqlen = max_kv_seqlen

        model_inputs = ModelInputs(
            input_ids=input_ids,
            seq_length=q_seqlens,
            history_lengths=history_lens,
            block_offsets=block_offsets,
            is_decoding=False,
            num_ignored_history=num_ignored_history,
            max_q_seqlen=q_seqlens.item(),
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            model_metas=model_metas,
            is_chunk=True,
        )

        # adapters
        self._set_adapter_ids(model_inputs, [seq])

        # vision inputs
        if multimodals is not None and len(multimodals) > 0:
            vision_model_inputs = VisionModelInputs(
                history_lengths=model_inputs.history_lengths,
                input_multimodals=[multimodals],
            )
            model_inputs.vision_inputs = vision_model_inputs

        # ssm
        if self.config.is_ssm:
            model_inputs.state_offsets = torch.tensor([seq.logical_state])

        return model_inputs

    @torch.inference_mode()
    @record_function('create_model_inputs_delta')
    def create_model_inputs_delta(self):
        """Create model inputs delta from messages."""
        batch_size = len(self.running_seqs)
        assert batch_size > 0
        num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
        max_q_seqlen = num_decode_tokens
        prealloc_size = self.engine_strategy.get_prealloc_size(True)
        valid_mask = self.scheduler.schedule_running(self.running_seqs,
                                                     num_decode_tokens=num_decode_tokens,
                                                     prealloc_size=prealloc_size)

        valid_mask = np.array(valid_mask)
        indices_cpu = np.arange(0, batch_size)[valid_mask]
        valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
        invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
        if len(valid_seqs) == 0:
            return None, valid_seqs, invalid_seqs

        # block offsets
        block_offsets = self.scheduler.get_block_tables(valid_seqs)
        block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)

        # sliding window
        if self.scheduler.cache_config.window_size > 0:
            num_ignored_history = torch.tensor([msg.num_ignored_history for msg in valid_seqs])
        else:
            num_ignored_history = torch.zeros(len(valid_seqs), dtype=torch.long)

        kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
        sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
        max_kv_seqlen = max(kv_seqlens) + max_q_seqlen

        output = ModelInputsDelta(
            indices=None,
            block_offsets=block_offsets,
            indice_cpu=indices_cpu,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            num_ignored_history=num_ignored_history,
        )

        return output, valid_seqs, invalid_seqs

    def create_model_inputs_delta_valid_only(self):
        """Create model inputs delta for valid running seqs only.

        Only check validation, no resources will be scheduled.
        """
        from lmdeploy.pytorch.messages import MessageStatus
        batch_size = len(self.running_seqs)

        valid_mask = [seq.status == MessageStatus.RUNNING for seq in self.running_seqs]
        if all(valid_mask):
            return None, self.running_seqs, []

        valid_mask = np.array(valid_mask, dtype=bool)
        indices_cpu = np.arange(0, batch_size)[valid_mask]
        valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
        invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]

        num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
        max_q_seqlen = num_decode_tokens
        kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
        if len(kv_seqlens) == 0:
            sum_kv_seqlen = 0
            max_kv_seqlen = 0
        else:
            sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
            max_kv_seqlen = max(kv_seqlens) + max_q_seqlen

        output = ModelInputsDelta(
            indices=None,
            block_offsets=None,
            indice_cpu=indices_cpu,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            num_ignored_history=None,
        )

        return output, valid_seqs, invalid_seqs

    def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]):
        """Update running seqs."""
        if self.config.role == EngineRole.Prefill:
            # p node will not update running seqs
            return

        is_decoding = inputs is None
        if self.long_context_chunker.enabled() and not is_decoding:
            # long context chunk does not need to update running seqs
            self.long_context_chunker.update_step(inputs)
            return

        if is_decoding:
            self.running_seqs = running
        else:
            self.running_seqs += running

    def deactivate_evict_seqs(self):
        """Deactivate and evict seqs."""
        scheduler = self.scheduler
        to_evict_seqs = self.to_evict_seqs
        if len(to_evict_seqs) == 0:
            return
        # deactivate seqs(running -> ready)
        scheduler.deactivate_seqs(to_evict_seqs)
        # ready to waiting
        scheduler.evict_seqs(to_evict_seqs)
        self.to_evict_seqs.clear()

    @torch.inference_mode()
    @record_function('make_forward_inputs')
    def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False):
        """Make forward inputs for ModelAgent._async_step_background()"""

        def __need_logits(seqs: 'SeqList'):
            """Need logits."""
            if self.spec_decoding:
                return True
            return any(seq.return_logits for seq in seqs)

        def __need_routed_experts(seqs: 'SeqList'):
            """Need routed experts."""
            return any(seq.return_routed_experts for seq in seqs)

        def __create_model_inputs(seqs):
            """Createe model inputs."""
            inputs = self.create_model_inputs(seqs, True)
            delta, valid_seqs, _ = self.create_model_inputs_delta_valid_only()
            self.running_seqs = valid_seqs
            extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs)
            return inputs, delta, extra_inputs

        def __create_inputs_chunk(running: 'SeqList'):
            chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
            inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals)
            extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs)
            return inputs, extra_inputs

        def __create_inputs_long_context_chunk():
            seq = self.long_context_chunker.seq
            running = [seq]
            if self.long_context_chunker.is_last_chunk():
                inputs, delta, extra_inputs = __create_model_inputs(running)
                self.long_context_chunker.clear()
            else:
                inputs, extra_inputs = __create_inputs_chunk(running)
                delta = None
            inputs.is_first_chunk = False
            return running, inputs, delta, extra_inputs

        def __create_inputs_prefill():
            if self.config.role == EngineRole.Prefill:
                prealloc_size = 0
            else:
                prealloc_size = self.engine_strategy.get_prealloc_size(True)
            scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size)
            running = scheduler_output.running
            swap_in_map = scheduler_output.swap_in_map
            swap_out_map = scheduler_output.swap_out_map

            inputs = None
            delta = None
            extra_inputs = None
            if len(running) == 1 and self.long_context_chunker.is_long_context(running[0]):
                # set long context chunker
                self.long_context_chunker.set_seq(running[0])
                inputs, extra_inputs = __create_inputs_chunk(running)
            elif len(running) > 0:
                # create inputs
                inputs, delta, extra_inputs = __create_model_inputs(running)
            return running, inputs, delta, extra_inputs, swap_in_map, swap_out_map

        scheduler = self.scheduler
        logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')

        inputs = None
        delta = None
        swap_in_map = {}
        swap_out_map = {}

        self.long_context_chunker.check_enable()
        if self.long_context_chunker.enabled():
            # long context chunking
            running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
        elif prefill:
            # prefill
            (
                running,
                inputs,
                delta,
                extra_inputs,
                swap_in_map,
                swap_out_map,
            ) = __create_inputs_prefill()

        # try decoding
        if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:
            prefill = False
            delta, running, invalid_seqs = self.create_model_inputs_delta()
            self.to_evict_seqs = invalid_seqs
            extra_inputs = None

        # skip if enable empty
        if inputs is None and delta is None:
            return None

        sampling_inputs = self.sampling_strategy.make_sampling_inputs(running)
        if inputs is not None:
            stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)
        else:
            stopping_criteria = None

        return_logits = __need_logits(running)
        return_routed_experts = __need_routed_experts(running)

        return dict(
            running=running,
            inputs=inputs,
            delta=delta,
            swap_in_map=swap_in_map,
            swap_out_map=swap_out_map,
            sampling_inputs=sampling_inputs,
            stopping_criteria=stopping_criteria,
            return_logits=return_logits,
            extra_inputs=extra_inputs,
            return_routed_experts=return_routed_experts,
        )

    def do_prefill_pnode(self):
        return True

    def do_prefill_default(self):
        # decoding if no waiting
        scheduler = self.scheduler

        # do decoding if not waiting
        if not scheduler.has_waiting():
            return False

        # do prefill if too much tokens
        waiting = scheduler.waiting
        token_count = 0
        for seq in waiting:
            token_count += seq.num_token_ids
            if token_count >= self.config.max_prefill_token_num:
                return True

        # prefill if no enough running
        num_ready = scheduler.num_ready()
        num_running = scheduler.num_running()
        max_batches = self.config.max_batches
        if num_ready + num_running < max_batches * 0.5:
            return True

        # decoding
        return False

    def do_prefill_chunked(self):
        """Chunked prefill strategy.

        both dp=1 and dp>1 are supported.
        """
        scheduler = self.scheduler
        return not scheduler.has_ready()

    async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool = False):
        forward_inputs = self._make_forward_inputs(prefill, enable_empty)
        if forward_inputs is None:
            return None, None
        next_running = forward_inputs.pop('running')
        inputs = forward_inputs['inputs']
        if logger.level <= logging.DEBUG and inputs is not None:
            logger.debug(f'Sending forward inputs: {inputs.log_info()}')
            session_ids = [seq.session_id for seq in next_running]
            logger.debug(f'Forward session_ids: {session_ids}')
        await self.executor.forward_async(forward_inputs)
        self.forward_inputs = forward_inputs
        return forward_inputs, next_running

    async def send_next_inputs(self):
        prefill = self.do_prefill()
        return await self._send_next_inputs_impl(prefill)

    async def prefetch_next_inputs(self):
        prefill = self.do_prefill()
        # send next forward
        logger.debug('Prefetching next forward inputs.')
        return await self._send_next_inputs_impl(prefill, True)


def build_inputs_maker(engine: 'Engine'):
    """Build inputs makers."""
    config = InputsMakerConfig.from_engine(engine)
    return InputsMakerAsync(
        executor=engine.executor,
        scheduler=engine.scheduler,
        adapter_manager=engine.adapter_manager,
        engine_strategy=engine.engine_strategy,
        sampling_strategy=engine.sampling_strategy,
        model_agent_strategy=engine.model_agent_strategy,
        config=config,
    )


================================================
FILE: lmdeploy/pytorch/engine/logits_process.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from dataclasses import dataclass, fields
from functools import lru_cache
from typing import Any

import numpy as np
import torch

from lmdeploy.messages import LogitsProcessor
from lmdeploy.pytorch import envs

from ..messages import SchedulerSequence
from .guided_process import GuidedDecodingManager


def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor):
    """Process temperature."""
    temperature = temperature.to(scores.dtype)
    scores.div_(temperature[:, None])
    return scores


def _process_bad_words_(scores: torch.Tensor,
                        bad_words: torch.Tensor,
                        mask: torch.Tensor,
                        filter_value: float = -float('inf')):
    """Apply bad-word filtering to token scores.

    This function updates ``scores`` in place by setting the scores of
    "bad" token indices to ``filter_value``.
    Args:
        scores (torch.Tensor): A tensor of shape ``[batch_size, vocab_size]``
            containing the logits or scores for each token in the vocabulary.
        bad_words (torch.Tensor): A tensor of shape
            ``[batch_size, num_bad_words]`` containing token indices that
            should be suppressed. Invalid or masked positions may contain
            negative values; these entries are ignored and not used as
            indices into ``scores``.
        mask (torch.Tensor): A boolean tensor with the same shape as
            ``bad_words``. Positions with ``True`` indicate that the
            corresponding entry in ``bad_words`` is a valid bad-word index
            that should be filtered. Positions with ``False`` are treated as
            invalid/masked and are not applied to ``scores``.
        filter_value (float, optional): The value to assign to the scores of
            bad-word tokens. Defaults to ``-float('inf')``.
    Returns:
        torch.Tensor: The ``scores`` tensor after bad-word filtering has
        been applied.
    """
    # invalid badwords might be negative
    valid_bad_words = bad_words.where(mask, 0)
    filtered_scores = scores.gather(1, valid_bad_words)
    filtered_scores[mask] = filter_value
    scores.scatter_(1, valid_bad_words, filtered_scores)
    return scores


def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor):
    """Process repetition penalty."""
    score = torch.gather(scores, 1, input_ids)
    penalty = penalty.to(score.dtype)
    score = torch.where(score < 0, score * penalty[:, None], score / penalty[:, None])
    scores.scatter_(1, input_ids, score)
    return scores


def _filter_topk_sorted_(scores: torch.Tensor, topk: torch.LongTensor, filter_value: float = -float('inf')):
    """Filter topk on sorted scores."""
    filter_value = -float('inf')
    num_tokens = scores.size(1)
    token_idx = torch.arange(num_tokens, device=scores.device)
    mask = token_idx[None, :] >= topk[:, None]
    scores.masked_fill_(mask, filter_value)
    return scores


def _filter_topp_sorted_(scores: torch.Tensor, topp: torch.Tensor, filter_value: float = -float('inf')):
    """Filter topp on sorted scores."""
    softmax_scores = scores.softmax(-1)
    cum_scores = softmax_scores.cumsum(1) - softmax_scores
    mask = cum_scores > topp[:, None]
    mask[:, 0] = False  # keep at least one
    scores.masked_fill_(mask, filter_value)
    return scores


def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: float = -float('inf')):
    """Filter minp on sorted scores."""
    softmax_scores = scores.softmax(-1)
    top_probs, _ = softmax_scores.max(dim=-1, keepdim=True)
    scaled_min_p = minp.unsqueeze(dim=1) * top_probs
    mask = softmax_scores < scaled_min_p
    scores.masked_fill_(mask, filter_value)
    return scores


@lru_cache
def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1):
    return torch.ones(fill, dtype=dtype, device=device)


def ngram(
    token_ids: torch.Tensor,
    n: torch.Tensor | None,
    threshold: torch.Tensor,
    max_n: int,
    max_window_size: int,
):
    """Compute n-gram matches between sliding windows and a target sequence.

    For each batch, performs cosine similarity checking between:
      - All sliding windows of length `max_n` from the full sequence
      - The last `max_n` tokens of the sequence (target window)

    A match is counted when both:
      1. Cosine similarity ≈ 1 (normalized vectors match)
      2. Vector lengths match (preventing zero/normalization artifacts)

    Parameters
    ----------
    token_ids : torch.Tensor
        Input token IDs of shape (batch_size, seq_len).
        Values are typically ≥0 (0 may represent padding/special tokens).
    n : torch.Tensor
        Effective n-gram length for each batch element, shape (batch_size,).
        When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked.
    threshold : torch.Tensor
        Minimum number of matching windows required for validity, shape (batch_size,).
    max_n : int
        Maximum n-gram length (window size for matching).
    max_window_size: int
        Maximum window size for matching.

    Returns
    -------
    matched_mask : torch.Tensor
        Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating
        which sliding windows match the target n-gram.
    found : torch.Tensor
        Boolean tensor of shape (batch_size,) indicating whether each batch
        element has at least `threshold` matches.
    """

    batch_size, seq_len = token_ids.size()
    if seq_len < max_n:
        # Not enough tokens to form a single n-gram
        matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device)
        found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device)
        return matched_mask, found
    # token_ids could be 0, so we add 2 to avoid div 0
    token_ids = (token_ids + 2).to(torch.float32).log2()

    # Trim to max_window_size
    if seq_len >= max_window_size:
        token_ids = token_ids[:, -max_window_size:]
    max_window_size = token_ids.size(1)

    # normalize ids
    # we would set n=None if n shared same value. Read lmdeploy/pytorch/strategies/ar/sampling.py for more details
    same_n = n is None
    norm = token_ids[:, -max_n:]
    if not same_n:
        # fill 0 for n < max_n
        mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1))
        norm = norm * mask.to(torch.float32)
    norm = norm.norm(2, dim=-1, keepdim=True)
    normed_ids = token_ids / norm

    # concate p1 and p2 so we can check distance and vector in one conv1d
    normed_n_ids = normed_ids[:, -max_n:]
    normed_ids_p2 = normed_ids * normed_ids
    ones_ids = torch.ones_like(normed_n_ids)
    if not same_n:
        # fill 0 for n < max_n
        normed_n_ids = normed_n_ids * mask.to(torch.float32)
        ones_ids = ones_ids * mask.to(torch.float32)
    normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0)
    normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)

    # check cos distance & check vector length
    match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]
    match_norm, match_ones = match_norm.chunk(2, dim=0)

    # both match result should be close to 1
    one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1)
    matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor)

    # threshold
    count = matched_mask.sum(-1)
    found = (count >= threshold) & (threshold > 0)

    return matched_mask, found


def _filter_repetition_ngram_(
    scores: torch.Tensor,
    stop_words: torch.Tensor,
    generated_ids: torch.Tensor,
    n: torch.Tensor | None,
    threshold: torch.Tensor,
    max_n: int,
    max_ngram_window_size: int,
):
    """Filter ngram.

    if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist.
    """
    if stop_words is None or stop_words.numel() == 0:
        return scores
    # use first stop words
    _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size)
    stop_words = stop_words[:, 0]
    # fill all scores -inf
    scores.masked_fill_(found[:, None], -float('inf'))
    # set stop words to 0
    stop_scores = scores.gather(1, stop_words[:, None])
    stop_scores.masked_fill_(found[:, None], 0)
    scores.scatter_(1, stop_words[:, None], stop_scores)
    return scores


def _multinomial_sampling(scores: torch.Tensor,
                          seeds: torch.LongTensor,
                          offsets: torch.LongTensor,
                          indices: torch.LongTensor = None):
    """sampling."""
    from lmdeploy.pytorch.nn.multinomial_sampling import multinomial_sampling
    return multinomial_sampling(scores, seeds, offsets, indices)


SeqList = list[SchedulerSequence]


@dataclass
class SamplingInputsDelta:
    num_ignore_eos: torch.Tensor = None
    random_offsets: torch.Tensor = None
    all_ids: None | torch.Tensor = None


@dataclass
class SamplingInputs:
    temperature: torch.Tensor = None
    bad_words: torch.LongTensor = None
    bad_mask: torch.BoolTensor = None
    stop_words: torch.LongTensor = None
    stop_mask: torch.BoolTensor = None
    repetition_penalty: torch.Tensor = None
    top_k: torch.LongTensor = None
    top_p: torch.Tensor = None
    min_p: torch.Tensor = None
    random_seeds: torch.Tensor = None
    random_offsets: torch.Tensor = None
    max_top_k: int = 1
    min_top_p: float = 1.0
    response_formats: list[str, ...] = ()
    logits_processors: list[list[LogitsProcessor]] = None
    max_num_logprobs: None | int = None
    all_ids: None | torch.Tensor = None
    num_ignore_eos: torch.Tensor = None
    batch_size: int = 0
    session_ctx: None | list[dict[str, Any]] = None
    session_to_cleanup: None | list[int] = None
    # for repetition_penalty and ngram
    generated_ids: torch.Tensor | None = None
    generated_ids_cpu: np.ndarray | None = None

    # n gram
    repetition_ngram_size: torch.Tensor | None = None
    repetition_ngram_threshold: torch.Tensor | None = None
    max_repetition_ngram_size: int = 0

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        out_dict = dict()
        if self.generated_ids is None and self.generated_ids_cpu is not None:
            self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor):
                v = v.to(device, non_blocking=non_blocking)
            out_dict[k] = v

        return SamplingInputs(**out_dict)

    def get_delta(self) -> SamplingInputsDelta:
        """Get delta."""
        delta = SamplingInputsDelta()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor):
                setattr(delta, k, v)
        return delta

    def update_delta(self, delta: SamplingInputsDelta):
        """Update from delta."""
        for f in fields(delta):
            k = f.name
            v = getattr(delta, k)
            if v is not None:
                setattr(self, k, v)


def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits):
    """Apply custom logits processors."""
    for seq_id, processors in enumerate(batched_logits_processors):
        if processors is not None:
            for processor in processors:
                logits[seq_id] = processor(all_ids[seq_id], logits[seq_id])
    return logits


def _torch_topk(x: torch.Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True):
    if k == 1:
        # torch.topk would not fallback to torch.max/torch.min automatically
        if largest:
            return torch.max(x, dim=dim, keepdim=True)
        else:
            return torch.min(x, dim=dim, keepdim=True)
    else:
        return torch.topk(x, k, dim=dim, largest=largest, sorted=sorted)


class FusedLogitsProcessor:
    """Custom logits processor."""

    def __init__(
        self,
        sampling_inputs: SamplingInputs,
        logprobs_mode: None | str = None,
        guided_decoding_manager: None | GuidedDecodingManager = None,
    ):
        self.sampling_inputs: SamplingInputs = sampling_inputs
        self.logprobs_mode = logprobs_mode
        self.guided_decoding_manager = guided_decoding_manager
        if sampling_inputs.session_to_cleanup:
            self.cleanup_sessions(sampling_inputs.session_to_cleanup)

        if self.guided_decoding_manager:
            self.guided_processors = self.guided_decoding_manager.get_processors(sampling_inputs.session_ctx,
                                                                                 sampling_inputs.response_formats)
        else:
            self.guided_processors = {}

    async def _wait_stream_once(self):
        """Wait stream once."""
        stream = torch.cuda.current_stream()
        if not stream.query():
            await asyncio.sleep(0)

    async def __call__(self, scores: torch.Tensor) -> torch.Tensor:
        r"""
        Args:
            scores (torch.Tensor):
                Prediction scores of a language modeling head.
                These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token
                when using beam search


        Return:
            torch.Tensor: The processed prediction scores.

        """

        num_logprobs = self.sampling_inputs.max_num_logprobs
        # get raw logprobs
        if num_logprobs < 0:
            logprobs = None
        else:
            if self.logprobs_mode == 'raw_logits':
                logprobs = scores.clone()
            elif self.logprobs_mode == 'raw_logprobs':
                logprobs = scores.log_softmax(dim=-1)
            else:
                logprobs = None

        sampling_inputs = self.sampling_inputs
        all_ids = sampling_inputs.all_ids
        custom_logits_processors = self.sampling_inputs.logits_processors
        if self.guided_decoding_manager and self.guided_processors:
            if not hasattr(self, 'guided_bitmask'):
                self.guided_bitmask = self.guided_decoding_manager.allocate_batched_bitmap(len(scores))

            assert self.guided_bitmask is not None
            guided_bitmask = self.guided_bitmask

            await self._wait_stream_once()
            for i, processor in self.guided_processors.items():
                self.guided_decoding_manager.fill_bitmap(processor, guided_bitmask, i)

            self.guided_decoding_manager.apply_batched_bitmap(scores, guided_bitmask)

        if any(custom_logits_processors):
            await self._wait_stream_once()
            scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores)

        repetition_penalty = sampling_inputs.repetition_penalty
        if repetition_penalty is not None:
            generated_ids = sampling_inputs.generated_ids
            scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)

        if sampling_inputs.max_repetition_ngram_size > 0:
            generated_ids = sampling_inputs.generated_ids
            assert generated_ids is not None
            assert sampling_inputs.repetition_ngram_threshold is not None
            max_repetition_ngram_window_size = envs.repetition_window_size
            scores = _filter_repetition_ngram_(
                scores,
                sampling_inputs.stop_words,
                generated_ids,
                sampling_inputs.repetition_ngram_size,
                sampling_inputs.repetition_ngram_threshold,
                sampling_inputs.max_repetition_ngram_size,
                max_repetition_ngram_window_size,
            )

        temperature = sampling_inputs.temperature
        if temperature is not None:
            scores = _process_temperature_(scores, temperature)

        bad_words = sampling_inputs.bad_words
        if bad_words is not None:
            bad_mask = sampling_inputs.bad_mask
            scores = _process_bad_words_(scores, bad_words, bad_mask)

        stop_words = sampling_inputs.stop_words
        if stop_words is not None:
            ignore_eos = sampling_inputs.num_ignore_eos > 0
            stop_mask = sampling_inputs.stop_mask
            stop_mask = torch.where(ignore_eos[:, None], stop_mask, False)
            scores = _process_bad_words_(scores, stop_words, stop_mask)

        return scores, logprobs

    @torch.inference_mode()
    def sampling(self, logits: torch.Tensor):
        """sampling."""
        sampling_inputs = self.sampling_inputs

        def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
            """Random sampling."""
            max_topk = sampling_inputs.max_top_k
            top_k = sampling_inputs.top_k
            if max_topk <= 0:
                max_topk = scores.size(1)
                if top_k is not None:
                    top_k = torch.masked_fill(top_k, top_k <= 0, max_topk)

            if top_k is not None:
                scores = _filter_topk_sorted_(scores, top_k)

            top_p = sampling_inputs.top_p
            if top_p is not None:
                scores = _filter_topp_sorted_(scores, top_p)

            min_p = sampling_inputs.min_p
            if min_p is not None:
                scores = _filter_minp_sorted_(scores, min_p)

            softmax_scores = scores.softmax(1)

            seeds = sampling_inputs.random_seeds
            offsets = sampling_inputs.random_offsets
            return _multinomial_sampling(softmax_scores, seeds, offsets, indices)

        if sampling_inputs.max_top_k == 1:
            result = logits.argmax(-1)
        else:
            # sort logits is too slow. and we only need topk logits
            max_topk = sampling_inputs.max_top_k
            if max_topk <= 0:
                scores, indices = logits.sort(1, descending=True)
            else:
                scores, indices = _torch_topk(logits, max_topk, dim=1)
            result = __random_sampling(scores, indices)

        if self.guided_decoding_manager and self.guided_processors:
            for i, processor in self.guided_processors.items():
                self.guided_decoding_manager.accept_token(processor, result[i])

        return result

    @torch.inference_mode()
    def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor):
        """Compute logprobs."""
        if raw_logprobs is None:
            return None

        indices = token_ids.unsqueeze(-1)
        logprobs = raw_logprobs.gather(-1, indices)
        num_logprobs = self.sampling_inputs.max_num_logprobs
        if num_logprobs > 0:
            topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1)
            logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
            indices = torch.cat([indices, topk_indices], dim=-1)

        return logprobs, indices.to(torch.int32)

    def cleanup_sessions(self, session_ids: list[int]):
        if self.guided_decoding_manager:
            for session_id in session_ids:
                self.guided_decoding_manager.remove_processor(session_id)


================================================
FILE: lmdeploy/pytorch/engine/model_agent/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager

from .agent import BaseModelAgent, BatchedOutputs  # noqa: F401


def build_model_agent(
    model_path: str,
    model_config: ModelConfig,
    cache_config: CacheConfig,
    backend_config: BackendConfig,
    misc_config: MiscConfig,
    dist_ctx: DistContext = None,
    device_ctx: DeviceContext = None,
    adapters: Dict[str, str] = None,
    specdecode_config: SpecDecodeConfig = None,
):
    """Create model agent.

    Args:
        model_path (str): the path of the input model
        cache_config (CacheConfig): config of kv cache
        backend_config (BackendConfig): config of backend devices
        trust_remote_code (bool): To use the remote modeling code or not
        adapters (Dict): lora adapters
        tp (int): the number of devices to be used in tensor parallelism
        dtype (str): the data type of model weights and activations
        custom_module_map (str): customized nn module map
    """

    if device_ctx is None:
        device_mgr = get_device_manager()
        device_ctx = device_mgr.current_context()
    if dist_ctx is None:
        dist_mgr = get_dist_manager()
        dist_ctx = dist_mgr.current_context()

    model_agent = BaseModelAgent(
        model_path,
        model_config=model_config,
        cache_config=cache_config,
        backend_config=backend_config,
        misc_config=misc_config,
        adapters=adapters,
        dist_ctx=dist_ctx,
        device_ctx=device_ctx,
        specdecode_config=specdecode_config,
    )
    return model_agent


================================================
FILE: lmdeploy/pytorch/engine/model_agent/agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import time
from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from multiprocessing.reduction import ForkingPickler
from os import getenv
from typing import Any, Dict, List, Optional

import numpy as np
import pybase64
import torch
import torch.distributed as dist
from torch.profiler import record_function

from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager
from lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine
from lmdeploy.pytorch.engine.guided_process import GuidedDecodingManager
from lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs, SamplingInputsDelta
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, step_ctx_manager
from lmdeploy.pytorch.models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map
from lmdeploy.pytorch.spec_decode import build_spec_agent
from lmdeploy.pytorch.strategies import build_strategy_factory
from lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
from lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks
from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights
from lmdeploy.serve.openai.protocol import UpdateParamsRequest
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger

from .inputs_maker import build_inputs_maker
from .profiler import AgentProfiler

logger = get_logger('lmdeploy')


@dataclass
class SleepWakeupState:
    to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
    to_wakeup: asyncio.Event = field(default_factory=asyncio.Event)
    is_sleeping: bool = False


@dataclass
class BatchedLogProbs:
    vals: torch.Tensor
    indices: torch.Tensor

    def to_cpu(self):
        """To cpu."""
        return BatchedLogProbs(vals=self.vals.cpu(), indices=self.indices.cpu())

    def to_numpy(self):
        """To numpy."""
        if self.vals.dtype == torch.bfloat16:
            np_vals = self.vals
        else:
            np_vals = self.vals.detach().numpy()
        return BatchedLogProbs(vals=np_vals, indices=self.indices.detach().numpy())

    def to_tensor(self):
        """To tensor."""
        if isinstance(self.vals, torch.Tensor):
            vals = self.vals
        else:
            vals = torch.from_numpy(self.vals)
        return BatchedLogProbs(vals=vals, indices=torch.from_numpy(self.indices))


@dataclass
class BatchedOutputs:
    next_token_ids: torch.Tensor
    stopped: torch.Tensor
    stop_pos: Optional[torch.Tensor] = None
    logits: Optional[torch.Tensor] = None
    model_metas: List[Dict[str, Any]] = None
    logprobs: Optional[BatchedLogProbs] = None
    new_token_timestamp: int = 0
    extra_outputs: Optional[ExtraOutputs] = None
    all_routed_experts: Optional[torch.Tensor] = None

    def to_cpu(self):
        """To cpu."""
        out = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor):
                v = v.cpu()
            elif hasattr(v, 'to_cpu'):
                v = v.to_cpu()
            out[k] = v
        return BatchedOutputs(**out)

    def to_numpy(self):
        """To numpy."""
        out = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16:
                v = v.detach().numpy()
            elif hasattr(v, 'to_numpy'):
                v = v.to_numpy()
            out[k] = v
        return BatchedOutputs(**out)

    def to_tensor(self):
        """To tensor."""
        out = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, np.ndarray):
                v = torch.from_numpy(v)
            elif hasattr(v, 'to_tensor'):
                v = v.to_tensor()
            out[k] = v
        return BatchedOutputs(**out)


def msg_with_rank(rank: int, msg: str):
    """Return message with rank."""
    return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
    """Perform cache swapping."""
    issued_cache_op = False
    swap_in_map = swap_in_map or dict()
    swap_out_map = swap_out_map or dict()
    if len(swap_in_map) > 0:
        cache_engine.swap_in(swap_in_map)
        issued_cache_op = True
    if len(swap_out_map) > 0:
        cache_engine.swap_out(swap_out_map)
        issued_cache_op = True

    if issued_cache_op:
        cache_engine.events.wait()


@torch.inference_mode()
def model_forward(
    model: torch.nn.Module,
    inputs: ModelInputs,
    cache_engine: CacheEngine,
    state_cache_engine: StateCacheEngine,
    stream: torch.cuda.Stream = None,
):
    """Perform model forward."""
    stream = stream or torch.cuda.current_stream()
    with torch.cuda.stream(stream), step_ctx_manager(model.ctx_mgr):
        # forward
        ctx_mgr = model.ctx_mgr
        context = ctx_mgr.build_context(
            inputs=inputs,
            model_config=cache_engine.model_config,
            cache_config=cache_engine.cache_config,
            kv_caches=cache_engine.gpu_cache,
            state_caches=state_cache_engine.state_caches,
            kv_quant_policy=cache_engine.cache_config.quant_policy,
        )

        with ctx_mgr.context(context):
            model_metas = model.update_model_metas(
                past_key_values=cache_engine.gpu_cache,
                context=context,
            )
            input_dict = model.prepare_inputs_for_generation(
                past_key_values=cache_engine.gpu_cache,
                context=context,
            )
            output = model(**input_dict)
            if not isinstance(output, Dict):
                output = dict(hidden_states=output)
            # InternVL-3.5-Flash will change the seqlen, model_metas during forward
            if getattr(context, 'is_model_meta_updated', False):
                model_metas = context.model_metas
            output['model_metas'] = model_metas
            output['seq_length'] = context.q_seqlens[:len(inputs.seq_length)]
            # for draft model reuse
            output['position_ids'] = context.position_ids
            return output


def _try_to_cuda(val, non_blocking: bool = False):
    if val is None:
        return val
    elif isinstance(val, torch.Tensor):
        return val.cuda(non_blocking=non_blocking)
    elif hasattr(val, 'to_device'):
        return val.to_device('cuda', non_blocking=non_blocking)
    else:
        raise RuntimeError(f'Can not cast {type(val)} to cuda.')


class DistGatherScalar:
    """Distribute value gather."""

    def __init__(self, val, size: int, device: str = 'cpu', group: dist.ProcessGroup = None):
        self.val = val
        self.device = device
        self.group = group

        self.all_vals = torch.tensor([val] * size, device=device)
        self.worker = dist.all_gather_into_tensor(self.all_vals,
                                                  self.all_vals.new_tensor([val]),
                                                  group=group,
                                                  async_op=True)

    async def async_wait(self, timeout: float = 0.001):
        while not self.worker.is_completed():
            await asyncio.sleep(timeout)
        self.worker.wait()
        return self.all_vals


SwapMap = Dict[int, int]


@dataclass
class StepInputs:
    """Step inputs."""
    model_inputs: ModelInputs = None
    extra_inputs: ExtraInputs = None
    stopping_criteria: StoppingCriteria = None
    sampling_delta: SamplingInputsDelta = None

    @record_function('StepInputs.merge')
    def merge(
        self,
        inputs: ModelInputs,
        extra_inputs: ExtraInputs,
        stopping_criteria: StoppingCriteria,
        sampling_delta: SamplingInputsDelta,
        next_token_ids: torch.Tensor,
        model_metas,
        extra_outputs: ExtraOutputs,
        model_agent: 'BaseModelAgent',
    ):
        """Merge prefill inputs."""
        inputs, extra_inputs = model_agent.agent_strategy.update_prefill_for_next_step(
            inputs,
            extra_inputs,
            next_token_ids,
            model_metas,
            extra_outputs,
        )
        stopping_criteria = stopping_criteria.clone()
        sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
                                                                           next_token_ids,
                                                                           extra_inputs=extra_inputs)
        if self.model_inputs is None:
            self.model_inputs = inputs
            self.extra_inputs = extra_inputs
            self.stopping_criteria = stopping_criteria
            self.sampling_delta = sampling_delta
        else:
            self.model_inputs = model_agent.inputs_strategy.merge(self.model_inputs, inputs)
            self.extra_inputs = self.extra_inputs.merge(extra_inputs)
            self.stopping_criteria = self.stopping_criteria.merge(stopping_criteria)
            self.sampling_delta = model_agent.sampling_strategy.merge_sampling_delta(
                self.sampling_delta, sampling_delta)

    def update_delta(
        self,
        delta: ModelInputsDelta,
        model_agent: 'BaseModelAgent',
    ):
        """Get inputs from delta."""
        self.model_inputs = model_agent.inputs_strategy.update_inputs(self.model_inputs, delta)
        self.extra_inputs = model_agent.agent_strategy.update_extra_inputs(self.extra_inputs, delta)
        self.stopping_criteria = self.stopping_criteria.update(delta)
        self.sampling_delta = model_agent.sampling_strategy.update_sampling_delta(self.sampling_delta, delta)

    @record_function('StepInputs.step')
    def step(
        self,
        model_inputs: ModelInputs,
        extra_inputs: ExtraInputs,
        stopping_criteria: StoppingCriteria,
        sampling_delta: SamplingInputsDelta,
        next_token_ids: torch.Tensor,
        model_metas,
        extra_outputs: ExtraOutputs,
        model_agent: 'BaseModelAgent',
    ):
        """Update inputs."""
        # dp might change is_decoding of decoding inputs
        model_inputs.is_decoding = True
        (
            self.model_inputs,
            self.extra_inputs,
        ) = model_agent.agent_strategy.update_decoding_for_next_step(
            model_inputs,
            next_token_ids=next_token_ids,
            model_metas=model_metas,
            extra_inputs=extra_inputs,
            extra_outputs=extra_outputs,
        )
        self.stopping_criteria = stopping_criteria.clone()
        self.sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
                                                                                next_token_ids,
                                                                                extra_inputs=extra_inputs)


class BaseModelAgent:
    """Base model agent.

    load model on local gpu

    Args:
        model_path (str): The hugging face model path.
        model_config (ModelConfig): The config of the model.
        cache_config (CacheConfig): The config of the cache info.
        trust_remote_code (bool): Trust remote code
    """

    def __init__(
        self,
        model_path: str,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        backend_config: BackendConfig,
        misc_config: MiscConfig,
        dist_ctx: DistContext,
        device_ctx: DeviceContext,
        adapters: Dict[str, str] = None,
        specdecode_config: SpecDecodeConfig = None,
    ):

        self.model_config = model_config
        self.cache_config = cache_config
        # use raw tokenizer
        if dist_ctx.dist_config.world_size > 1:
            monkey_patch_hf_modules_cache()
        self.tokenizer = Tokenizer(model_path).model.model

        # asyncio
        self._pre_in_que = None
        self._in_que = None
        self._out_que = None
        self._background_task = None
        self._preprocess_task = None
        self.tasks = set()

        # cuda stream
        self.stream = torch.cuda.Stream()
        self.out_stream = torch.cuda.Stream()
        self.cache_stream = torch.cuda.Stream()

        self.dist_ctx = dist_ctx
        self.device_ctx = device_ctx

        device = 'cuda'
        self.backend_config = backend_config
        self.misc_config = misc_config
        self.dist_config = dist_ctx.dist_config
        rank = dist_ctx.rank

        self.model_path = model_path
        self.adapters = adapters
        self.device = device
        self.rank = rank

        tp = self.dist_config.tp
        world_size = self.dist_config.world_size
        self.tp = tp
        self.world_size = world_size
        self.need_output = rank % self.dist_config.attn_tp == 0

        self.patched_model = None
        self.cache_engine = None
        self.state_cache_engine = None
        self.profiler: AgentProfiler = None
        try:
            self.guided_decoding_manager = GuidedDecodingManager(self.tokenizer, model_config.vocab_size)
        except ValueError as e:
            logger.warning(f'Failed to create GuidedManager for tokenizer {type(self.tokenizer)}: {e}')
            self.guided_decoding_manager = None

        # microbatch
        self.enable_microbatch = self.dist_config.enable_microbatch
        self.enable_microbatch_prefill_batchsize_threshold = \
            int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2))
        self.enable_microbatch_prefill_token_threshold = \
            int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2))
        self.enable_microbatch_decode_batchsize_threshold = \
            int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2))

        # strategy
        self.strategy_factory = build_strategy_factory(model_config, misc_config, specdecode_config=specdecode_config)
        self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy()
        self.agent_strategy = self.strategy_factory.build_model_agent_strategy()
        self.sampling_strategy = self.strategy_factory.build_sampling_strategy()

        # spec decoding
        self.spec_agent = build_spec_agent(specdecode_config,
                                           backend_config,
                                           dist_ctx,
                                           self.inputs_strategy,
                                           self.agent_strategy,
                                           device=device)
        # sleep wakeup state
        self.state: SleepWakeupState = SleepWakeupState()

        # decoding inputs
        self.step_inputs = StepInputs()

        # long context
        self._prev_chunk_output: Dict = None

    @contextmanager
    def all_context(self):
        device_mgr = get_device_manager()
        dist_mgr = get_dist_manager()
        with device_mgr.context(self.device_ctx), dist_mgr.context(self.dist_ctx), torch.inference_mode():
            yield

    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
        """Set all cache config."""
        self.cache_config = cache_config
        self.spec_agent.set_cache_config(spec_cache_config)

    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
        """Set model config."""
        self.model_config = model_config
        self.spec_agent.set_model_config(spec_model_config)

    def get_free_mem(self):
        """Gather available memory."""
        with self.all_context():
            torch.cuda.empty_cache()
            gpu_mem_physical_free, _ = get_gpu_memory()
            return gpu_mem_physical_free

    def warmup(self):
        """warmup."""
        from lmdeploy.pytorch.envs import skip_warmup
        if skip_warmup:
            return

        with self.all_context(), torch.cuda.stream(self.stream):
            max_batches = self.cache_config.max_batches
            world_size = self.dist_config.world_size

            num_tokens = max_batches
            dp = self.dist_config.dp

            if dp > 1:
                # make sure warmup started together
                group = self.dist_ctx.cpu_group
                dist.barrier(group=group)

            # warmup prefill
            inputs = self.inputs_strategy.make_dummy(max_batches,
                                                     is_decoding=False,
                                                     device='cuda',
                                                     vocab_size=self.model_config.vocab_size)
            if dp > 1:
                num_tokens = inputs.input_ids.numel()
                inputs.build_dp_meta([num_tokens] * world_size)
            logger.debug('Warmup prefill start.')
            self._forward_impl(inputs)
            torch.cuda.synchronize()
            logger.debug('Warmup prefill done.')

            # warmup decoding(with cuda graph)
            capture_batch_sizes = self.patched_model.get_capture_batch_sizes()
            capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)
            if self.cache_config.role == EngineRole.Prefill:
                # do not warmup decoding for prefill engine
                capture_batch_sizes = []
            for num_tokens in capture_batch_sizes:
                inputs = self.inputs_strategy.make_dummy(num_tokens,
                                                         is_decoding=True,
                                                         device='cuda',
                                                         vocab_size=self.model_config.vocab_size)
                if dp > 1:
                    num_tokens = inputs.input_ids.numel()
                    inputs.build_dp_meta([num_tokens] * world_size)
                logger.debug(f'Warmup decoding num_tokens={num_tokens} start.')
                self._forward_impl(inputs)
                torch.cuda.synchronize()
                logger.debug(f'Warmup decoding num_tokens={num_tokens} done.')

            # warmup draft model
            self.spec_agent.warmup(max_batches, self.model_config)

    def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor):
        """Slice outputs."""
        return self.agent_strategy.slice_outputs(inputs, seq_length)

    def _postprocess_forward_output(self, output: dict, inputs: ModelInputs):
        """Post process forward output."""
        hidden_states = output['hidden_states']
        seq_length = output.get('seq_length', inputs.seq_length)
        hidden_states = self._slice_outs(hidden_states[0], seq_length)[None]
        output['hidden_states'] = hidden_states
        return output

    async def _async_model_forward(
        self,
        inputs: ModelInputs,
        return_logits: bool,
    ):
        """Model forward."""
        origin_inputs = inputs
        ret = await self.async_forward(inputs)

        if not return_logits:
            ret = self._postprocess_forward_output(ret, origin_inputs)

        hidden_states, ret = self.spec_agent.update_main_model_outputs(ret, origin_inputs)

        logits = self.get_logits(hidden_states)
        ret['logits'] = logits
        return ret

    async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs):
        """Sampling logits."""

        # record function does not support async function
        # so we can not decorate it on async_sampling_logits
        with record_function('sampling_logits'):
            logits_processor = FusedLogitsProcessor(
                sampling_inputs,
                logprobs_mode=self.misc_config.logprobs_mode,
                guided_decoding_manager=self.guided_decoding_manager,
            )
            origin_logits = logits
            logits, raw_logprobs = await logits_processor(origin_logits)
            next_token_ids = logits_processor.sampling(logits)
            logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids)
            if logprobs is not None:
                logprobs = BatchedLogProbs(
                    vals=logprobs[0],
                    indices=logprobs[1],
                )

        return next_token_ids, logprobs

    def _push_output(self, output: BatchedOutputs):
        """Push output."""
        event = torch.cuda.Event()
        event.record()
        self._out_que.put_nowait((output, event))

    @contextmanager
    def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):
        if not enable:
            yield
            return

        dist_ctx = self.dist_ctx
        with self.agent_strategy.broadcast_next_token(next_token_ids, extra_inputs, dist_ctx) as handle:
            yield handle

    @record_function('prepare_dp')
    async def _prepare_dp_v1(self, inputs: ModelInputs):
        """Prepare dp.

        If all inputs are dummy inputs, skip forward. If any of the inputs is prefill, then do prefill. Set padding
        batch size for decoding.
        """
        world_size = self.dist_config.world_size
        is_decoding = inputs.is_decoding
        num_tokens = inputs.input_ids.numel()
        is_dummy = inputs.is_dummy

        # gather dp forward metadata
        batch_size = inputs.seq_length.numel()
        is_sleeping = self.state.is_sleeping
        dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens, int(is_sleeping)]
        # check enable_microbatch
        if self.enable_microbatch:
            tokens_num = inputs.input_ids.numel()
            if is_decoding:
                enable_microbatch = batch_size >= \
                    self.enable_microbatch_decode_batchsize_threshold
            else:
                enable_microbatch = batch_size >= \
                    self.enable_microbatch_prefill_batchsize_threshold and \
                    tokens_num >= self.enable_microbatch_prefill_token_threshold
            dp_forward_meta.append(int(enable_microbatch))
        group = self.dist_ctx.cpu_group
        device = 'cpu'
        gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device=device, group=group)
        gathered_meta = (await gathered_meta.async_wait()).cpu()

        # check is_decoding
        # if any one of the rank is prefill, then all ranks are prefill
        is_decoding = gathered_meta[:, 0].all().item()
        inputs.is_decoding = is_decoding

        # check if all inputs are dummy inputs
        is_all_dummy = gathered_meta[:, 1].all().item()
        is_all_sleeping = gathered_meta[:, 3].all().item()
        if is_all_dummy:
            return None, is_all_sleeping

        # pad batch size for decoding
        all_num_tokens = gathered_meta[:, 2].tolist()
        if is_decoding:
            max_num_tokens = max(all_num_tokens)
            meta = self.patched_model.get_meta()
            meta.padding_batch_size = max_num_tokens
            logger.debug(f'max_num_tokens={max_num_tokens}')

        # update if enable_microbatch
        if self.enable_microbatch:
            inputs.enable_microbatch = gathered_meta[:, 4].all().item()

        # update dp meta
        inputs.build_dp_meta(all_num_tokens)
        inputs = self.patched_model.update_inputs(inputs)
        return inputs, is_all_sleeping

    def _get_inputs_from_delta(
        self,
        delta: ModelInputsDelta,
        sampling_inputs: SamplingInputs,
    ):
        """Get inputs from delta."""
        self.step_inputs.update_delta(delta, self)
        inputs = self.step_inputs.model_inputs
        extra_inputs = self.step_inputs.extra_inputs
        stopping_criteria = self.step_inputs.stopping_criteria
        sampling_inputs.update_delta(self.step_inputs.sampling_delta)
        return inputs, extra_inputs, stopping_criteria, sampling_inputs

    def _prepare_inputs_prefill(
        self,
        inputs: ModelInputs,
        delta: ModelInputsDelta,
    ):
        """Prepare prefill inputs."""

        if delta is not None:
            # update decoding inputs with delta
            # for second round chat
            self.step_inputs.update_delta(delta, self)

        if inputs.is_first_chunk:
            self._prev_chunk_output = None

        # check long context
        if self._prev_chunk_output is not None:
            # update model metas
            model_metas = self._prev_chunk_output.get('model_metas')
            inputs.model_metas = model_metas

            if not inputs.is_chunk:
                # remove _prev_chunk_output
                self._prev_chunk_output = None

        return inputs

    async def _step_postprocess_with_output(self,
                                            last_logits: torch.Tensor,
                                            logits: torch.Tensor,
                                            inputs: ModelInputs,
                                            sampling_inputs: SamplingInputs,
                                            stopping_criteria: StoppingCriteria,
                                            model_metas: Any,
                                            need_broadcast_next: bool,
                                            return_logits: bool = False,
                                            all_routed_experts: Any = None,
                                            extra_inputs: ExtraInputs = None):
        """Step postprocess with output."""
        rank = self.rank
        logger.debug(f' rank[{rank}]: Sampling.')
        # sampling
        next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs)

        # post sampling
        next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
                                                                         extra_inputs)

        # spec decoding
        output_token_ids = next_token_ids
        if self.spec_agent.is_enabled():
            extra_inputs = await self.spec_agent.async_model_forward(next_token_ids, inputs, extra_inputs,
                                                                     sampling_inputs)
            next_token_ids = extra_inputs.next_token_ids
            output_token_ids = extra_inputs.output_token_ids
            logits = None

        with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
            logger.debug(f' rank[{rank}]: synchronize token ids')

            # stopping criteria
            stopped, stop_pos, stopping_criteria = stopping_criteria.step(
                next_token_ids,
                sampling_inputs.stop_words,
                inputs=inputs,
                extra_inputs=extra_inputs,
            )

            # send output
            logger.debug(f' rank[{rank}]: Output')
            extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)

        self._push_output(
            BatchedOutputs(next_token_ids=output_token_ids,
                           logits=logits if return_logits else None,
                           stopped=stopped,
                           stop_pos=stop_pos,
                           model_metas=model_metas,
                           logprobs=logprobs,
                           all_routed_experts=all_routed_experts,
                           extra_outputs=extra_outputs))

        return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids

    async def _step_postprocess_without_output(
        self,
        inputs: ModelInputs,
        last_logits: torch.Tensor,
        extra_inputs: ExtraInputs,
        need_broadcast_next: bool,
    ):
        rank = self.rank
        # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
        # as it can trigger recompilation on different ranks when using torch.compile.
        next_token_ids, extra_inputs = self.agent_strategy.make_dummy_next_token(inputs, last_logits, extra_inputs)

        # broadcast next token for TP > 1
        with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
            logger.debug(f' rank[{rank}]: synchronize token ids')

        extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)

        return inputs, next_token_ids, extra_inputs, extra_outputs

    async def _async_step(
        self,
        inputs: ModelInputs,
        delta: ModelInputsDelta = None,
        swap_in_map: Dict = None,
        swap_out_map: Dict = None,
        sampling_inputs: SamplingInputs = None,
        stopping_criteria: StoppingCriteria = None,
        return_logits: bool = False,
        return_routed_experts: bool = False,
        extra_inputs: ExtraInputs = None,
    ):
        """Asyc forward task."""

        @record_function('update_decoding_for_next_step')
        def __update_inputs(
            inputs,
            next_token_ids,
            model_metas,
            extra_inputs,
            extra_outputs,
            stopping_criteria,
            sampling_delta: SamplingInputsDelta = None,
        ):
            """Update inputs."""
            # dp might change is_decoding of decoding inputs
            self.step_inputs.step(
                inputs,
                extra_inputs,
                stopping_criteria,
                sampling_delta,
                next_token_ids,
                model_metas,
                extra_outputs,
                model_agent=self,
            )

        dist_ctx = get_dist_manager().current_context()
        dist_config = dist_ctx.dist_config
        rank = self.rank
        tp = dist_config.attn_tp
        need_broadcast_next = (tp > 1)
        dp = dist_config.dp
        need_update_inputs = False

        if inputs is None:
            # decoding step, update prev_inputs with delta
            need_update_inputs = True
            assert delta is not None
            (
                inputs,
                extra_inputs,
                stopping_criteria,
                sampling_inputs,
            ) = self._get_inputs_from_delta(
                delta,
                sampling_inputs,
            )
        elif not inputs.is_dummy:
            # prefill step
            inputs = self._prepare_inputs_prefill(
                inputs,
                delta,
            )

        # dp might change is_decoding in inputs
        is_decoding = inputs.is_decoding
        if dp > 1:
            # update inputs for dp
            inputs, is_all_sleeping = await self._prepare_dp_v1(inputs)
            # skip dummy forward.
            if inputs is None:
                if is_all_sleeping:
                    self.state.to_sleep.set()
                    await self.state.to_wakeup.wait()
                    self.state.to_wakeup.clear()
                    # sync after wakeup
                    dist.barrier()
                logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.')
                await asyncio.sleep(0.01)
                return

        if not is_decoding:
            # init state cache for first time prefill
            # I don't know if this is necessary...
            self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0)

        # swap caches
        cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

        # inference
        logger.debug(f' rank[{rank}]: model forward. '
                     f'batch_size={inputs.seq_length.size(0)} '
                     f'num_tokens={inputs.input_ids.size(-1)} '
                     f'is_decoding={inputs.is_decoding}')
        output = await self._async_model_forward(
            inputs,
            return_logits=return_logits,
        )
        # recovery is_decoding
        inputs.is_decoding = is_decoding

        if inputs.is_dummy:
            # skip dummy forward output
            return

        logits = output['logits'][0]  # [bs, seq, prob] -> [seq, prob]
        seq_length = output.get('seq_length', inputs.seq_length)
        last_logits = self._slice_outs(logits, seq_length)  # [bs, 1, prob] -> [bs, prob]
        extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output)
        model_metas = output.get('model_metas')

        if self.need_output:
            logger.debug(f' rank[{rank}]: Sampling.')
            # for router replay
            if return_routed_experts:
                all_routed_experts = output.get('all_routed_experts', None)
            else:
                all_routed_experts = None

            (
                inputs,
                extra_inputs,
                stopping_criteria,
                extra_outputs,
                next_token_ids,
            ) = await self._step_postprocess_with_output(
                last_logits,
                logits,
                inputs,
                sampling_inputs,
                stopping_criteria,
                model_metas,
                need_broadcast_next,
                return_logits=return_logits,
                all_routed_experts=all_routed_experts,
                extra_inputs=extra_inputs,
            )
        else:
            (
                inputs,
                next_token_ids,
                extra_inputs,
                extra_outputs,
            ) = await self._step_postprocess_without_output(
                inputs,
                last_logits,
                extra_inputs,
                need_broadcast_next,
            )

        sampling_delta = sampling_inputs.get_delta()
        if need_update_inputs:
            __update_inputs(inputs,
                            next_token_ids,
                            model_metas,
                            extra_inputs,
                            extra_outputs,
                            stopping_criteria,
                            sampling_delta=sampling_delta)
        elif inputs.is_chunk:
            # _prev_chunk_output is used to update model metas
            self._prev_chunk_output = output
        elif self.cache_config.role != EngineRole.Prefill:
            self.step_inputs.merge(
                inputs,
                extra_inputs,
                stopping_criteria,
                sampling_delta,
                next_token_ids,
                model_metas,
                extra_outputs,
                model_agent=self,
            )

    async def _async_loop_background(self, forward_event: asyncio.Event = None):
        """Async loop background."""
        with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode():

            # for dp
            input_maker = build_inputs_maker(self)

            while True:
                forward_inputs = await input_maker.get()

                await self._async_step(**forward_inputs, )
                if forward_event is not None:
                    forward_event.set()

                input_maker.step()

    async def _async_loop_inputs_preprocess(self, forward_event: asyncio.Event = None):
        """Async loop inputs preprocess."""
        non_blocking = True
        keys = ['inputs', 'delta', 'sampling_inputs', 'stopping_criteria', 'extra_inputs']
        while True:
            forward_inputs = await self._pre_in_que.get()
            forward_inputs_cuda = {}
            forward_inputs_cuda.update(forward_inputs)
            logger.debug('preprocessing forward inputs.')
            with torch.cuda.stream(self.out_stream), torch.inference_mode(), record_function('inputs_H2D'):
                for k in keys:
                    if k not in forward_inputs_cuda:
                        continue
                    forward_inputs_cuda[k] = _try_to_cuda(forward_inputs_cuda[k], non_blocking=non_blocking)
                self.out_stream.synchronize()
            logger.debug('preprocessing forward inputs done.')
            self._in_que.put_nowait(forward_inputs_cuda)
            if forward_event is not None:
                forward_event.clear()

    def start(self, forward_event: asyncio.Event = None):
        """Start event loop."""
        event_loop = asyncio.get_event_loop()
        self._pre_in_que = asyncio.Queue()
        self._in_que = asyncio.Queue()
        self._out_que = asyncio.Queue()

        # forward task
        logger.debug('Create task ModelAgentLoop.')
        self._background_task = event_loop.create_task(self._async_loop_background(forward_event),
                                                       name='ModelAgentLoop')
        self.tasks.add(self._background_task)
        self._background_task.add_done_callback(self.tasks.discard)

        # preprocess inputs task
        logger.debug('Create task ModelAgentPreprocess.')
        self._preprocess_task = event_loop.create_task(self._async_loop_inputs_preprocess(forward_event),
                                                       name='ModelAgentPreprocess')
        self.tasks.add(self._preprocess_task)
        self._preprocess_task.add_done_callback(self.tasks.discard)

        # profiler
        self.profiler = AgentProfiler(self.dist_ctx, self.stream)
        self.profiler.create_task()

    async def wait_tasks(self):
        """Wait tasks."""
        if len(self.tasks) == 0:
            return
        try:
            await wait_for_async_tasks(self.tasks)
        except asyncio.CancelledError:
            logger.debug(f'ModelAgent rank[{self.rank}] wait_tasks cancelled.')
            raise
        except BaseException as e:
            raise e from None
        finally:
            logger.debug(f'ModelAgent rank[{self.rank}] wait_tasks cleanup.')

    def stop(self):
        """Stop task."""
        if self.dist_config.dp > 1:
            return

        if self.profiler is not None:
            self.profiler.dump()

        for task in self.tasks:
            if not task.done():
                task.cancel()

        if self.guided_decoding_manager:
            self.guided_decoding_manager.clear()

    async def stop_async(self):
        """Stop task."""
        if self.dist_config.dp > 1:
            return

        if self.profiler is not None:
            # dirty hack for profiler
            while not self.stream.query():
                logger.debug('Profiler waiting for stream finish.')
                await asyncio.sleep(1)
            self.profiler.dump()

        for task in self.tasks:
            if not task.done():
                task.cancel()

        try:
            await asyncio.gather(*self.tasks, return_exceptions=True)
        except asyncio.CancelledError:
            logger.debug(f'ModelAgent {task.get_name()} task cancelled.')

        if self.guided_decoding_manager:
            self.guided_decoding_manager.clear()

    def set_forward_inputs(self, inputs):
        """Set forward inputs."""
        assert self._pre_in_que is not None, ('Please start backendground task before forward.')
        self._pre_in_que.put_nowait(inputs)

    async def get_output_async(self):
        """Async get output."""
        assert self._out_que is not None, ('Please start backendground task before forward.')
        out = await self._out_que.get()
        if out is None:
            return dict()

        out, event = out
        while not event.query():
            await asyncio.sleep(0.001)
        with torch.cuda.stream(self.out_stream), torch.inference_mode(), record_function('outputs_D2H'):
            event.wait()
            out = out.to_cpu()
            out.new_token_timestamp = time.time()
        return out

    def _build_model(self):
        """Build patched model."""
        model_path = self.model_path
        adapters = self.adapters
        device = self.device
        rank = self.rank
        custom_module_map = self.model_config.custom_module_map
        if custom_module_map is not None:
            update_custom_module_map(custom_module_map)
        logger.debug(msg_with_rank(rank, 'build model.'))
        # for router replay
        enable_return_routed_experts = self.misc_config.enable_return_routed_experts and self.need_output

        build_model_ctx = BuildModelContext(
            disable_vision_encoder=self.misc_config.disable_vision_encoder,
            dllm_config=self.misc_config.dllm_config,
            strategy_factory=self.strategy_factory,
            enable_return_routed_experts=enable_return_routed_experts,
            quant_config=self.model_config.quant_config,
            fp32_lm_head=self.model_config.fp32_lm_head,
            tie_word_embeddings=self.model_config.tie_word_embeddings,
        )
        patched_model = build_patched_model(self.model_config, device=device, build_model_ctx=build_model_ctx)
        logger.debug(msg_with_rank(rank, 'loading weights.'))
        if not self.misc_config.empty_init:
            load_model_weights(patched_model, model_path, device=device)
        if adapters is not None:
            logger.debug(msg_with_rank(rank, 'loading adapters.'))
            add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device)
        self.patched_model = patched_model
        self.build_model_ctx = build_model_ctx

    def build_model(self):
        """Build model api."""
        with self.all_context():
            self._build_model()
            self.spec_agent.build_model(self.misc_config.empty_init,
                                        self.patched_model,
                                        build_model_ctx=self.build_model_ctx)

    def build_graph_runner(self):
        """Build graph runner."""
        with self.all_context():
            backend = get_backend()
            self.patched_model = backend.build_graph_runner(self.patched_model,
                                                            model_config=self.model_config,
                                                            cache_config=self.cache_config,
                                                            backend_config=self.backend_config,
                                                            device=self.device)
            self.spec_agent.build_graph_runner()

    def build_cache_engine(self):
        """Build cache engine."""
        with self.all_context():
            dist_ctx = get_dist_manager().current_context()
            dist_cfg = self.dist_config
            tp = dist_cfg.attn_tp

            self.cache_engine = CacheEngine(self.cache_config,
                                            self.model_config,
                                            rank=self.rank,
                                            tp_rank=dist_ctx.attn_tp_group.rank,
                                            world_size=tp,
                                            cache_stream=self.cache_stream)
            self.state_cache_engine = StateCacheEngine(self.cache_config)

            self.spec_agent.build_cache_engine(self.cache_stream)

    def _forward_impl(self, inputs: ModelInputs):
        output = model_forward(
            self.patched_model,
            inputs,
            self.cache_engine,
            state_cache_engine=self.state_cache_engine,
            stream=self.stream,
        )
        return output

    async def async_forward(self, inputs: ModelInputs):
        """Model forward.

        Args:
            inputs (Dict): The input data comes from _make_inputs.
            swap_in_map (SwapMap): Cache maps to swap in.
            swap_out_map (SwapMap): Cache maps to swap out.
        """
        output = self._forward_impl(inputs)
        await asyncio.sleep(0)
        return output

    @record_function('get_logits')
    def get_logits(self, hidden_states: torch.Tensor):
        """Get logits of model output."""
        return self.patched_model.get_logits(hidden_states)

    def get_input_processor(self):
        """Get input processor."""
        return self.patched_model.get_input_processor()

    def reset_graph_runner(self):
        """Reset graph runner to prevent tp hanging."""
        if hasattr(self.patched_model, 'reset'):
            self.patched_model.reset()

        self.spec_agent.reset_graph_runner()

    @torch.inference_mode()
    def update_params(self, request: UpdateParamsRequest):
        """Update params."""

        # modified from https://github.com/vllm-project/vllm/blob/v0.8.5/examples/offline_inference/rlhf_utils.py#L82
        def _construct(item):
            func, args = item
            args = list(args)
            args[6] = torch.cuda.current_device()  # device id.
            # clone() seems necessary otherwise the producer can not release the memory
            return func(*args).clone()

        with self.all_context():
            serialized_data = request.serialized_named_tensors
            if isinstance(serialized_data, list):
                serialized_data = serialized_data[self.dist_ctx.tp_group.rank]
            model = self.patched_model.get_model()
            weights = ForkingPickler.loads(pybase64.b64decode(serialized_data))
            if request.load_format == 'flattened_bucket':
                metadata: List[FlattenedTensorMetadata] = weights['metadata']
                if metadata:
                    flattened_tensor: torch.Tensor = _construct(weights['flattened_tensor'])
                    bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata)
                    weights = bucket.reconstruct_tensors()
                else:
                    # empty data
                    weights = []
            else:
                weights = [(k, _construct(v)) for k, v in weights]

            weights = ModelWeightLoader._rename_weights_iterator(weights, model)
            model.load_weights(weights)

            if request.finished:
                for _, mod in model.named_modules():
                    if not hasattr(mod, 'update_weights'):
                        continue
                    mod.update_weights()

            torch.cuda.empty_cache()

    @torch.inference_mode()
    async def sleep(self, level: int = 1):
        """Sleep."""
        self.state.is_sleeping = True
        if self.dist_config.dp > 1:
            await self.state.to_sleep.wait()
        self.cache_engine = None
        self.reset_graph_runner()
        device = 'cpu' if level == 1 else 'meta'
        self.patched_model.get_model().to(device=device, non_blocking=True)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        self.state.to_sleep.clear()

    @torch.inference_mode()
    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        if tags is None:
            tags = ['weights', 'kv_cache']
        if 'weights' in tags:
            device = next(self.patched_model.get_model().parameters()).device
            assert device.type in ['cpu', 'meta']
            if device.type == 'cpu':
                self.patched_model.get_model().to(torch.cuda.current_device())
            else:
                # user should update weights after wakeup
                old_empty_init = self.misc_config.empty_init
                self.misc_config.empty_init = True
                self.build_model()
                self.build_graph_runner()
                self.misc_config.empty_init = old_empty_init

        if 'kv_cache' in tags:
            self.build_cache_engine()
            # wake up signal
            self.state.is_sleeping = False
            if self.dist_config.dp > 1:
                self.state.to_wakeup.set()

    def release(self):
        """release."""
        self.reset_graph_runner()
        self.patched_model = None
        self.cache_engine = None
        torch.cuda.empty_cache()


================================================
FILE: lmdeploy/pytorch/engine/model_agent/inputs_maker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist

from lmdeploy.pytorch.disagg.config import EngineRole

if TYPE_CHECKING:
    from .agent import BaseModelAgent


class DefaultForwardInputsMaker:
    """Default forward inputs maker."""

    def __init__(self, model_agent: 'BaseModelAgent'):
        self._in_que = model_agent._in_que

    async def get(self):
        """get."""
        return await self._in_que.get()

    def step(self):
        """step."""
        # No-op for default maker
        pass


class DPForwardInputsMaker:
    """Dp forward inputs maker."""

    def __init__(self, model_agent: 'BaseModelAgent'):
        self.model_agent = model_agent
        self.dist_ctx = model_agent.dist_ctx
        self.model_config = model_agent.model_config
        self.cache_config = model_agent.cache_config
        self.inputs_strategy = model_agent.inputs_strategy
        self.device = model_agent.device
        self._in_que = model_agent._in_que

        # maker metas
        self._ready_event = torch.cuda.Event()
        self._ready_event.record()

    def _make_dummy_forward_inputs(self):
        """Make dummy forward inputs."""
        is_decoding = self.cache_config.role != EngineRole.Prefill
        dist_config = self.dist_ctx.dist_config
        batch_size = 2 if dist_config.enable_microbatch else 1
        batch_size = min(self.cache_config.max_batches, batch_size)
        model_inputs = self.inputs_strategy.make_dummy(batch_size,
                                                       is_decoding,
                                                       device=self.device,
                                                       vocab_size=self.model_config.vocab_size)
        forward_inputs = dict(inputs=model_inputs, )
        return forward_inputs

    async def _gather_has_inputs(self, has_inputs: bool = False):
        """Broadcast has inputs."""
        attn_tp_group = self.dist_ctx.attn_tp_group
        attn_tp = self.dist_ctx.dist_config.attn_tp
        if attn_tp == 1:
            return has_inputs

        group = attn_tp_group.cpu_group
        has_inputs = torch.tensor((int(has_inputs), ))
        handle = dist.all_reduce(has_inputs, op=dist.ReduceOp.SUM, group=group, async_op=True)
        future = handle.get_future()
        while not future.done():
            await asyncio.sleep(0)
        future.wait()
        return (has_inputs > 0).item()

    async def _get_inputs(self):
        # get local forward inputs
        try:
            forward_inputs = self._in_que.get_nowait()
        except asyncio.QueueEmpty:
            forward_inputs = None

        # async inputs around tp group
        has_inputs = await self._gather_has_inputs(forward_inputs is not None)
        if has_inputs and forward_inputs is None:
            forward_inputs = await self._in_que.get()

        return forward_inputs

    async def get(self):
        """get."""
        # # wait until has inputs or prev forward finish
        while self._in_que.qsize() == 0 and not self._ready_event.query():
            await asyncio.sleep(0.001)

        # try get inputs
        forward_inputs = await self._get_inputs()

        # make dummy inputs
        if forward_inputs is None:
            forward_inputs = self._make_dummy_forward_inputs()

        return forward_inputs

    def step(self):
        """step."""
        self._ready_event.wait()
        self._ready_event = torch.cuda.Event()
        self._ready_event.record()


def build_inputs_maker(model_agent: 'BaseModelAgent'):
    """Build inputs maker."""
    dist_config = model_agent.dist_ctx.dist_config
    if dist_config.dp > 1:
        return DPForwardInputsMaker(model_agent)
    else:
        return DefaultForwardInputsMaker(model_agent)


================================================
FILE: lmdeploy/pytorch/engine/model_agent/profiler.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio

import torch
from torch.profiler import ProfilerActivity, profile

from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class AgentProfiler:

    def __init__(self, dist_ctx: DistContext, stream: torch.Stream):
        from lmdeploy.pytorch import envs
        self.rank = dist_ctx.rank
        self.dp_rank = dist_ctx.dp_rank
        self.dp = dist_ctx.dist_config.dp
        self.stream = stream
        self.profiler = None
        self.name = f'rank[{self.rank}]'

        self.delay = envs.torch_profile_delay
        self.duration = envs.torch_profile_duration

        self.profiler = self._build_profiler()
        self.prefix = envs.torch_profile_output_prefix
        self._task = None
        self._started = False
        if self.dp > 1 and self.duration < 0 and self.profiler is not None:
            logger.warning('Do not support duration<=0 for dp > 1.')
            self.profiler = None

    def _build_profiler(self):
        from lmdeploy.pytorch import envs
        activities = []
        if envs.torch_profile_cpu:
            activities.append(ProfilerActivity.CPU)
        if envs.torch_profile_cuda:
            activities.append(ProfilerActivity.CUDA)
        if len(activities) > 0:
            logger.warning(f'Profiler start on {self.name}. '
                           'Please Note that profiling might harm performance.')
            profiler = profile(activities=activities)
            return profiler
        else:
            return None

    def dump(self):
        """Dump profile result."""
        if self.profiler is None:
            return

        if not self._started:
            logger.warning(f'Profiler {self.name} not started, skip dump.')
            return

        try:
            self.profiler.stop()
            rank = self.rank
            dump_path = f'{self.prefix}{rank}.json'
            self.profiler.export_chrome_trace(dump_path)
            logger.warning(f'Profiler {self.name} dump to {dump_path}.')
        except Exception as e:
            logger.error(f'Failed to dump profile {self.name} result: {e}')
        finally:
            self.profiler = None

    async def profile_task(self):
        """Profile task."""
        if self.profiler is None:
            return

        # start profiler with delay
        await asyncio.sleep(self.delay)
        self.profiler.start()
        self._started = True

        if self.duration <= 0:
            return

        # dump profiler
        await asyncio.sleep(self.duration)
        self.dump()

    def create_task(self):
        """Create task."""
        event_loop = asyncio.get_event_loop()
        self._task = event_loop.create_task(self.profile_task())


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.messages import PytorchEngineConfig


def build_mp_engine(backend: str, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs):
    """Build mp engine."""
    if backend == 'mp':
        from .zmq_engine import ZMQMPEngine
        return ZMQMPEngine(model_path, engine_config=engine_config, **kwargs)
    elif backend == 'ray':
        from .ray_engine import RayMPEngine
        return RayMPEngine(model_path, engine_config=engine_config, **kwargs)
    else:
        raise ValueError(f'Unsupported backend: {backend}')


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, List, Optional

from lmdeploy.messages import ResponseType
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
                                                   DistServeInitRequest)
from lmdeploy.utils import get_logger

from ..base import EngineBase, EngineInstanceBase

logger = get_logger('lmdeploy')


@dataclass
class SessionState:
    is_exists: asyncio.Event = field(default_factory=asyncio.Event)


class MPEngine(EngineBase):

    def __init__(self) -> None:
        """Initialize mp engine."""
        self.session_states = defaultdict(SessionState)
        self.engine_config = self._collective_rpc('get_engine_config')

    def _collective_rpc(self, func, *args, **kwargs):
        """Collective rpc call."""
        raise NotImplementedError('This method has not been implemented yet.')

    async def _collective_rpc_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        raise NotImplementedError('This method has not been implemented yet.')

    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        raise NotImplementedError('This method has not been implemented yet.')

    def close(self) -> None:
        """Close mp engine."""
        raise NotImplementedError('This method has not been implemented yet.')

    def start_loop(self) -> None:
        """Start mp engine loop."""
        raise NotImplementedError('This method has not been implemented yet.')

    def end_session(self, session_id: int):
        """End session."""
        return self._collective_rpc('end_session', session_id)

    def sleep(self, level: int):
        """sleep."""
        return self._collective_rpc('sleep', level)

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        return self._collective_rpc('wakeup', tags)

    def update_params(self, request: Any):
        """Update params."""
        return self._collective_rpc('update_params', request)

    def get_schedule_metrics(self):
        """Get schedule metrics."""
        return self._collective_rpc('get_schedule_metrics')

    def p2p_initialize(self, conn_request: DistServeInitRequest):
        """Init rdma link."""
        return self._collective_rpc('p2p_initialize', conn_request)

    def p2p_connect(self, conn_request: DistServeConnectionRequest):
        """rdma_connect."""
        return self._collective_rpc('p2p_connect', conn_request)

    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
        """Drop connection.

        1. drop engine connection (zmq connection)
        2. TODO(JimyMa) drop RDMA Connection.
        """
        return self._collective_rpc('p2p_drop_connect', drop_conn_request)

    def create_instance(self, cuda_stream_id=0):
        """Create instance."""
        return MPEngineInstance(self)


class MPEngineInstance(EngineInstanceBase):
    """MP Engine Instance."""

    def __init__(self, engine: MPEngine):
        self.engine = engine
        self.session_states = engine.session_states

    async def async_end(self, session_id: int):
        """End the given session."""
        if session_id not in self.session_states:
            logger.warning(f'Session {session_id} not found when end session.')
            return ResponseType.SESSION_NOT_EXIST
        await self.session_states[session_id].is_exists.wait()
        ret = await self.engine._collective_rpc_async('instance_async_end', session_id)
        self.session_states.pop(session_id)
        return ret

    async def async_cancel(self, session_id: int):
        """Stop current streaming inference."""
        if session_id not in self.session_states:
            logger.warning(f'Session {session_id} not found when cancel session.')
            return ResponseType.SESSION_NOT_EXIST
        await self.session_states[session_id].is_exists.wait()
        return await self.engine._collective_rpc_async('instance_async_cancel', session_id)

    async def async_stream_infer(self, session_id: int, *args, **kwargs):
        """Send stream inference request."""
        state = self.session_states[session_id]
        kwargs['session_id'] = session_id
        kwargs['notify_add_msg'] = True
        generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs)
        # session should have been added
        state.is_exists.set()

        async for result in generator:
            yield result


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/base_worker.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, List, Optional

from lmdeploy.messages import EngineOutput
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
                                                   DistServeInitRequest)
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

if TYPE_CHECKING:
    from lmdeploy.pytorch.engine.engine import Engine


class EngineInstancePool:
    """Engine Instance Pool."""

    def __init__(self, engine):
        from lmdeploy.pytorch.engine import Engine
        self.engine: Engine = engine
        # enlarge `num_instance`, otherwise an sequence cannot be stopped in time
        self.num_instance = self.engine.engine_config.max_batch_size * 2
        self.pool = None

    def create_instance_pool(self, num_instance: int):
        """Create instance pool."""
        pool = asyncio.Queue(maxsize=num_instance)
        for _ in range(num_instance):
            instance = self.engine.create_instance()
            pool.put_nowait(instance)
        return pool

    @asynccontextmanager
    async def instance(self):
        """Get an instance from the pool."""
        # lazy create pool
        if self.pool is None:
            self.pool = self.create_instance_pool(self.num_instance)
        instance = await self.pool.get()
        try:
            yield instance
        finally:
            self.pool.put_nowait(instance)

    async def async_end(self, session_id: int):
        """End the given session."""
        async with self.instance() as instance:
            return await instance.async_end(session_id)

    async def async_cancel(self, session_id: int):
        """Stop current streaming inference."""
        async with self.instance() as instance:
            return await instance.async_cancel(session_id)

    async def async_stream_infer(self, *args, **kwargs):
        """Send stream inference request."""
        async with self.instance() as instance:
            async for result in instance.async_stream_infer(*args, **kwargs):
                yield result


class EngineWorkerBase:
    """Base class for engine worker."""

    def __init__(self, engine: 'Engine'):
        engine.start_loop()
        self.engine = engine
        self.instance_pool = EngineInstancePool(engine)

    def end_session(self, session_id: int):
        """End session."""
        return self.engine.end_session(session_id)

    def get_engine_config(self):
        """Get engine config."""
        return self.engine.get_engine_config()

    def get_schedule_metrics(self):
        """Get schedule metrics."""
        return self.engine.get_schedule_metrics()

    def p2p_initialize(self, conn_request: DistServeInitRequest):
        """Init rdma link."""
        return self.engine.p2p_initialize(conn_request)

    def p2p_connect(self, conn_request: DistServeConnectionRequest):
        """rdma_connect."""
        return self.engine.p2p_connect(conn_request)

    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
        """Drop connection.

        1. drop engine connection (zmq connection)
        2. TODO(JimyMa) drop RDMA Connection.
        """
        return self.engine.p2p_drop_connect(drop_conn_request)

    def sleep(self, level: int = 1):
        """sleep."""
        return self.engine.sleep(level)

    def wakeup(self, tags: Optional[List[str]] = None):
        """Wakeup."""
        return self.engine.wakeup(tags)

    def update_params(self, request: Any):
        """Update params."""
        return self.engine.update_params(request)

    def close(self) -> None:
        """Close engine worker."""
        self.engine.close()

    async def instance_async_end(self, session_id: int):
        """End the given session."""
        return await self.instance_pool.async_end(session_id)

    async def instance_async_cancel(self, session_id: int):
        """Stop current streaming inference."""
        return await self.instance_pool.async_cancel(session_id)

    async def instance_async_stream_infer(self, *args, **kwargs):
        """Send stream inference request."""
        async for result in self.instance_pool.async_stream_infer(*args, **kwargs):
            yield result


class EngineOutputGather:
    """Helper class to gather incremental engine output."""

    def __init__(self):
        self._output = dict()

    def get(self, stream_id):
        if stream_id not in self._output:
            self._output[stream_id] = EngineOutput(status=None, token_ids=[], logprobs=[])
        return self._output[stream_id]

    def add(self, stream_id, result):
        if not isinstance(result, EngineOutput):
            return
        output = self.get(stream_id)
        output.token_ids.extend(result.token_ids or [])
        output.logprobs.extend(result.logprobs or [])

    def pop(self, stream_id, result):
        if not isinstance(result, EngineOutput):
            return result
        output = self._output.pop(stream_id)
        result.token_ids = output.token_ids or []
        result.logprobs = output.logprobs or None
        return result


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/ray_engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Dict

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs
from lmdeploy.utils import get_logger

from .base import MPEngine
from .base_worker import EngineOutputGather, EngineWorkerBase

logger = get_logger('lmdeploy')


class RayEngineWorker(EngineWorkerBase):

    def __init__(self,
                 model_path: str,
                 engine_config: PytorchEngineConfig = None,
                 log_level: int = 30,
                 **kwargs) -> None:
        """Initialize Ray engine worker."""
        from lmdeploy.pytorch.engine.engine import Engine
        logger.setLevel(log_level)
        # create engine
        if engine_config is not None:
            engine_config.enable_mp_engine = False
        engine = Engine.from_pretrained(model_path, engine_config=engine_config, **kwargs)
        super().__init__(engine)

        self._stream_id = 0
        self._stream_aiter = dict()
        self._stream_task = dict()
        self._engine_output_gather = EngineOutputGather()

    async def _stream_task_wrapper(self, stream_id: int, init_event: asyncio.Event, func: str, *args, **kwargs):
        """Create a stream task."""
        method = getattr(self, func)
        event = self._stream_aiter[stream_id][0]
        try:
            generator = method(*args, **kwargs)
            init_event.set()
            async for result in generator:
                self._engine_output_gather.add(stream_id, result)
                self._stream_aiter[stream_id][1] = (result, False)
                event.set()
        finally:
            self._stream_aiter[stream_id][1] = (result, True)
            event.set()
            init_event.set()

    async def create_stream_task(self, func, *args, **kwargs):
        """Create a stream task."""
        stream_id = self._stream_id
        self._stream_id += 1
        event_loop = asyncio.get_event_loop()
        self._stream_aiter[stream_id] = [asyncio.Event(), None]
        init_event = asyncio.Event()
        task = event_loop.create_task(self._stream_task_wrapper(stream_id, init_event, func, *args, **kwargs))
        self._stream_task[stream_id] = task
        await init_event.wait()

        return stream_id

    async def get_stream_task_result(self, stream_id: int):
        """Get the result of a stream task."""
        assert stream_id in self._stream_aiter, f'Stream id {stream_id} not found.'
        stopped = False

        event = self._stream_aiter[stream_id][0]
        await event.wait()
        result, stopped = self._stream_aiter[stream_id][1]
        event.clear()

        result = self._engine_output_gather.pop(stream_id, result)

        if stopped:
            self._stream_aiter.pop(stream_id, None)
            self._stream_task.pop(stream_id, None)
        return result, stopped


def _update_runtime_envs(runtime_env: Dict):
    """Update runtime envs."""
    new_envs = _envs.get_all_envs()
    env_vars: Dict = runtime_env.get('env_vars', {})
    env_vars.update(new_envs)
    runtime_env['env_vars'] = env_vars
    return runtime_env


class RayMPEngine(MPEngine):

    def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs) -> None:
        """Initialize mp engine."""
        self.ray_ctx = self._init_ray(engine_config)
        placement_group = self.ray_ctx.get_placement_group()
        self.placement_group = placement_group

        self.worker = self._create_worker(model_path, engine_config, log_level=logger.level, **kwargs)
        super().__init__()

    def _init_ray(self, engine_config: PytorchEngineConfig = None):
        """Initialize Ray."""
        if engine_config is None:
            engine_config = PytorchEngineConfig()

        device_type = engine_config.device_type if engine_config else 'cuda'
        dp = engine_config.dp if engine_config else 1
        world_size = engine_config.tp if dp <= 1 else 1

        ray_ctx = RayContext(world_size, dp=dp, device_type=device_type)
        return ray_ctx

    def _create_worker(self, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs):
        """Create a Ray worker."""
        bundle_id = 0 if len(_envs.ray_external_pg_bundles) == 0 else _envs.ray_external_pg_bundles[0]
        scheduling_strategy = PlacementGroupSchedulingStrategy(
            placement_group=self.placement_group,
            placement_group_capture_child_tasks=True,
            placement_group_bundle_index=bundle_id,
        )

        runtime_env = dict()
        _update_runtime_envs(runtime_env)
        device_str = get_device_str(engine_config.device_type)
        resource_kwargs = get_resource_kwargs(device_str=device_str, resource_used=0.01)
        worker = ray.remote(
            num_cpus=0,
            **resource_kwargs,
            scheduling_strategy=scheduling_strategy,
            runtime_env=runtime_env,
        )(RayEngineWorker).remote(model_path, engine_config, **kwargs)

        return worker

    def _collective_rpc(self, func, *args, **kwargs):
        """Collective rpc call."""
        method = getattr(self.worker, func)
        return ray.get(method.remote(*args, **kwargs))

    async def _collective_rpc_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        method = getattr(self.worker, func)
        return await method.remote(*args, **kwargs)

    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        # ray generator would try cache every result, which is too verbose.
        stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs)

        stopped = False
        while not stopped:
            result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id)
            yield result

    def close(self) -> None:
        """Close mp engine."""
        logger.info('Closing mp engine.')
        self._collective_rpc('close')
        self.ray_ctx.shutdown()

    def start_loop(self) -> None:
        """Start mp engine loop."""


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/zmq_engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import atexit
import signal
from typing import TYPE_CHECKING

import torch.multiprocessing as mp

from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig
from lmdeploy.utils import get_logger

from .base import MPEngine

logger = get_logger('lmdeploy')

if TYPE_CHECKING:
    from lmdeploy.pytorch.engine.engine import Engine


def cancel_async_tasks(loop: asyncio.AbstractEventLoop):
    """Cancel async tasks."""
    tasks = asyncio.all_tasks(loop=loop)
    for task in tasks:
        if not task.done():
            task.cancel()
    loop.run_until_complete(loop.shutdown_asyncgens())
    loop.close()


class ZMQMPEngine(MPEngine):

    def __init__(self,
                 model_path: str,
                 engine_config: PytorchEngineConfig = None,
                 speculative_config: SpeculativeConfig = None,
                 **kwargs) -> None:
        """Initialize mp engine."""
        from .zmq_rpc import AsyncRPCClient
        self.shared_dict = None
        self.port = None
        self.proc = None
        self._start_mp_proc(model_path, engine_config, speculative_config=speculative_config, **kwargs)

        self.rpc_client = AsyncRPCClient(port=self.port)

        super().__init__()
        atexit.register(self.close)

    def _start_mp_proc(
        self,
        model_path: str,
        engine_config: PytorchEngineConfig = None,
        speculative_config: SpeculativeConfig = None,
        **kwargs,
    ):
        """Start mp proc."""
        logger.debug('Starting engine multi-process.')
        with mp.Manager() as manager:
            self.shared_dict = manager.dict()
            condition = manager.Condition()
            self.mp_ctx = mp.get_context('spawn')
            log_level = logger.level
            target_kwargs = dict(
                model_path=model_path,
                engine_config=engine_config,
                log_level=log_level,
                speculative_config=speculative_config,
            )
            target_kwargs.update(kwargs)
            self.proc = self.mp_ctx.Process(
                target=self._mp_proc,
                args=(self.shared_dict, condition),
                kwargs=target_kwargs,
                name='mp_engine_proc',
            )
            self.proc.start()
            logger.debug('Receiving rpc server port from mp process.')
            with condition:
                if 'rpc_server_port' not in self.shared_dict:
                    condition.wait()
            self.port = self.shared_dict['rpc_server_port']

    @staticmethod
    def _mp_proc(
        shared_dict: dict,
        condition: mp.Condition,
        model_path: str,
        engine_config: PytorchEngineConfig = None,
        log_level: str = 'WARNING',
        speculative_config: SpeculativeConfig = None,
        **kwargs,
    ):
        """Mp process function."""
        from lmdeploy.pytorch.engine import Engine

        from .zmq_rpc import AsyncRPCServer

        logger.setLevel(log_level)

        # create an async rpc server
        server = AsyncRPCServer()
        with condition:
            shared_dict['rpc_server_port'] = server.port
            condition.notify()

        # create engine
        if engine_config is not None:
            engine_config.enable_mp_engine = False
        engine = Engine.from_pretrained(
            model_path,
            engine_config=engine_config,
            speculative_config=speculative_config,
            **kwargs,
        )

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        try:
            loop.run_until_complete(ZMQMPEngine._mp_proc_async(server, engine))
        except KeyboardInterrupt:
            logger.info('Received KeyboardInterrupt, stopping mp process.')

    @staticmethod
    async def _mp_proc_async(server, engine: 'Engine'):
        """Mp process function."""
        import inspect

        from .base_worker import EngineWorkerBase

        loop = asyncio.get_running_loop()
        current_task = asyncio.current_task()

        async def shutdown(loop, signame):
            logger.info(f'MP process received signal {signame}, stopping server.')
            if current_task is not None:
                current_task.cancel()

        for signame in {'SIGINT', 'SIGTERM'}:
            sig = getattr(signal, signame)
            loop.add_signal_handler(sig, lambda signame=signame: asyncio.create_task(shutdown(loop, signame)))

        worker = EngineWorkerBase(engine)

        for name, value in inspect.getmembers(EngineWorkerBase):
            if not name.startswith('_') and inspect.isfunction(value):
                method = getattr(worker, name)
                server.register_method(name, method)

        try:
            # run server
            await server.run()
        except asyncio.CancelledError:
            logger.info('RPC Server stopping due to cancellation.')
        except Exception as e:
            logger.error(f'RPC Server stopped with exception: {e}')
        finally:
            server.stop()
            engine.close()
            try:
                await engine.wait_tasks()
            except asyncio.CancelledError:
                logger.info('Engine wait_tasks cancelled during shutdown.')
            except Exception as e:
                logger.debug(f'Engine wait_tasks failed during shutdown: {e}')

    def _collective_rpc(self, func, *args, **kwargs):
        """Collective rpc call."""
        return self.rpc_client.call(func, *args, **kwargs)

    async def _collective_rpc_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        return await self.rpc_client.async_call(func, *args, **kwargs)

    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
        """Collective rpc call."""
        async for out in self.rpc_client.async_stream_call(func, *args, **kwargs):
            yield out

    def close(self) -> None:
        """Close mp engine."""
        if self.proc is None:
            return
        logger.info('Closing mp engine.')
        self.rpc_client.stop()
        self.proc.terminate()
        self.proc.join(10)
        if not self.proc.is_alive():
            self.proc.close()
        else:
            logger.warning('MP process did not terminate in time, force killing.')
            self.proc.kill()
        self.proc = None

    def start_loop(self) -> None:
        """Start mp engine loop."""


================================================
FILE: lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import inspect
import pickle
from typing import Callable, Dict
from uuid import uuid4

import zmq
import zmq.asyncio
from zmq.asyncio import Context

from lmdeploy.utils import get_logger

from .base_worker import EngineOutputGather

logger = get_logger('lmdeploy')


def _task_callback(task: asyncio.Task) -> None:
    """Raise exception on finish."""
    task_name = task.get_name()
    try:
        task.result()
    except asyncio.CancelledError:
        logger.debug(f'Task <{task_name}> cancelled.')
    except Exception:
        logger.exception(f'Task <{task_name}> failed')
    finally:
        if not task.done():
            task.cancel()


class AsyncRPCServer:

    def __init__(self):
        # Warning: DO NOT allow visit rpc server from external network
        # unauthorized access may lead to code execution vulnerability
        address = 'tcp://localhost'
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.ROUTER)
        self.port = self.socket.bind_to_random_port(address)
        self.methods: Dict[str, Callable] = {}
        self.running = False

        # streaming
        self.stream_output = dict()
        self._stream_idx = 0
        self._engine_output_gather = EngineOutputGather()

        self.tasks = set()

    def get_port(self):
        return self.port

    def _get_next_stream_id(self):
        """Get next stream id."""
        self._stream_idx += 1
        return self._stream_idx

    def register_method(self, name: str, func: Callable):
        """Register method."""
        if asyncio.iscoroutinefunction(func):
            func_type = 'async'
        elif inspect.isasyncgenfunction(func):
            func_type = 'async_streaming'
        else:
            func_type = 'default'
        self.methods[name] = (func_type, func)

    def send_multipart(self, client_id: bytes, data: bytes):
        """Send multipart message to client."""
        try:
            self.socket.send_multipart([client_id, pickle.dumps(data)])
        except zmq.ZMQError as e:
            logger.error(f'Failed to send message to client[{client_id}]: {e}')

    def call_method_default(self, client_id, method: Callable, request: Dict):
        request_id = request.get('request_id')
        args = request.get('args', [])
        kwargs = request.get('kwargs', {})
        try:
            result = method(*args, **kwargs)
            response = dict(success=True, request_id=request_id, result=result)
        except Exception as e:
            response = dict(success=False, request_id=request_id, error=str(e))
        self.send_multipart(client_id, response)

    async def _method_async_task(self, client_id, request_id, method: Callable, args: tuple, kwargs: Dict):
        """Call method in a task."""
        try:
            result = await method(*args, **kwargs)
            response = dict(success=True, request_id=request_id, result=result)
        except Exception as e:
            response = dict(success=False, request_id=request_id, error=str(e))
        self.send_multipart(client_id, response)

    async def _method_async_streaming_task(self, stream_id: int, request_id: int, client_id: int, method: Callable,
                                           args: tuple, kwargs: Dict):
        """Call method in a task for streaming."""

        def __send_resp():
            response = dict(success=True, request_id=request_id, result=stream_id)
            session_id = kwargs.get('session_id', None)
            if session_id is None:
                session_id = args[0]
            self.send_multipart(client_id, response)

        stream_out = dict(
            event=asyncio.Event(),
            result=None,
            stopped=False,
        )
        self.stream_output[stream_id] = stream_out
        __send_resp()
        try:
            generator = method(*args, **kwargs)
            async for result in generator:
                self._engine_output_gather.add(stream_id, result)
                stream_out['result'] = result
                stream_out['event'].set()
        except Exception as e:
            stream_out['error'] = e
            stream_out['event'].set()
        finally:
            stream_out['stopped'] = True

    async def get_stream_output(self, stream_id: int):
        """Get streaming output."""
        if stream_id not in self.stream_output:
            raise ValueError(f'Stream ID {stream_id} not found')
        stream_out = self.stream_output[stream_id]
        event = stream_out['event']
        await event.wait()
        event.clear()
        result = stream_out['result']
        stopped = stream_out['stopped']
        result = self._engine_output_gather.pop(stream_id, result)
        if stopped:
            self.stream_output.pop(stream_id)
        if 'error' in stream_out:
            raise stream_out['error']
        return result, stopped

    async def call_method_async(self, client_id, method: Callable, request: Dict):
        """Call method async."""
        request_id = request.get('request_id')
        method_name = request.get('method')
        args = request.get('args', [])
        kwargs = request.get('kwargs', {})
        event_loop = asyncio.get_event_loop()
        name = f'{method_name}_{client_id}'
        if request.get('streaming', False):
            # if method is a streaming method, use a different task
            stream_id = self._get_next_stream_id()
            task = event_loop.create_task(self._method_async_streaming_task(stream_id, request_id, client_id, method,
                                                                            args, kwargs),
                                          name=name)
            self.tasks.add(task)
            task.add_done_callback(self.tasks.discard)
        else:
            task = event_loop.create_task(self._method_async_task(client_id, request_id, method, args, kwargs),
                                          name=name)
            self.tasks.add(task)
            task.add_done_callback(self.tasks.discard)

    async def call_and_response(self):
        """Call method."""
        # receive message: [client_id, empty, request_data]
        client_id, request_data = self.socket.recv_multipart()
        request = pickle.loads(request_data)

        method_name = request.get('method')
        logger.debug(f'call method: {method_name}')
        if method_name not in self.methods:
            request_id = request.get('request_id')
            response = dict(success=False, request_id=request_id, error=f'Method {method_name} not found')
            self.send_multipart(client_id, response)
        else:
            method_type, method = self.methods[method_name]
            if method_type in ('async', 'async_streaming'):
                await self.call_method_async(client_id, method, request)
            else:
                self.call_method_default(client_id, method, request)

    async def run(self):
        logger.info('Starting AsyncRPCServer...')
        self.running = True
        poller = zmq.asyncio.Poller()
        poller.register(self.socket, zmq.POLLIN)

        self.register_method('_asyncrpcserver_get_stream_output', self.get_stream_output)
        try:
            events = await poller.poll(timeout=10)
            while self.running:
                while self.socket in dict(events):
                    await self.call_and_response()
                    events = await poller.poll(timeout=0)
                events = await poller.poll(timeout=10)

        except zmq.ZMQError:
            logger.exception('ZMQRPCServer error')
        except Exception:
            logger.exception('AsyncRPCServer error')
        finally:
            logger.info('Stopping AsyncRPCServer...')
            self.socket.close()
            self.context.term()
            self.running = False

    def stop(self):
        self.running = False
        for task in self.tasks:
            task.cancel()


class AsyncRPCClient:

    def __init__(self, port: int = 5555):
        logger.info(f'Connecting to AsyncRPCServer on port {port}...')
        address = f'tcp://localhost:{port}'

        socket_type = zmq.DEALER

        # sync socket
        self.sync_ctx = zmq.Context()
        self.sync_socket = self.sync_ctx.socket(socket_type)
        self.sync_socket.connect(address)
        self.sync_poller = zmq.Poller()
        self.sync_poller.register(self.sync_socket, zmq.POLLIN)

        # async socket
        self.async_ctx = Context.instance()
        self.async_socket = self.async_ctx.socket(socket_type)
        self.async_socket.connect(address)

        self.pending = {}
        self._listen_task = None
        self.running = False

    def _set_reply_default(self, request_id: int, reply: Dict):
        """Default reply handler for sync socket."""
        logger.debug(f'recv reply request_id: {request_id}')
        future: asyncio.Future = self.pending.pop(request_id)
        try:
            if reply['success']:
                future.set_result(reply['result'])
            else:
                future.set_exception(Exception(reply['error']))
        except Exception as e:
            logger.debug(f'Set future failed with exception: {e}')

    def _set_reply(self, reply: Dict):
        request_id = reply['request_id']
        self._set_reply_default(request_id, reply)

    def _poll_recv(self, timeout: float = 3):
        """Poll and receive message."""
        # socket.recv would block the process, use poll to avoid hanging
        while True:
            sockets = dict(self.sync_poller.poll(timeout=timeout * 1000))
            if self.sync_socket in sockets:
                return self.sync_socket.recv()

    def _try_start_listen(self):
        """Try to start listening on async socket."""
        if self._listen_task is None or self._listen_task.done():
            logger.debug('Starting async listen task...')
            self._listen_task = asyncio.create_task(self.listen(), name='AsyncRPCClient.listen')
            self._listen_task.add_done_callback(_task_callback)

    def call(self, method, *args, **kwargs):
        request_id = str(uuid4())
        logger.debug(f'call method: {method}, request_id: {request_id}')
        data = pickle.dumps(dict(request_id=request_id, method=method, args=args, kwargs=kwargs))
        self.sync_socket.send(data)

        reply = self._poll_recv()
        reply = pickle.loads(reply)
        while reply['request_id'] != request_id:
            self._set_reply(reply)
            reply = self._poll_recv()
            reply = pickle.loads(reply)

        logger.debug(f'recv reply request_id: {request_id}')
        if reply['success']:
            return reply['result']
        else:
            raise Exception(reply['error'])

    async def _async_call_impl(self, method, streaming, *args, **kwargs):
        self._try_start_listen()
        request_id = str(uuid4())
        future = asyncio.Future()
        self.pending[request_id] = future

        logger.debug(f'call method: {method}, request_id: {request_id}')
        data = pickle.dumps(dict(request_id=request_id, method=method, args=args, kwargs=kwargs, streaming=streaming))
        await self.async_socket.send(data)

        return await future

    async def async_call(self, method, *args, **kwargs):
        """Async call."""
        return await self._async_call_impl(method, False, *args, **kwargs)

    async def async_stream_call(self, method, *args, **kwargs):
        """Streaming call."""
        stream_id = await self._async_call_impl(method, True, *args, **kwargs)

        stopped = False
        while not stopped:
            output, stopped = await self.async_call('_asyncrpcserver_get_stream_output', stream_id)
            yield output

    async def listen(self):
        self._listen_task = asyncio.current_task()
        self.running = True
        try:
            while self.running:
                reply = await self.async_socket.recv()
                reply = pickle.loads(reply)
                self._set_reply(reply)
        except zmq.ZMQError:
            logger.exception('AsyncRPCClient listen error')
        finally:
            self.running = False
            self.close_sockets()

    def stop(self):
        """Stop the client."""
        self.running = False
        if self._listen_task is not None:
            self._listen_task.cancel()
        self.close_sockets()

    def close_sockets(self):
        """Close sockets."""
        self.async_socket.close()
        self.sync_socket.close()
        self.async_ctx.term()
        self.sync_ctx.term()


================================================
FILE: lmdeploy/pytorch/engine/request.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import enum
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Coroutine, Dict, List

from lmdeploy.messages import RequestMetrics, ResponseType
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class RequestType(enum.Enum):
    """Request type."""

    ADD_SESSION = enum.auto()
    ADD_MESSAGE = enum.auto()
    STOP_SESSION = enum.auto()
    END_SESSION = enum.auto()
    STOP_ENGINE = enum.auto()
    RESUME_ENGINE = enum.auto()


@dataclass
class Response:
    """Response."""

    type: ResponseType
    sender_id: int
    event: asyncio.Event
    data: Any = None
    err_msg: str = ''
    is_done: bool = False
    req_metrics: RequestMetrics = None


@dataclass
class Request:
    """Request."""

    type: RequestType
    sender_id: int
    data: Any = None
    resp: Response = None


ReqList = List[Request]


def _run_until_complete(future: Awaitable):
    """Run untile complete."""
    try:
        event_loop = asyncio.get_event_loop()
    except Exception:
        logger.warning('Can not found event loop in current thread.'
                       ' Create a new event loop.')
        event_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(event_loop)
    return event_loop.run_until_complete(future)


@dataclass
class RequestSender:
    """Request sender.

    Args:
        sender_id (int): The id of the sender
    """
    sender_id: int
    manager: 'RequestManager'
    resp_dict: Dict[int, List[Response]] = field(default_factory=dict)

    @classmethod
    def new(cls, sender_id: int, manager: 'RequestManager'):
        """new."""
        obj = cls(sender_id=sender_id, manager=manager)
        return obj

    @property
    def req_que(self):
        """Request queue."""
        return self.manager.requests

    @property
    def event_loop(self):
        """Get event loop."""
        return self.manager.event_loop

    def is_loop_alive(self):
        """Is loop alive."""
        return self.manager.is_loop_alive()

    def run_until_complete(self, future: Awaitable):
        """Run untile complete."""
        return self.manager.run_until_complete(future)

    def _req_put(self, reqs: Any):
        """Async rq_que put."""
        self.req_que.put_nowait(reqs)

    def _gather_request(self, req_types: List[RequestType], data: List[Any]):
        """Gather requests."""
        if self.manager._loop_task is None:
            self.manager.create_loop_task()
        assert len(req_types) == len(data)

        reqs = []
        resps = []
        for rtype, rdata in zip(req_types, data):
            event = asyncio.Event()
            resp = Response(type=ResponseType.INTERNAL_ENGINE_ERROR,
                            sender_id=self.sender_id,
                            event=event,
                            data=None,
                            err_msg=None)
            req = Request(type=rtype, sender_id=self.sender_id, data=rdata, resp=resp)
            resps.append(resp)
            reqs.append(req)
        return resps, reqs

    def batched_send_async(self, req_types: List[RequestType], data: List[Any]):
        """Batched send request asynchronize."""
        resps, reqs = self._gather_request(req_types, data)
        self._req_put(reqs)
        return resps

    def send_async(self, req_type: RequestType, data: Any):
        """Send request asynchronize."""
        return self.batched_send_async(req_types=[req_type], data=[data])[0]

    async def async_recv(self, resp: Response, wait_main: bool = False) -> Response:
        """Receive response of given request id async."""
        if wait_main:
            await self.manager.prepare_send()
        event = resp.event
        while not event.is_set():
            try:
                await asyncio.wait_for(event.wait(), 1)
            except asyncio.TimeoutError:
                if self.is_loop_alive():
                    continue
                logger.debug('Engine main loop failed.')
                resp.type = ResponseType.ENGINE_STOP_ERROR
                break
        event.clear()
        return resp

    def recv(self, resp: Response) -> Response:
        """Receive response of given request id."""
        coro = self.async_recv(resp)
        return self.run_until_complete(coro)

    async def async_send(self, req_type: RequestType, data: Any):
        """Send and receive synchronize."""
        resp = self.send_async(req_type, data)
        return await self.async_recv(resp)

    def send(self, req_type: RequestType, data: Any) -> Response:
        """Send and receive synchronize."""
        resp = self.send_async(req_type, data)
        return self.recv(resp)


class RequestManager:
    """Request manager."""

    def __init__(self):
        self.senders: Dict[int, RequestSender] = dict()
        self.callbacks: Dict[RequestType, Callable] = dict()
        self.request_priority: List[RequestType] = [
            RequestType.STOP_ENGINE, RequestType.ADD_SESSION, RequestType.STOP_SESSION, RequestType.END_SESSION,
            RequestType.ADD_MESSAGE
        ]
        self.requests: asyncio.Queue = None
        self._loop_task: asyncio.Future = None
        self._loop_coro: Callable = None
        self._next_sender_id = 0

        # sender speed limiter
        self._condition: asyncio.Condition = None
        self._sender_wait_task: asyncio.Task = None
        self._send_count = 0
        self._send_event = None

    async def prepare_send(self):
        if self._condition is None:
            return

        self._send_count += 1
        self._send_event.set()
        async with self._condition:
            await self._condition.wait()
        self._send_count -= 1
        if self._send_count == 0:
            self._send_event.clear()

    async def sender_wait_loop(self):
        """Wait for loop to be created."""
        self._condition = asyncio.Condition()
        self._send_count = 0
        self._send_event = asyncio.Event()

        try:
            while True:
                await self._send_event.wait()
                # notify one sender to control send speed
                async with self._condition:
                    self._condition.notify()
                await asyncio.sleep(0.0001)
        finally:
            # notify all senders to exit
            async with self._condition:
                self._condition.notify_all()
            self._condition = None
            self._send_event = None

    def create_loop_task(self):
        """Create coro task."""
        if self._loop_task is not None:
            logger.debug('loop task has been created.')
            return self._loop_task
        logger.debug('creating engine loop task.')
        event_loop = asyncio.get_event_loop()
        assert self._loop_coro is not None, ('Please set loop task with manager.start_loop')
        loop_unshielded = event_loop.create_task(self._loop_coro(), name='EngineMainLoop')
        self._loop_task = loop_unshielded
        self._sender_wait_task = event_loop.create_task(self.sender_wait_loop(), name='SenderWaitLoop')
        self.requests = asyncio.Queue()
        return self._loop_task

    async def wait_tasks(self):
        """Wait for loop task and sender wait task to finish."""
        if self._loop_task is None:
            return

        try:
            await self._loop_task
        except asyncio.CancelledError:
            logger.info('Engine main loop task has been cancelled.')
            raise
        finally:
            if self._sender_wait_task is not None:
                self._sender_wait_task.cancel()
                try:
                    await self._sender_wait_task
                except Exception:
                    logger.debug('Sender wait task has been cancelled.')

    @property
    def event_loop(self):
        """Get event loop."""
        if self._loop_task is None:
            return None
        else:
            return self._loop_task.get_loop()

    def set_main_loop_func(self, loop: Callable[[Coroutine], asyncio.Task]):
        """Start main loop."""
        self._loop_coro = loop

    def stop_loop(self):
        if self.is_loop_alive():
            self._loop_task.cancel()
        self._loop_task = None
        if self._sender_wait_task is not None:
            self._sender_wait_task.cancel()
            self._sender_wait_task = None

    def is_loop_alive(self):
        """Check if main loop is alive."""

        if self._loop_task is None:
            logger.debug('loop task has not been created.')
            return False
        if self._loop_task.get_loop() != asyncio.get_event_loop():
            logger.warning('Current event loop is different from'
                           ' the one bound to loop task!')
            return False
        return not self._loop_task.done()

    def build_sender(self):
        """Create a new sender."""
        sender_id = self._next_sender_id
        self._next_sender_id += 1
        new_sender = RequestSender.new(sender_id, self)
        self.senders[sender_id] = new_sender
        return new_sender

    def has_requests(self):
        """Has unprocessed request."""
        if self.requests is None:
            return False
        return not self.requests.empty()

    async def get_all_requests(self) -> Dict[RequestType, List[Request]]:
        """Get all requests in current queue."""
        num_reqs = self.requests.qsize()
        reqs: ReqList = []

        def __proc_reqs(elem):
            """Proc reqs."""
            nonlocal reqs
            if isinstance(elem, Request):
                elem = [elem]
            reqs += elem

        if num_reqs == 0:
            elem = await self.requests.get()
            __proc_reqs(elem)
            num_reqs = self.requests.qsize()

        for _ in range(num_reqs):
            elem = self.requests.get_nowait()
            __proc_reqs(elem)

        # gather requests
        reqs_by_type: Dict[RequestType, List[Request]] = dict((t, []) for t in RequestType)
        for req in reqs:
            reqs_by_type[req.type].append(req)
        return reqs_by_type

    def bind_func(self, req_type: RequestType, callback: Callable):
        """Bind handler for given request type."""
        self.callbacks[req_type] = callback

    def set_request_priority(self, priority: List[RequestType]):
        """Set the priority of request type."""
        self.request_priority = priority

    def response(self, resp: Response):
        """Send response."""
        resp.event.set()

    def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
        """Process reqs with given req type."""
        # get callback
        func = self.callbacks.get(req_type, None)
        if func is not None:
            func(reqs, **kwargs)
        else:
            # TODO: send error message
            for req in reqs:
                resp = req.resp
                resp.type = ResponseType.HANDLER_NOT_EXIST
                resp.err_msg = (f'callback for {req_type}'
                                ' not exists.')
                self.response(resp)

    async def step(self, **kwargs):
        """Handle requests.

        Should only be called in loop task.
        """

        def _log_reqs(reqs: ReqList):
            num_reqs = len(reqs)
            if num_reqs == 0:
                return
            logger_level = logger.level
            if logger_level <= logging.DEBUG:
                sender_id = [req.sender_id for req in reqs]
                logger.debug(f'Receive {req_type.name} Request: senders: {sender_id}')
            elif logger_level <= logging.INFO:
                logger.info(f'Receive {req_type.name} Request: {num_reqs}')

        reqs_by_type = await self.get_all_requests()

        # handle requests
        for req_type in self.request_priority:
            reqs: ReqList = reqs_by_type.get(req_type, [])
            if not reqs:
                continue

            _log_reqs(reqs)
            self.process_request(req_type, reqs, **kwargs)

    def run_until_complete(self, future: Awaitable):
        """Run untile complete."""
        return _run_until_complete(future)


================================================
FILE: lmdeploy/pytorch/envs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import contextlib
import os
from typing import Union


def env_to_bool(
    env_var: str,
    default: bool = False,
    *,
    true_values: Union[set, list] = {'true', '1', 'yes', 'on'},
    false_values: Union[set, list] = {'false', '0', 'no', 'off'},
):
    """Env to bool."""
    value = os.getenv(env_var)
    if value is None:
        return default
    value = value.lower().strip()
    if value in true_values:
        return True
    elif value in false_values:
        return False
    else:
        raise ValueError(f"Cannot convert environment variable '{env_var}={value}' to boolean. "
                         f'Allowed true values: {true_values}, false values: {false_values}')


def env_to_int(
    env_var: str,
    default: int = 0,
):
    """Env to int."""
    value = os.getenv(env_var)
    if value is None:
        return default
    try:
        value = int(value)
    except Exception:
        value = default
    return value


def env_to_list_int(
    env_var: str,
    default: list[int] = None,
):
    """Env to list of int."""
    default_ = default if default is not None else []
    value = os.getenv(env_var)
    if value is None:
        return default_
    try:
        value = [int(x) for x in value.split(',')]
    except Exception:
        value = default_
    return value


def env_to_float(
    env_var: str,
    default: float = 0,
):
    """Env to float."""
    value = os.getenv(env_var)
    if value is None:
        return default
    try:
        value = float(value)
    except Exception:
        value = default
    return value


_ENVS = dict()


@contextlib.contextmanager
def set_envs():
    _origin_get_env = os.getenv

    def _patched_get_env(
        env_var: str,
        default: Union[str, None] = None,
    ):
        """Patched get_env."""
        if env_var in os.environ:
            _ENVS[env_var] = os.environ[env_var]

        return _origin_get_env(env_var, default)

    os.getenv = _patched_get_env
    yield
    os.getenv = _origin_get_env


with set_envs():
    # loader
    random_load_weight = env_to_bool('LMDEPLOY_RANDOM_LOAD_WEIGHT', True)

    # profile
    ray_nsys_enable = env_to_bool('LMDEPLOY_RAY_NSYS_ENABLE', False)
    ray_nsys_output_prefix = os.getenv('LMDEPLOY_RAY_NSYS_OUT_PREFIX', None)

    # ascend
    ascend_set_rt_visable_devices_by_ray = env_to_bool('ASCEND_SET_RT_VISIBLE_DEVICES_BY_RAY', False)
    ascend_rank_table_file = os.getenv('ASCEND_RANK_TABLE_FILE_PATH')

    # dp
    dp_master_addr = os.getenv('LMDEPLOY_DP_MASTER_ADDR', None)
    dp_master_port = os.getenv('LMDEPLOY_DP_MASTER_PORT', None)

    # executor
    executor_backend = os.getenv('LMDEPLOY_EXECUTOR_BACKEND', None)

    # torch profiler
    torch_profile_cpu = env_to_bool('LMDEPLOY_PROFILE_CPU', False)
    torch_profile_cuda = env_to_bool('LMDEPLOY_PROFILE_CUDA', False)
    torch_profile_delay = env_to_int('LMDEPLOY_PROFILE_DELAY', 0)
    torch_profile_duration = env_to_int('LMDEPLOY_PROFILE_DURATION', -1)
    torch_profile_output_prefix = os.getenv('LMDEPLOY_PROFILE_OUT_PREFIX', 'lmdeploy_profile_')

    # ray timeline
    ray_timeline_enable = env_to_bool('LMDEPLOY_RAY_TIMELINE_ENABLE', False)
    ray_timeline_output_path = os.getenv('LMDEPLOY_RAY_TIMELINE_OUT_PATH', 'ray_timeline.json')

    # ray external placement group bundles
    # only used when lmdeploy is initialized inside a Ray Actor with pg allocated
    ray_external_pg_bundles = env_to_list_int('LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES', [])

    # dist
    dist_master_addr = os.getenv('LMDEPLOY_DIST_MASTER_ADDR', None)
    dist_master_port = os.getenv('LMDEPLOY_DIST_MASTER_PORT', None)

    # logging
    log_file = os.getenv('LMDEPLOY_LOG_FILE', None)

    # check env
    enable_check_env = env_to_bool('LMDEPLOY_ENABLE_CHECK_ENV', True)

    # dlblas
    # we don't need to read this, it would be passed to ray workers
    # If Ray is launched from outside, it may fail to access the environment variables.
    os.getenv('DEEPEP_MAX_TOKENS_PER_RANK', None)
    os.getenv('DEEPEP_ENABLE_MNNVL', None)
    os.getenv('DEEPEP_MODE', 'auto')

    # deepep
    deep_ep_buffer_num_sms = env_to_int('DEEPEP_BUFFER_NUM_SMS', 20)

    # deepgemm
    os.getenv('DG_JIT_DEBUG', '0')
    os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', '0')

    # model agent
    skip_warmup = env_to_bool('LMDEPLOY_SKIP_WARMUP', False)

    # model format
    scale_fmt = os.getenv('LMDEPLOY_SCALE_FMT', None)

    # repetition check
    repetition_window_size = env_to_int('LMDEPLOY_REPETITION_WINDOW_SIZE', 1024)


def get_all_envs():
    """Get all environment variables."""
    return _ENVS


================================================
FILE: lmdeploy/pytorch/kernels/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,
                                  rms_norm_dynamic_quant)

__all__ = [
    'matmul_kernel_dynamic_quant',
    'per_channel_quant',
    'per_token_quant_int8',
    'rms_norm_dynamic_quant',
]


================================================
FILE: lmdeploy/pytorch/kernels/cuda/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..default.w8a8_kernels import per_channel_quant
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .flashattention import flash_attn_varlen_func
from .flatten_kv_cache import flatten_kv_cache
from .fused_moe import fused_moe
from .multinomial_sampling import multinomial_sampling
from .pagedattention import flash_attn_with_kvcache
from .rms_norm import rms_norm
from .w8a8_fused_moe import fused_moe_w8a8
from .w8a8_triton_kernels import matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant

__all__ = [
    'apply_rotary_pos_emb',
    'fused_moe',
    'flash_attn_with_kvcache',
    'fill_kv_cache',
    'multinomial_sampling',
    'rms_norm',
    'matmul_kernel_dynamic_quant',
    'per_channel_quant',
    'per_token_quant_int8',
    'rms_norm_dynamic_quant',
    'flash_attn_varlen_func',
    'flatten_kv_cache',
    'fused_moe_w8a8',
]


================================================
FILE: lmdeploy/pytorch/kernels/cuda/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from packaging import version

from .utils import get_device_props

TRITON_VERSION = version.parse(triton.__version__)

if TRITON_VERSION >= version.parse('3.0.0'):
    fast_expf = tl.math.exp
else:
    fast_expf = tl.math.fast_expf


@triton.jit
def _silu_and_mul_kernel(
    gateup_ptr,
    out_ptr,
    N: tl.constexpr,
    M,
    stride_gum: tl.constexpr,
    stride_gun: tl.constexpr,
    stride_om: tl.constexpr,
    stride_on: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    """Silu and mul kernel."""
    n_block_id = tl.program_id(0)
    m_id_start = tl.program_id(1)
    m_id_stride = tl.num_programs(1)

    up_ptr = gateup_ptr + N * stride_gun
    offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    if N % BLOCK_SIZE_N == 0:
        mask = None
    else:
        mask = offs_n < N

    gate_ptrs = gateup_ptr + m_id_start * stride_gum + offs_n * stride_gun
    up_ptrs = up_ptr + m_id_start * stride_gum + offs_n * stride_gun
    out_ptrs = out_ptr + m_id_start * stride_om + offs_n * stride_on

    for _ in tl.range(m_id_start, M, m_id_stride):
        gate = tl.load(gate_ptrs, mask=mask)
        up = tl.load(up_ptrs, mask=mask)
        # exp expect fp32
        gate = gate.to(tl.float32)

        gate = gate / (1 + fast_expf(-gate))
        gate = gate.to(gateup_ptr.dtype.element_ty)
        out = gate * up

        tl.store(out_ptrs, out, mask=mask)

        gate_ptrs += m_id_stride * stride_gum
        up_ptrs += m_id_stride * stride_gum
        out_ptrs += m_id_stride * stride_om


def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None):
    """Silu and mul."""
    assert gate_up.dim() == 2

    M = gate_up.size(0)
    N = gate_up.size(-1) // 2
    if out is None:
        out_shape = (M, N)
        out = gate_up.new_empty(out_shape)

    BLOCK_SIZE_N = triton.next_power_of_2(N)
    BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512)
    num_warps = 4
    num_stages = 1

    props = get_device_props(gate_up.device.index)
    num_sm = props['multi_processor_count']
    warps_per_sm = props['warps_per_sm']
    grid_size0 = triton.cdiv(N, BLOCK_SIZE_N)
    grid_size1 = min(M, num_sm * warps_per_sm // num_warps)
    assert grid_size0 < 65536 and grid_size1 < 65536
    grid = (grid_size0, grid_size1)
    _silu_and_mul_kernel[grid](gate_up,
                               out,
                               N,
                               M,
                               stride_gum=gate_up.stride(0),
                               stride_gun=gate_up.stride(1),
                               stride_om=out.stride(0),
                               stride_on=out.stride(1),
                               BLOCK_SIZE_N=BLOCK_SIZE_N,
                               num_warps=num_warps,
                               num_stages=num_stages)

    return out


@triton.jit
def _silu_and_mul_moe_ep_kernel(
    gateup_ptr,
    out_ptr,
    mask_ptr,
    N: tl.constexpr,
    M: tl.constexpr,
    stride_gue: tl.constexpr,
    stride_gum: tl.constexpr,
    stride_gun: tl.constexpr,
    stride_oe: tl.constexpr,
    stride_om: tl.constexpr,
    stride_on: tl.constexpr,
    stride_m: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    """Silu and mul kernel."""
    n_block_id = tl.program_id(0)
    e_id = tl.program_id(1)
    m_id_start = tl.program_id(2)
    m_id_stride = tl.num_programs(2)

    offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    if N % BLOCK_SIZE_N == 0:
        mask = None
    else:
        mask = offs_n < N

    mask_m = tl.load(mask_ptr + e_id * stride_m)
    mask_m = tl.minimum(mask_m, M)
    if mask_m < m_id_start:
        return
    gate_ptrs = gateup_ptr + e_id * stride_gue + m_id_start * stride_gum + offs_n * stride_gun
    up_ptrs = gate_ptrs + N * stride_gun
    out_ptrs = out_ptr + e_id * stride_oe + m_id_start * stride_om + offs_n * stride_on

    for _ in tl.range(m_id_start, mask_m, m_id_stride):
        gate = tl.load(gate_ptrs, mask=mask)
        up = tl.load(up_ptrs, mask=mask)
        # exp expect fp32
        gate = gate.to(tl.float32)
        gate = gate / (1 + fast_expf(-gate))
        gate = gate.to(gateup_ptr.dtype.element_ty)
        out = gate * up

        tl.store(out_ptrs, out, mask=mask)

        gate_ptrs += m_id_stride * stride_gum
        up_ptrs += m_id_stride * stride_gum
        out_ptrs += m_id_stride * stride_om


def silu_and_mul_moe_ep(gate_up: torch.Tensor, mask_m: torch.Tensor, out: torch.Tensor = None):
    """Silu and mul for moe with expert parallelism."""
    # gate_up: [num_experts, batch_size, 2*hidden_size]
    assert gate_up.dim() == 3
    assert mask_m.dim() == 1
    assert mask_m.size(0) == gate_up.size(0)

    stride_m = mask_m.stride(0)
    assert gate_up.size(0) % stride_m == 0

    E = gate_up.size(0)
    M = gate_up.size(1)
    N = gate_up.size(-1) // 2
    if out is None:
        out_shape = (E, M, N)
        out = gate_up.new_empty(out_shape)

    BLOCK_SIZE_N = triton.next_power_of_2(N)
    BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512)
    num_warps = 4
    num_stages = 1

    props = get_device_props(gate_up.device.index)
    num_sm = props['multi_processor_count']
    warps_per_sm = props['warps_per_sm']
    ctas_per_sm = warps_per_sm // num_warps
    ctas_per_device = num_sm * ctas_per_sm
    grid_size0 = triton.cdiv(N, BLOCK_SIZE_N)
    grid_size1 = min(M, triton.cdiv(ctas_per_device, grid_size0 * E))
    grid = (grid_size0, E, grid_size1)
    _silu_and_mul_moe_ep_kernel[grid](gate_up,
                                      out,
                                      mask_m,
                                      N,
                                      M,
                                      stride_gue=gate_up.stride(0),
                                      stride_gum=gate_up.stride(1),
                                      stride_gun=gate_up.stride(2),
                                      stride_oe=out.stride(0),
                                      stride_om=out.stride(1),
                                      stride_on=out.stride(2),
                                      stride_m=mask_m.stride(0),
                                      BLOCK_SIZE_N=BLOCK_SIZE_N,
                                      num_warps=num_warps,
                                      num_stages=num_stages)

    return out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def _apply_rotary_impl(x_l, x_h, cos_l, cos_h, sin_l, sin_h):
    """Apply rotary positional embedding implementation."""
    # x_l, x_h: [BLOCK, BLOCK_N]
    # cos_l, cos_h, sin_l, sin_h: [BLOCK, BLOCK_N]

    # qe_l = q_l * cos_l - q_h * sin_l
    # qe_h = q_h * cos_h + q_l * sin_h

    # triton 3.4 would do fma 3 times to perform the above computation,
    # which causes higher numerical error. So we manually expand the
    # computation to avoid fma.
    x_l_new0 = x_l * cos_l + 0
    x_l_new1 = x_h * sin_l + 0
    x_h_new0 = x_h * cos_h + 0
    x_h_new1 = x_l * sin_h + 0
    return x_l_new0 - x_l_new1, x_h_new0 + x_h_new1


@triton.jit(do_not_specialize=('seq_len', ))
def apply_rotary_pos_emb_qk_kernel(
    Q,
    K,
    COS,
    SIN,
    Q_EMB,
    K_EMB,
    seq_len,
    stride_qs: tl.constexpr,
    stride_qh: tl.constexpr,
    stride_qd: tl.constexpr,
    stride_ks: tl.constexpr,
    stride_kh: tl.constexpr,
    stride_kd: tl.constexpr,
    stride_qes: tl.constexpr,
    stride_qeh: tl.constexpr,
    stride_qed: tl.constexpr,
    stride_kes: tl.constexpr,
    stride_keh: tl.constexpr,
    stride_ked: tl.constexpr,
    half_size: tl.constexpr,
    BLOCK: tl.constexpr,
    BLOCK_QH: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """Apply rotary on key AND query kernel."""
    seq_block_id = tl.program_id(1)
    head_id = tl.program_id(0)

    pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)
    pos_mask = pos_offset < seq_len
    pos_offset = tl.max_contiguous(tl.multiple_of(pos_offset % seq_len, BLOCK), BLOCK)

    feat_size = half_size * 2
    feat_offset_l = tl.arange(0, BLOCK_N)
    feat_mask = feat_offset_l < half_size
    feat_offset_l = feat_offset_l % half_size
    feat_offset_h = half_size + feat_offset_l
    seq_mask = pos_mask[:, None] & feat_mask[None, :]
    cs_offset_l = pos_offset[:, None] * feat_size + feat_offset_l[None, :]
    cs_offset_h = pos_offset[:, None] * feat_size + feat_offset_h[None, :]
    q_elem_type = Q.dtype.element_ty
    cos_l = tl.load(COS + cs_offset_l).to(q_elem_type)
    cos_h = tl.load(COS + cs_offset_h).to(q_elem_type)
    sin_l = tl.load(SIN + cs_offset_l).to(q_elem_type)
    sin_h = tl.load(SIN + cs_offset_h).to(q_elem_type)

    if head_id < BLOCK_QH:
        q_ptr = Q + pos_offset * stride_qs
        qe_ptr = Q_EMB + pos_offset * stride_qes
        ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd
        qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd
        qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed
        qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed
        ql_ptrs += head_id * stride_qh
        qh_ptrs += head_id * stride_qh
        qel_ptrs += head_id * stride_qeh
        qeh_ptrs += head_id * stride_qeh

        q_l = tl.load(ql_ptrs)
        q_h = tl.load(qh_ptrs)

        qe_l, qe_h = _apply_rotary_impl(q_l, q_h, cos_l, cos_h, sin_l, sin_h)

        tl.store(qel_ptrs, qe_l, mask=seq_mask)
        tl.store(qeh_ptrs, qe_h, mask=seq_mask)
    else:
        head_id = head_id - BLOCK_QH
        k_ptr = K + pos_offset * stride_ks
        ke_ptr = K_EMB + pos_offset * stride_kes
        kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd
        kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd
        kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked
        keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked
        kl_ptrs += head_id * stride_kh
        kh_ptrs += head_id * stride_kh
        kel_ptrs += head_id * stride_keh
        keh_ptrs += head_id * stride_keh
        k_l = tl.load(kl_ptrs)
        k_h = tl.load(kh_ptrs)

        ke_l, ke_h = _apply_rotary_impl(k_l, k_h, cos_l, cos_h, sin_l, sin_h)

        tl.store(kel_ptrs, ke_l, mask=seq_mask)
        tl.store(keh_ptrs, ke_h, mask=seq_mask)


def apply_rotary_pos_emb(q: Tensor,
                         k: Tensor,
                         cos: Tensor,
                         sin: Tensor,
                         q_embed: Tensor = None,
                         k_embed: Tensor = None):
    """Apply rotary positional embedding on query and key.

    Args:
        q (Tensor): Query state.
        k (Tensor): Key state.
        cos (Tensor): cosine matrix (seq_len, dim).
        sin (Tensor): sine matrix (seq_len, dim).
        q_embed (Tensor): output q, can be same as q
        k_embed (Tensor): output k, can be same as k

    Returns:
        Tuple[Tensor, Tensor]: Embedded query and key.
    """
    if cos.device != q.device:
        cos = cos.to(device=q.device)
    if sin.device != q.device:
        sin = sin.to(device=q.device)

    if q_embed is None:
        q_embed = torch.empty_like(q)
    if k_embed is None:
        k_embed = torch.empty_like(k)

    seq_len = cos.numel() // cos.size(-1)

    if q.size(-1) == cos.size(-1):
        half_size = q.size(-1) // 2
    elif q.size(-1) > cos.size(-1):
        # only do rope with rope_dim size
        half_size = cos.size(-1) // 2
    else:
        raise ValueError('Not support head_dim < rope_dim, '
                         f'but given head_dim={q.size(-1)} '
                         f'rope_dim={cos.size(-1)}')
    BLOCK_N = triton.next_power_of_2(half_size)
    num_heads_q = q.size(-2)
    num_heads_k = k.size(-2)
    num_warps = 2
    num_stages = 1

    # compute best BLOCK size
    num_threads = num_warps * 32
    elem_size = q.dtype.itemsize
    elem_per_ldgv4 = 16 // elem_size
    BLOCK = num_threads * elem_per_ldgv4 // BLOCK_N
    BLOCK = max(1, BLOCK)

    grid = (
        num_heads_q + num_heads_k,
        triton.cdiv(seq_len, BLOCK),
    )
    apply_rotary_pos_emb_qk_kernel[grid](q,
                                         k,
                                         cos,
                                         sin,
                                         q_embed,
                                         k_embed,
                                         seq_len=seq_len,
                                         stride_qs=q.stride(-3),
                                         stride_qh=q.stride(-2),
                                         stride_qd=q.stride(-1),
                                         stride_ks=k.stride(-3),
                                         stride_kh=k.stride(-2),
                                         stride_kd=k.stride(-1),
                                         stride_qes=q_embed.stride(-3),
                                         stride_qeh=q_embed.stride(-2),
                                         stride_qed=q_embed.stride(-1),
                                         stride_kes=k_embed.stride(-3),
                                         stride_keh=k_embed.stride(-2),
                                         stride_ked=k_embed.stride(-1),
                                         half_size=half_size,
                                         BLOCK=BLOCK,
                                         BLOCK_QH=num_heads_q,
                                         BLOCK_N=BLOCK_N,
                                         num_warps=num_warps,
                                         num_stages=num_stages)

    return q_embed, k_embed


================================================
FILE: lmdeploy/pytorch/kernels/cuda/awq_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
from triton import language as tl


def get_cuda_autotune_config():
    return [
        triton.Config({
            'BLOCK_SIZE_N': 64,
            'GROUP_SIZE_M': 8,
        }, num_stages=3, num_warps=4),
        triton.Config({
            'BLOCK_SIZE_N': 128,
            'GROUP_SIZE_M': 8,
        }, num_stages=3, num_warps=4),
    ]


@triton.jit
def _dequant_s4_to_f16x2(weight, shift: tl.constexpr, is_top: tl.constexpr):

    immLut: tl.constexpr = (0xf0 & 0xcc) | 0xaa
    BOTTOM_MASK: tl.constexpr = 0x000f000f
    TOP_MASK: tl.constexpr = 0x00f000f0
    I4s_TO_F16s_MAGIC_NUM: tl.constexpr = 0x64006400
    FP16_TOP_MAGIC_NUM: tl.constexpr = 0x64006400
    ONE_SIXTEENTH: tl.constexpr = 0x2c002c00
    NEG_64: tl.constexpr = 0xd400d400

    if shift:
        weight = weight >> 8

    if is_top:
        return tl.inline_asm_elementwise("""{
        .reg .b32 tmp;
        lop3.b32 tmp, $2, $3, $4, $5;
        fma.rn.f16x2 tmp, tmp, $6, $7;
        mov.b32 {$0, $1}, tmp;
    }""",
                                         '=h,=h,r,n,n,n,r,r',
                                         args=[weight, TOP_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, ONE_SIXTEENTH, NEG_64],
                                         dtype=(tl.float16, tl.float16),
                                         is_pure=True,
                                         pack=1)
    else:
        return tl.inline_asm_elementwise("""{
        .reg .b32 tmp;
        lop3.b32 tmp, $2, $3, $4, $5;
        sub.f16x2 tmp, tmp, $6;
        mov.b32 {$0, $1}, tmp;
    }""",
                                         '=h,=h,r,n,n,n,r',
                                         args=[weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, FP16_TOP_MAGIC_NUM],
                                         dtype=(tl.float16, tl.float16),
                                         is_pure=True,
                                         pack=1)


@triton.jit
def _unpack_weight(weight):
    """Unpack weight."""
    # broadcast and shift
    width: tl.constexpr = 8
    BLOCK_SIZE_K: tl.constexpr = weight.shape[0]
    BLOCK_SIZE_QN: tl.constexpr = weight.shape[1]
    BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width

    w0, w1 = _dequant_s4_to_f16x2(weight, False, False)
    w2, w3 = _dequant_s4_to_f16x2(weight, False, True)
    w4, w5 = _dequant_s4_to_f16x2(weight, True, False)
    w6, w7 = _dequant_s4_to_f16x2(weight, True, True)

    w04 = tl.join(w0, w4)
    w15 = tl.join(w1, w5)
    w26 = tl.join(w2, w6)
    w37 = tl.join(w3, w7)
    w0246 = tl.join(w04, w26)
    w1357 = tl.join(w15, w37)
    weight = tl.join(w0246, w1357)

    return weight.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)


@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['N', 'K'],
    reset_to_zero=['c_ptr'],
)
@triton.jit
def awq_linear_kernel(
        a_ptr,
        qw_ptr,
        s_ptr,
        qz_ptr,
        c_ptr,
        M,
        N: tl.constexpr,
        K: tl.constexpr,
        stride_am,
        stride_ak: tl.constexpr,  #
        stride_wk: tl.constexpr,
        stride_wn: tl.constexpr,  #
        stride_sk: tl.constexpr,
        stride_sn: tl.constexpr,  #
        stride_zk: tl.constexpr,
        stride_zn: tl.constexpr,  #
        stride_cm,
        stride_cn: tl.constexpr,
        # Meta-parameters
        SPLIT_K: tl.constexpr,
        NUM_STAGES: tl.constexpr,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
        BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
):
    """Kernel for computing the matmul C = A x B.

    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """

    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    kid = tl.program_id(axis=1)
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8
    offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
    s_ptrs = s_ptr + offs_bn * stride_sn
    qz_ptrs = qz_ptr + offs_wn * stride_zn

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    k_start = kid
    k_last = K // BLOCK_SIZE_K

    # prefetch
    a_ptrs += k_start * BLOCK_SIZE_K * stride_ak
    qw_ptrs += k_start * BLOCK_SIZE_K * stride_wk
    s_ptrs += k_start * stride_sk
    qz_ptrs += k_start * stride_zk
    qw = tl.load(qw_ptrs)
    qz = tl.load(qz_ptrs)[None, :]
    s = tl.load(s_ptrs)[None, :]
    qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
    s_ptrs += SPLIT_K * stride_sk
    qz_ptrs += SPLIT_K * stride_zk

    for k in tl.range(k_start, k_last, SPLIT_K, num_stages=NUM_STAGES):

        # unpack b
        z = _unpack_weight(qz)
        w = _unpack_weight(qw)
        b = (w - z) * s

        # load a
        a = tl.load(a_ptrs)

        # load next q
        mask = k + SPLIT_K < k_last
        qz = tl.load(qz_ptrs, mask=mask)[None, :]
        s = tl.load(s_ptrs, mask=mask)[None, :]
        qw = tl.load(qw_ptrs, mask=mask)

        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, acc=accumulator)

        # Advance the ptrs to the next K block.
        a_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_ak
        qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
        s_ptrs += SPLIT_K * stride_sk
        qz_ptrs += SPLIT_K * stride_zk

    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

    if SPLIT_K > 1:
        tl.atomic_add(c_ptrs, c, mask=c_mask, sem='relaxed', scope='gpu')
    else:
        tl.store(c_ptrs, c, mask=c_mask)


def awq_linear(x, qweight, scales, qzeros):
    """Awq linear."""
    M = x.size(0)
    K = qweight.size(0)
    N = scales.size(1)
    group_size = K // scales.size(0)
    SPLIT_K = max(1, K // 4096)

    def grid(META):
        """grid."""
        return (
            triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
            SPLIT_K,
        )

    if SPLIT_K > 1:
        out = scales.new_zeros(M, N)
    else:
        out = scales.new_empty(M, N)

    props = torch.cuda.get_device_properties(x.device)
    if props.major == 9:
        num_stages = 2
    elif props.major == 8 and props.minor in [6, 9]:
        num_stages = 2
    else:
        num_stages = 3

    BLOCK_SIZE_M = triton.next_power_of_2(M)
    BLOCK_SIZE_M = max(16, min(128, BLOCK_SIZE_M))
    awq_linear_kernel[grid](
        # Pointers to matrices
        x,
        qweight,
        scales,
        qzeros,
        out,
        # Matrix dimensions
        M,
        N,
        K,
        stride_am=x.stride(0),
        stride_ak=x.stride(1),  #
        stride_wk=qweight.stride(0),
        stride_wn=qweight.stride(1),  #
        stride_sk=scales.stride(0),
        stride_sn=scales.stride(1),  #
        stride_zk=qzeros.stride(0),
        stride_zn=qzeros.stride(1),  #
        stride_cm=out.stride(0),
        stride_cn=out.stride(1),
        # Meta-parameters
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_K=group_size,
        SPLIT_K=SPLIT_K,
        NUM_STAGES=num_stages,
    )

    return out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/bitonic_topk.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from triton.language import core
from triton.language.standard import _log2

try:
    # For Triton >= 3.6.0, core.get_int_dtype must be wrapped with
    # triton.runtime.jit.constexpr_function to be usable as a constexpr helper
    # inside @triton.jit kernels. This try/except keeps compatibility with
    # older Triton versions where constexpr_function is not available.
    get_int_dtype = triton.runtime.jit.constexpr_function(core.get_int_dtype)
except Exception:
    # fallback to original function if constexpr_function is not available (Triton < 3.6.0)
    get_int_dtype = core.get_int_dtype


@triton.jit
def _indicator(n_dims: core.constexpr, j: core.constexpr):
    ar = core.arange(0, 2)
    ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
    return ar


@triton.jit
def _flip_along_middle(x, n_dims, i):
    idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
    ix = x.to(idtype, bitcast=True)
    iy = ix ^ tl.xor_sum(ix, n_dims - 1 - i, True)
    y = iy.to(x.dtype, bitcast=True)
    return y


@triton.jit
def _compare_and_swap(x, ids, flip, i: core.constexpr):
    # compare-and-swap on the ith *innermost* dimension
    n_dims: core.constexpr = _log2(x.numel)

    # determines whether we are in the right (rather than left) position along the axis:
    is_right = _indicator(n_dims, i)

    # flip along middle dimension (the bitwise XORs will be optimised away):
    y = _flip_along_middle(x, n_dims, i)
    ids_y = _flip_along_middle(ids, n_dims, i)

    # conditional swap:
    mask = (x > y) != (flip ^ is_right)
    ret_x = core.where(mask, y, x)
    ret_ids = core.where(mask, ids_y, ids)
    return ret_x, ret_ids


@triton.jit
def _bitonic_merge_hypercube(x, ids, stage: core.constexpr, order: core.constexpr):
    """order_type 0 == ascending order_type 1 == descending order_type 2 ==
    alternating."""
    # flip denotes whether to re-arrange sub-sequences of elements in ascending or
    # descending order.
    # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
    # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
    # a stride of 2) at this stage
    if order == 2:
        flip = _indicator(_log2(x.numel), stage)
    else:
        flip = order
    # perform `stage` rounds of `compare-and-swap`
    for i in core.static_range(stage):
        x, ids = _compare_and_swap(x, ids, flip, stage - 1 - i)
    return x, ids


@triton.jit
def _bitonic_merge(x, ids, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
    """order_type 0 == ascending order_type 1 == descending order_type 2 ==
    alternating."""
    h = core.reshape(x, [2] * _log2(x.numel))
    h_ids = core.reshape(ids, [2] * _log2(x.numel))
    h, h_ids = _bitonic_merge_hypercube(h, h_ids, stage, order)
    x = core.reshape(h, x.shape)
    ids = core.reshape(h_ids, ids.shape)
    return x, ids


@triton.jit
def argsort(x, ids, dim: tl.constexpr = None, descending: tl.constexpr = core.CONSTEXPR_0):
    # handle default dimension or check that it is the most minor dim
    _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
    tl.static_assert(_dim == len(x.shape) - 1, 'only minor dimension is currently supported')
    # iteratively run bitonic merge-sort steps
    n_dims: tl.constexpr = _log2(x.shape[_dim])

    for i in tl.static_range(1, n_dims + 1):
        x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
    return x, ids


@triton.jit
def _bitonic_topk_kernel0(score_ptr,
                          seqlen_ptr,
                          out_ptr,
                          ids_ptr,
                          stride_m,
                          K: tl.constexpr,
                          fill: tl.constexpr,
                          descending: tl.constexpr = core.CONSTEXPR_0,
                          sorted: tl.constexpr = True):
    """kernel0."""
    batch_id = tl.program_id(0).to(tl.int64)
    block_id = tl.program_id(1).to(tl.int64)

    seqlen = tl.load(seqlen_ptr + batch_id)

    if block_id * K >= seqlen:
        return

    offs_k = tl.arange(0, K)
    origin_ids = block_id * K + offs_k
    # num scores should less than max(int32), I guess
    origin_ids = origin_ids.to(tl.int32)
    mask = (origin_ids < seqlen)
    score_ptrs = score_ptr + batch_id * stride_m + origin_ids
    scores = tl.load(score_ptrs, mask=mask, other=-1e6)
    ids = tl.where(mask, origin_ids, fill)
    ids = origin_ids

    if sorted or (seqlen > K):
        scores, ids = argsort(scores, ids, 0, descending)

    tl.store(out_ptr + batch_id * stride_m + origin_ids, scores, mask=mask)
    tl.store(ids_ptr + batch_id * stride_m + origin_ids, ids, mask=mask)


@triton.jit
def _concate(a, b):
    """concate."""
    c = tl.join(a, b)  # [k, 2]
    c = c.trans()  # [2, k]
    # there are bugs in `tr.ravel` when triton<=3.2.0
    c = tl.reshape(c, (a.numel + b.numel, ))
    return c


@triton.jit
def _split(a, k):
    """split."""
    a = a.reshape(2, k)
    a = a.trans()
    return tl.split(a)


@triton.jit
def _bitonic_topk_kernel1(score_ptr,
                          ids_ptr,
                          seqlen_ptr,
                          out_ptr,
                          stride_m,
                          K: tl.constexpr,
                          fill: tl.constexpr,
                          threshold: tl.constexpr,
                          descending: tl.constexpr = core.CONSTEXPR_0):
    """kernel1."""
    batch_id = tl.program_id(0).to(tl.int64)

    seqlen = tl.load(seqlen_ptr + batch_id)
    offs_k = tl.arange(0, K)
    score_ptrs = score_ptr + batch_id * stride_m + offs_k
    ids_ptrs = ids_ptr + batch_id * stride_m + offs_k

    # initialize
    pos = offs_k
    mask = pos < seqlen
    scores = tl.load(score_ptrs, mask=mask, other=threshold)
    ids = tl.load(ids_ptrs, mask=mask, other=fill)

    pos = 2 * K - 1 - offs_k
    score_ptrs = score_ptr + batch_id * stride_m + pos
    ids_ptrs = ids_ptr + batch_id * stride_m + pos

    stage: tl.constexpr = _log2(2 * K)
    for k in tl.range(K, seqlen, K, num_stages=3):
        mask = pos < seqlen
        new_scores = tl.load(score_ptrs, mask=mask, other=threshold)
        new_ids = tl.load(ids_ptrs, mask=mask, other=fill)

        merged_scores = _concate(scores, new_scores)
        merged_ids = _concate(ids, new_ids)

        merged_scores, merged_ids = _bitonic_merge(merged_scores, merged_ids, stage, descending, stage)

        scores, _ = _split(merged_scores, K)
        ids, _ = _split(merged_ids, K)
        score_ptrs += K
        ids_ptrs += K
        pos += K

    out_ptrs = out_ptr + batch_id * K + offs_k
    ids = tl.where(scores <= threshold, fill, ids)
    tl.store(out_ptrs, ids)


def bitonic_topk(scores: torch.Tensor,
                 q_seqlens: torch.Tensor,
                 kv_seqlens: torch.Tensor,
                 k: int,
                 fill: int = -1,
                 descending: bool = True,
                 sorted: bool = True,
                 threshold: float = -1e6):
    """Bitnoic topk."""
    num_tokens = scores.size(0)
    max_kv_len = scores.size(-1)
    assert max_kv_len < (1 << 31)

    if num_tokens != kv_seqlens.size(0):
        repeat_kv_seqlens = torch.repeat_interleave(kv_seqlens, q_seqlens, output_size=num_tokens)
    else:
        repeat_kv_seqlens = kv_seqlens
    tmp_scores = torch.empty_like(scores)
    tmp_ids = torch.empty_like(scores, dtype=torch.int32)
    num_warps = triton.cdiv(k, 4096)
    grid = (num_tokens, triton.cdiv(max_kv_len, k))
    _bitonic_topk_kernel0[grid](scores,
                                repeat_kv_seqlens,
                                tmp_scores,
                                tmp_ids,
                                stride_m=scores.stride(0),
                                K=k,
                                fill=fill,
                                descending=1 if descending else 0,
                                sorted=sorted,
                                num_warps=num_warps)

    out = kv_seqlens.new_empty((num_tokens, k), dtype=torch.int32)
    _bitonic_topk_kernel1[(num_tokens, )](tmp_scores,
                                          tmp_ids,
                                          repeat_kv_seqlens,
                                          out,
                                          stride_m=tmp_scores.stride(0),
                                          K=k,
                                          fill=fill,
                                          descending=1 if descending else 0,
                                          threshold=threshold,
                                          num_warps=num_warps * 2)
    return out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
from typing import Callable

import torch
import triton
import triton.language as tl

from .activation import silu_and_mul
from .blocked_gemm_fp8 import quant_fp8
from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize, moe_reduce


def get_cuda_autotune_config():
    return [
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 128,
        }, num_stages=3, num_warps=4),
    ]


@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['N', 'K', 'M_NP2'],
)
@triton.jit
def fused_moe_blocked_f8_kernel(
    A,
    A_scale,
    B,
    B_scale,
    bias,
    C,
    SortedIdx,
    ExpStart,
    ExpEnd,
    N: tl.constexpr,
    K: tl.constexpr,
    group_ak: tl.constexpr,
    group_bk: tl.constexpr,
    group_bn: tl.constexpr,
    stride_am: tl.constexpr,
    stride_ak: tl.constexpr,
    stride_asm,
    stride_ask: tl.constexpr,
    stride_be: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_bk: tl.constexpr,
    stride_bse: tl.constexpr,
    stride_bsk: tl.constexpr,
    stride_bsn: tl.constexpr,
    stride_bie: tl.constexpr,
    stride_bin: tl.constexpr,
    stride_cm,
    stride_cn: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    M_NP2: tl.constexpr,
    top_k: tl.constexpr,
    expert_offset: tl.constexpr,
    reindex_a: tl.constexpr,
    reindex_c: tl.constexpr,
):
    """Fused moe kernel."""
    exp_id = tl.program_id(1)
    pid = tl.program_id(0)

    exp_start = tl.load(ExpStart + exp_id + expert_offset)
    exp_end = tl.load(ExpEnd + exp_id + expert_offset)
    M = exp_end - exp_start
    if M <= 0:
        return

    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

    if GROUP_SIZE_M == 1:
        pid_m = pid % num_pid_m
        pid_n = pid // num_pid_m
    else:
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
        return

    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    mask_sid = offs_sid < exp_end
    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    if reindex_a:
        offs_am = sid // top_k
    else:
        offs_am = offs_sid
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

    # deepseek has 160 experts, exp index would overflow int32
    exp_id = exp_id.to(tl.int64)
    exp_off = stride_be * exp_id
    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    offs_bsn = pid_n * BLOCK_SIZE_N // group_bn
    as_ptrs = A_scale + offs_am * stride_asm
    bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # initialize acc_ratio and acc_scale
    a_scale = tl.load(as_ptrs, mask=mask_sid, other=1.0)
    b_scale = tl.load(bs_ptrs)
    acc_scale0 = a_scale * b_scale

    k_start = BLOCK_SIZE_K
    offs_ksa = k_start // group_ak
    offs_ksb = k_start // group_bk
    a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)
    b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)
    acc_scale1 = tl.maximum(a_scale * b_scale, 1e-12)
    acc_ratio = acc_scale0 / acc_scale1
    acc_scale = acc_scale1

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # load scales
        k_start = (k + 2) * BLOCK_SIZE_K
        offs_ksa = k_start // group_ak
        offs_ksb = k_start // group_bk
        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)
        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)

        # load ab
        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

        # mma
        accumulator = tl.dot(a, b, acc=accumulator)
        accumulator *= acc_ratio[:, None]

        # update scales and ratio
        new_acc_scale = tl.maximum(a_scale * b_scale, 1e-12)
        acc_ratio = acc_scale / new_acc_scale
        acc_scale = new_acc_scale

        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator * (acc_ratio * acc_scale)[:, None]

    if bias is not None:
        bias_ptrs = bias + exp_id * stride_bie + offs_bn * stride_bin
        bias_val = tl.load(bias_ptrs).to(accumulator.dtype)
        c += bias_val[None]

    c = c.to(C.dtype.element_ty)

    if reindex_c:
        offs_cm = sid
    else:
        offs_cm = offs_sid
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
    tl.store(c_ptrs, c, mask=mask_sid[:, None])


def fused_moe_blocked_fp8_kernel_launcher(
    A: torch.Tensor,
    A_scale: torch.Tensor,
    B: torch.Tensor,
    B_scale: torch.Tensor,
    C: torch.Tensor,
    sorted_idx: torch.Tensor,
    exp_start: torch.Tensor,
    exp_end: torch.Tensor,
    bias: torch.Tensor = None,
    top_k: int = 1,
    num_tokens: int = None,
    expert_offset: int = 0,
    reindex_a: bool = True,
    reindex_c: bool = True,
):
    """Fused moe kernel launcher."""

    if num_tokens is None:
        num_tokens = A.size(0)
    M_NP2 = triton.next_power_of_2(num_tokens)
    M_NP2 = max(64, M_NP2)
    E, N, K = B.shape

    assert A.dim() == 2
    assert A_scale.dim() == 2
    assert B.dim() == 3
    assert B_scale.dim() == 3

    assert K % A_scale.size(1) == 0
    assert K % B_scale.size(2) == 0
    assert N % B_scale.size(1) == 0

    group_ak = K // A_scale.size(1)
    group_bk = K // B_scale.size(2)
    group_bn = N // B_scale.size(1)

    def _grid_fn(META):
        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)
        return grid

    A = A.flatten(0, -2)
    C = C.flatten(0, -2)
    enable_bias = bias is not None

    BLOCK_SIZE_K = group_bk
    GROUP_SIZE_M = 1
    grid = _grid_fn
    fused_moe_blocked_f8_kernel[grid](
        A,
        A_scale,
        B,
        B_scale,
        bias,
        C,
        sorted_idx,
        exp_start,
        exp_end,
        N=N,
        K=K,
        group_ak=group_ak,
        group_bk=group_bk,
        group_bn=group_bn,
        stride_am=A.stride(0),
        stride_ak=A.stride(1),
        stride_asm=A_scale.stride(0),
        stride_ask=A_scale.stride(1),
        stride_be=B.stride(0),
        stride_bn=B.stride(1),
        stride_bk=B.stride(2),
        stride_bse=B_scale.stride(0),
        stride_bsn=B_scale.stride(1),
        stride_bsk=B_scale.stride(2),
        stride_cm=C.stride(0),
        stride_cn=C.stride(1),
        stride_bie=bias.stride(0) if enable_bias else 0,
        stride_bin=bias.stride(1) if enable_bias else 0,
        top_k=top_k,
        expert_offset=expert_offset,
        reindex_a=reindex_a,
        reindex_c=reindex_c,
        M_NP2=M_NP2,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        GROUP_SIZE_M=GROUP_SIZE_M,
    )


def fused_moe_blocked_fp8(input: torch.Tensor,
                          input_scale: torch.Tensor,
                          w1: torch.Tensor,
                          w1_scale: torch.Tensor,
                          w2: torch.Tensor,
                          w2_scale: torch.Tensor,
                          topk_weights: torch.Tensor,
                          topk_ids: torch.Tensor,
                          topk: int,
                          w1_bias: torch.Tensor = None,
                          w2_bias: torch.Tensor = None,
                          out_dtype: torch.dtype = torch.float16,
                          expert_offset: int = 0,
                          num_experts: int = None,
                          renormalize: bool = False,
                          act_func: Callable = None) -> torch.Tensor:
    """Fused moe."""
    device = input.device
    M = input.size(0)
    E, N, _ = w1.shape
    if num_experts is None:
        num_experts = E
    full_exp = num_experts == E
    group_size = input.size(-1) // input_scale.size(-1)

    topk_weights = _renormalize(topk_weights, renormalize)
    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)

    intermediate_cache1 = _make_intermediate((M, topk, N), dtype=out_dtype, device=device, zeros=not full_exp)
    # gate and up
    fused_moe_blocked_fp8_kernel_launcher(
        input,
        input_scale,
        w1,
        w1_scale,
        intermediate_cache1,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        bias=w1_bias,
        top_k=topk,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=True,
        reindex_c=False,
    )

    # activate
    intermediate_cache1 = intermediate_cache1.flatten(0, -2)
    if act_func is None:
        gate_cache = silu_and_mul(intermediate_cache1)
    else:
        gate_cache = act_func(intermediate_cache1)
    del intermediate_cache1
    gate_cache, gate_scale = quant_fp8(gate_cache, group_size, dtype=input.dtype)

    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), dtype=out_dtype, device=device, zeros=not full_exp)
    # down
    fused_moe_blocked_fp8_kernel_launcher(
        gate_cache,
        gate_scale,
        w2,
        w2_scale,
        intermediate_cache2,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        bias=w2_bias,
        top_k=1,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=False,
        reindex_c=True,
    )

    ret = moe_reduce(intermediate_cache2, topk_weights)
    return ret


================================================
FILE: lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import triton
import triton.language as tl
from torch import Tensor

from lmdeploy.utils import get_logger

from .utils import get_device_props

logger = get_logger('lmdeploy')


@triton.jit
def fast_log2_ceil(x):
    bits_x = tl.cast(x, tl.uint32, bitcast=True)
    exp_x = (bits_x >> 23) & 0xFF
    man_bits = bits_x & ((1 << 23) - 1)
    tmp = exp_x - 127 + tl.where(man_bits != 0, 1, 0)
    return tl.cast(tmp, tl.int32)


@triton.jit
def fast_pow2(x):
    bits_x = (x + 127) << 23
    return tl.cast(bits_x, tl.float32, bitcast=True)


@triton.jit
def fast_round_scale(amax, fp8_max_inv):
    return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))


@triton.jit(do_not_specialize=['M', 'M_out'])
def _quant_fp8_kernel(
    a_ptr,
    out_ptr,
    scale_ptr,
    M,
    M_out,
    K: tl.constexpr,
    num_groups_per_cta: tl.constexpr,
    fp8_min: tl.constexpr,
    fp8_max: tl.constexpr,
    stride_am,
    stride_ak: tl.constexpr,
    stride_om,
    stride_ok: tl.constexpr,
    stride_sm,
    stride_sg,
    ROUND_SCALE: tl.constexpr,
    GROUP_SIZE: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    """Quant fp8 kernel."""
    group_id = tl.program_id(0) * num_groups_per_cta
    m_id_start = tl.program_id(1)
    m_id_stride = tl.num_programs(1)

    GROUP_SIZE_CTA: tl.constexpr = GROUP_SIZE * num_groups_per_cta
    g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE_CTA)
    g_offs = tl.max_contiguous(tl.multiple_of(g_offs, GROUP_SIZE), GROUP_SIZE)
    gs_offs = group_id + tl.arange(0, num_groups_per_cta)
    rfp8_max = 1 / fp8_max

    m_id = m_id_start
    a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
    o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
    s_ptr = scale_ptr + m_id * stride_sm + gs_offs * stride_sg
    if K % GROUP_SIZE_CTA == 0:
        mask_n = True
        mask_s = True
        mask_o = True
    else:
        mask_n = g_offs < K
        mask_o = g_offs < K
        mask_s = gs_offs < tl.cdiv(K, GROUP_SIZE)

    for m_id in tl.range(m_id_start, M_out, m_id_stride, num_stages=NUM_STAGES):
        a = tl.load(a_ptrs, mask=mask_n & (m_id < M), other=0)
        a = a.reshape(num_groups_per_cta, GROUP_SIZE)
        a_max = tl.max(tl.abs(a), axis=1)
        a_max = tl.maximum(a_max, 1e-6).to(tl.float32)
        if ROUND_SCALE == 1:
            scale = fast_round_scale(a_max, rfp8_max)
            rscale = 1 / scale
        else:
            scale = a_max * rfp8_max
            rscale = fp8_max / a_max  # triton does not support rcp
        out = a.to(tl.float32) * rscale[:, None]

        out = tl.clamp(out, fp8_min, fp8_max)
        out = out.to(out_ptr.dtype.element_ty)
        out = out.reshape(GROUP_SIZE * num_groups_per_cta)
        tl.store(o_ptrs, out, mask=mask_o)
        tl.store(s_ptr, scale, mask=mask_s)

        a_ptrs += m_id_stride * stride_am
        o_ptrs += m_id_stride * stride_om
        s_ptr += m_id_stride * stride_sm


def _quant_fp8_launcher(A: Tensor, group_size: int, out: Tensor, scales: Tensor, scale_fmt: Optional[str] = None):
    """Quant online."""
    assert scale_fmt in (None, 'ue8m0')
    round_scale = 1 if scale_fmt == 'ue8m0' else 0
    M, K = A.shape
    M_out = out.size(0)

    dtype = out.dtype
    finfo = torch.finfo(dtype)
    fmin = finfo.min
    fmax = finfo.max

    num_warps = 2
    # every cp/ldg instruct can load 128bit=16byte data
    # each warp can read 512 byte data
    elem_size = A.element_size()
    num_groups_per_warp = 512 // (group_size * elem_size)
    num_groups_per_cta = num_groups_per_warp * num_warps
    grid_size0 = triton.cdiv(K, group_size * num_groups_per_cta)
    props = get_device_props(A.device.index)
    num_sm = props['multi_processor_count']
    warps_per_sm = props['warps_per_sm']
    blocks_per_sm = props['blocks_per_sm']
    max_ctas = num_sm * min(blocks_per_sm, warps_per_sm // num_warps)
    grid_size1 = min(M_out, max_ctas // grid_size0)
    assert grid_size1 < 65536
    num_stages = min(4, max(1, triton.cdiv(M_out, grid_size1)))
    grid = (grid_size0, grid_size1)
    _quant_fp8_kernel[grid](
        A,
        out,
        scales,
        M,
        M_out,
        K,
        num_groups_per_cta=num_groups_per_cta,
        fp8_min=fmin,
        fp8_max=fmax,
        stride_am=A.stride(0),
        stride_ak=A.stride(1),
        stride_om=out.stride(0),
        stride_ok=out.stride(1),
        stride_sm=scales.stride(0),
        stride_sg=scales.stride(1),
        ROUND_SCALE=round_scale,
        GROUP_SIZE=group_size,
        NUM_STAGES=num_stages,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return out, scales


def quant_fp8(A: Tensor,
              group_size: int,
              dtype: torch.dtype = torch.float8_e4m3fn,
              trans_scale: bool = False,
              scale_fmt: Optional[str] = None):
    """Quant fp8."""
    assert A.dim() == 2
    M, K = A.shape
    assert K % group_size == 0
    num_groups = K // group_size
    out = torch.empty_like(A, dtype=dtype)
    if trans_scale:
        scales = A.new_empty(num_groups, M, dtype=torch.float32).T
    else:
        scales = A.new_empty(M, num_groups, dtype=torch.float32)
    return _quant_fp8_launcher(A, group_size, out, scales, scale_fmt=scale_fmt)


def quant_fp8_tma(A: Tensor,
                  group_size: int,
                  dtype: torch.dtype = torch.float8_e4m3fn,
                  scale_fmt: Optional[str] = None):
    """Quant fp8 tma."""
    from lmdeploy.pytorch.third_party.deep_gemm import ceil_div, get_m_alignment_for_contiguous_layout
    assert A.dim() == 2
    M, K = A.shape
    assert K % group_size == 0
    num_groups = K // group_size
    alignment = get_m_alignment_for_contiguous_layout()
    aligned_M = ceil_div(M, alignment) * alignment
    out = A.new_empty(aligned_M, K, dtype=dtype)
    scales = A.new_empty(num_groups, aligned_M, dtype=torch.float32).T
    return _quant_fp8_launcher(A, group_size, out, scales, scale_fmt=scale_fmt)


def _gemm_fp8_tma_pre_hook(nargs):
    BLOCK_M = nargs['BLOCK_M']
    BLOCK_N = nargs['BLOCK_N']
    BLOCK_K = nargs['BLOCK_K']
    nargs['desc_a'].block_shape = (BLOCK_M, BLOCK_K)
    nargs['desc_b'].block_shape = (BLOCK_N, BLOCK_K)


@triton.autotune(configs=[
    triton.Config({
        'BLOCK_M': 128,
        'BLOCK_N': 128,
    }, num_stages=3, num_warps=8, pre_hook=_gemm_fp8_tma_pre_hook),
    triton.Config({
        'BLOCK_M': 128,
        'BLOCK_N': 64,
    }, num_stages=3, num_warps=4, pre_hook=_gemm_fp8_tma_pre_hook)
],
                 key=['N', 'K'])
@triton.jit
def _gemm_fp8_tma_kernel(
    desc_a,
    a_scale_ptr,
    desc_b,
    b_scale_ptr,
    C,
    M,
    N: tl.constexpr,
    K: tl.constexpr,
    group_ak: tl.constexpr,
    group_bk: tl.constexpr,
    group_bn: tl.constexpr,
    stride_asm: tl.constexpr,
    stride_ask,
    stride_bsk: tl.constexpr,
    stride_bsn: tl.constexpr,
    stride_cm,
    stride_cn: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_M: tl.constexpr,
):
    """Gemm fp8 kernel."""
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M

    offs_bsn = pid_n * BLOCK_N // group_bn
    as_ptrs = a_scale_ptr + offs_am * stride_asm
    bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn

    acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
    acc_ratio = 1 / acc_scale
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    off_m = pid_m * BLOCK_M
    off_n = pid_n * BLOCK_N
    off_k = 0
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # load scales
        k_start = (k + 1) * BLOCK_K
        offs_ksa = k_start // group_ak
        offs_ksb = k_start // group_bk
        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=k_start < K, other=1.0)
        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)

        # load ab
        a = desc_a.load([off_m, off_k])
        b = desc_b.load([off_n, off_k]).T

        # mma
        accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])

        # update scales and ratio
        new_acc_scale = a_scale * b_scale
        acc_ratio = acc_scale / new_acc_scale
        acc_scale = new_acc_scale

        off_k += BLOCK_K
    c = accumulator * (acc_ratio * acc_scale)[:, None]

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


@triton.autotune(configs=[
    triton.Config({
        'BLOCK_M': 64,
        'BLOCK_N': 128,
    }, num_stages=3, num_warps=4),
    triton.Config({
        'BLOCK_M': 128,
        'BLOCK_N': 64,
    }, num_stages=3, num_warps=4)
],
                 key=['N', 'K'])
@triton.jit
def _gemm_fp8_kernel(
    A,
    a_scale_ptr,
    B,
    b_scale_ptr,
    C,
    M,
    N: tl.constexpr,
    K: tl.constexpr,
    group_ak: tl.constexpr,
    group_bk: tl.constexpr,
    group_bn: tl.constexpr,
    stride_am,
    stride_ak: tl.constexpr,
    stride_asm: tl.constexpr,
    stride_ask,
    stride_bk: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_bsk: tl.constexpr,
    stride_bsn: tl.constexpr,
    stride_cm,
    stride_cn: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_M: tl.constexpr,
):
    """Gemm fp8 kernel."""
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    offs_bsn = pid_n * BLOCK_N // group_bn
    as_ptrs = a_scale_ptr + offs_am * stride_asm
    bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn

    acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
    acc_ratio = 1 / acc_scale
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # load scales
        k_start = (k + 1) * BLOCK_K
        offs_ksa = k_start // group_ak
        offs_ksb = k_start // group_bk
        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=k_start < K, other=1.0)
        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)

        # load ab
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)

        # mma
        accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])

        # update scales and ratio
        new_acc_scale = a_scale * b_scale
        acc_ratio = acc_scale / new_acc_scale
        acc_scale = new_acc_scale

        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    c = accumulator * (acc_ratio * acc_scale)[:, None]

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def blocked_gemm_fp8(A: Tensor,
                     A_scale: Tensor,
                     B: Tensor,
                     B_scale: torch.Tensor,
                     out_dtype: torch.dtype = torch.float16):
    """Gemm fp8."""

    def grid(META):
        return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )

    assert A.dim() == 2
    assert A_scale.dim() == 2
    assert B.dim() == 2
    assert B_scale.dim() == 2

    M, K = A.shape
    _, N = B.shape

    group_ak = triton.cdiv(K, A_scale.size(1))
    group_bk = triton.cdiv(K, B_scale.size(0))
    group_bn = triton.cdiv(N, B_scale.size(1))

    C = A.new_empty(M, N, dtype=out_dtype)

    BLOCK_K = max(group_ak, group_bk)

    from .utils import supports_tma

    run_tma = supports_tma()
    run_tma = run_tma and A.is_contiguous() and B.T.is_contiguous()

    # run_tma = False
    if run_tma:
        from .utils import TensorDescriptor

        dummy_block = (1, 1)
        desc_a = TensorDescriptor.from_tensor(A, block_shape=dummy_block)
        desc_b = TensorDescriptor.from_tensor(B.T, block_shape=dummy_block)

        def _grid_tma(META):
            """Grid tma."""
            BLOCK_M = META['BLOCK_M']
            BLOCK_N = META['BLOCK_N']
            return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )

        _gemm_fp8_tma_kernel[_grid_tma](
            desc_a,
            A_scale,
            desc_b,
            B_scale,
            C,
            M=M,
            N=N,
            K=K,
            group_ak=group_ak,
            group_bk=group_bk,
            group_bn=group_bn,
            stride_asm=A_scale.stride(0),
            stride_ask=A_scale.stride(1),
            stride_bsk=B_scale.stride(0),
            stride_bsn=B_scale.stride(1),
            stride_cm=C.stride(0),
            stride_cn=C.stride(1),
            BLOCK_K=BLOCK_K,
            GROUP_M=8,
        )
    else:
        _gemm_fp8_kernel[grid](
            A,
            A_scale,
            B,
            B_scale,
            C,
            M=M,
            N=N,
            K=K,
            group_ak=group_ak,
            group_bk=group_bk,
            group_bn=group_bn,
            stride_am=A.stride(0),
            stride_ak=A.stride(1),
            stride_asm=A_scale.stride(0),
            stride_ask=A_scale.stride(1),
            stride_bk=B.stride(0),
            stride_bn=B.stride(1),
            stride_bsk=B_scale.stride(0),
            stride_bsn=B_scale.stride(1),
            stride_cm=C.stride(0),
            stride_cn=C.stride(1),
            BLOCK_K=BLOCK_K,
            GROUP_M=8,
        )

    return C


def deep_gemm_fp8(A: Tensor,
                  A_scale: Tensor,
                  B: Tensor,
                  B_scale: torch.Tensor,
                  out_dtype: torch.dtype = torch.bfloat16):
    """Deepgemm fp8."""
    from lmdeploy.pytorch.third_party.deep_gemm import fp8_gemm_nt
    M, _ = A.shape
    N, _ = B.shape
    assert out_dtype == torch.bfloat16, 'DeepGemm requires bf16 output.'
    C = A.new_empty(M, N, dtype=out_dtype)
    fp8_gemm_nt((A, A_scale), (B, B_scale), C, None)
    return C


================================================
FILE: lmdeploy/pytorch/kernels/cuda/causal_conv1d.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import tilelang
import tilelang.language as T
import torch

# The kernels below is modified from: https://github.com/Dao-AILab/causal-conv1d


@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, )
def causal_conv1d_fwd(hidden_size, width, has_bias, activation, dtype, stride_x, num_warps, ChunkSizeL=64):
    """TileLang kernel for causal convolution forward pass.

    Each thread processes one output position for all channels sequentially.
    """
    num_threads = num_warps * 32
    num_bits = T.DataType(dtype).bits
    num_bytes = num_bits // 8
    # elems_per_row <= num_threads
    elems_per_row = 128 // num_bytes
    ChunkSizeC = elems_per_row
    silu_activation = activation in ['silu', 'swish']

    l_per_thread = min(ChunkSizeC * ChunkSizeL // num_threads, ChunkSizeL)
    assert num_threads * l_per_thread == ChunkSizeC * ChunkSizeL
    thrs_per_row = ChunkSizeL // l_per_thread
    assert thrs_per_row * l_per_thread == ChunkSizeL
    sum_seqlen = T.dynamic('sum_seqlen')

    @T.prim_func
    def causal_conv1d_fwd_main(
        X: T.StridedTensor([hidden_size, sum_seqlen], dtype=dtype, strides=(1, stride_x)),
        W: T.Tensor([hidden_size, width], dtype=dtype),
        seq_idx: T.Tensor([sum_seqlen], dtype=T.int32),
        Bias: T.Tensor([hidden_size], dtype=dtype) = None,
        Init_states: T.Tensor([hidden_size, width - 1], dtype=dtype) = None,
        Out: T.StridedTensor([hidden_size, sum_seqlen], dtype=dtype, strides=(1, hidden_size)) = None,
        Final_States: T.Tensor([hidden_size, width - 1], dtype=dtype) = None,
    ):
        # Process sum_seqlen output positions across all threads and blocks
        # every cta process (ChunkSizeC, ChunkSizeL) output tile
        with T.Kernel(T.ceildiv(hidden_size, ChunkSizeC), T.ceildiv(sum_seqlen, ChunkSizeL),
                      threads=num_threads) as (bc, bl):

            x_smem = T.alloc_shared((ChunkSizeL + width - 1, ChunkSizeC), dtype)

            # load x(copy can not be used on strided tensor)
            for lidx, cidx in T.Parallel(ChunkSizeL, ChunkSizeC):
                glidx = bl * ChunkSizeL + lidx
                gcidx = bc * ChunkSizeC + cidx
                x_smem[lidx + width - 1, cidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, X[gcidx, glidx],
                                                                T.cast(0.0, dtype))
            for lidx, cidx in T.Parallel(width, ChunkSizeC):
                glidx = bl * ChunkSizeL + lidx - width + 1
                gcidx = bc * ChunkSizeC + cidx
                x_smem[lidx, cidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, X[gcidx, glidx],
                                                    T.cast(0.0, dtype))

            x_local = T.alloc_local((width - 1 + l_per_thread, ), T.float32)
            seq_idx_local = T.alloc_local((width - 1 + l_per_thread, ), seq_idx.dtype)
            w_local = T.alloc_local((width, ), T.float32)
            if has_bias:
                bias_var = T.alloc_var(T.float32)
            else:
                bias_var = 0.0
            T.clear(w_local)

            tid = T.get_thread_binding(0)
            row_idx = tid // thrs_per_row
            col_idx = tid % thrs_per_row

            # load w/b
            if bc * ChunkSizeC + row_idx < hidden_size:
                for widx in T.unroll(width):
                    w_local[widx] = W[bc * ChunkSizeC + row_idx, widx]
                if has_bias:
                    bias_var = Bias[bc * ChunkSizeC + row_idx]

            # load x
            # load seq_idx
            for i in T.unroll(l_per_thread + width - 1):
                x_local[i] = x_smem[col_idx * l_per_thread + i, row_idx]

            # load seq_idx
            for i in T.unroll(l_per_thread + width - 1):
                gi = bl * ChunkSizeL + col_idx * l_per_thread + i - (width - 1)
                seq_idx_local[i] = T.if_then_else(gi >= 0 and gi < sum_seqlen, seq_idx[gi], -1)

            out_vals = T.alloc_local((l_per_thread, ), T.float32)
            T.clear(out_vals)
            for i in T.unroll(l_per_thread):
                out_vals[i] = bias_var
                seq_idx_cur = seq_idx_local[i + width - 1]
                if seq_idx_cur < 0:
                    out_vals[i] = 0.0
                    continue
                for w in T.unroll(width):
                    out_vals[i] += T.if_then_else(seq_idx_local[i + w] == seq_idx_cur, w_local[w] * x_local[i + w], 0.0)
                if silu_activation:
                    out_vals[i] = T.sigmoid(out_vals[i]) * out_vals[i]

            for i in T.unroll(l_per_thread):
                x_smem[col_idx * l_per_thread + i, row_idx] = out_vals[i]

            for lidx, cidx in T.Parallel(ChunkSizeL, ChunkSizeC):
                glidx = bl * ChunkSizeL + lidx
                gcidx = bc * ChunkSizeC + cidx
                Out[gcidx, glidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, x_smem[lidx, cidx],
                                                   T.cast(0.0, dtype))

    return causal_conv1d_fwd_main


def causal_conv1d_fn(
    x,
    weight,
    bias=None,
    seq_idx=None,
    initial_states=None,
    return_final_states=False,
    final_states_out=None,
    activation=None,
):
    """Causal 1D convolution function using TileLang kernel.

    Args:
        x: Input tensor of shape [batch_size, hidden_size, sequence_length]
           Note: batch_size must be 1
        weight: Convolution weights of shape [hidden_size, kernel_size]
        bias: Optional bias of shape [hidden_size]
        seq_idx: Sequence indices of shape [sequence_length] to handle multiple sequences
        initial_states: Initial states for sequence start [hidden_size, kernel_size-1]
        return_final_states: Whether to return final states
        final_states_out: Output tensor for final states
        activation: Activation function name ('silu', 'gelu', 'relu', or None)

    Returns:
        output: Convolution result of shape [batch_size, hidden_size, sequence_length]
        (and final_states if return_final_states=True)
    """
    assert x.dim() == 3, 'x should be in shape of [batch_size, hidden_size, sum_seqlen]'
    assert x.size(0) == 1, 'batch_size should be 1 for continuous batching'
    assert x.stride(1) == 1, 'x should be in channel last format'
    assert weight.dim() == 2, 'weight should be in shape of [hidden_size, kernel_size]'
    assert seq_idx is not None, 'seq_idx is required for causal_conv1d_fn'
    assert activation in ['silu', 'swish', None]
    assert not return_final_states, 'return_final_states=True is not supported in this version'

    _, hidden_size, _ = x.shape
    kernel_size = weight.shape[1]
    dtype = x.dtype

    # Reshape to 2D format for kernel: [hidden_size, sum_seqlen]
    x_2d = x.squeeze(0)  # [hidden_size, sum_seqlen]
    seq_idx_1d = seq_idx.squeeze(0) if seq_idx.dim() > 1 else seq_idx  # [sum_seqlen]

    # Initialize output tensor, hidden_size first for better memory access pattern
    out = x_2d.new_empty(x_2d.size(1), hidden_size)
    out = out.T

    # Create and call the TileLang kernel
    num_warps = 4  # Tunable parameter
    kernel = causal_conv1d_fwd(hidden_size, kernel_size, bias is not None, activation, dtype, x.stride(2), num_warps)

    kernel(
        x_2d,
        weight,
        seq_idx_1d,
        bias,
        initial_states,
        out,
        None,
    )

    # Reshape back to original format: [1, hidden_size, sum_seqlen]
    out = out.unsqueeze(0)

    return out


@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, )
def causal_conv1d_update_fwd(hidden_size: int, seqlen: int, state_len: int, width: int, has_bias: bool,
                             activation: str | None, dtype, conv_stride: tuple[int, int, int], num_warps: int):
    """TileLang kernel for causal convolution forward pass.

    Each thread processes one output position for all channels sequentially.
    """
    num_threads = num_warps * 32
    silu_activation = activation in ['silu', 'swish']

    advance_len = seqlen
    batch = T.dynamic('batch')
    conv_batch = T.dynamic('conv_batch')
    conv_batch_stride = T.dynamic('conv_batch_stride')
    update_idx = -(width - 1)
    update_idx = update_idx if update_idx >= 0 else update_idx + state_len

    @T.prim_func
    def causal_conv1d_update_main(
        X: T.Tensor((batch, hidden_size, seqlen), dtype=dtype),
        Conv_State: T.StridedTensor((conv_batch, hidden_size, state_len),
                                    dtype=dtype,
                                    strides=(conv_batch_stride, conv_stride[1], conv_stride[2])),
        W: T.Tensor((hidden_size, width), dtype=dtype),
        Bias: T.Tensor((hidden_size, ), dtype=dtype) = None,
        Out: T.Tensor((batch, hidden_size, seqlen), dtype=dtype) = None,
        Conv_state_indices: T.Tensor((batch, ), dtype=T.int32) = None,
    ):
        with T.Kernel(batch, T.ceildiv(hidden_size, num_threads), threads=num_threads) as (bi, bc):
            tidx = T.get_thread_binding(0)
            batch_id = bi
            channel_id = bc * num_threads + tidx

            # load conv state index
            conv_state_batch_coord = T.if_then_else(Conv_state_indices is not None, Conv_state_indices[batch_id],
                                                    T.cast(batch_id, T.int32))

            # skip padding tokens
            # tilelang does not support return in branch,
            # so I have to create this ugly branch to skip the computation for padding tokens
            if conv_state_batch_coord < 0:
                for i in T.unroll(seqlen, unroll_factor=2):
                    Out[batch_id, channel_id, i] = 0.0
            else:
                # load bias and weight
                bias_val = T.if_then_else(has_bias, T.cast(Bias[channel_id], T.float32), 0.0)
                weight_vals = T.alloc_local((width, ), T.float32)
                for i in T.unroll(width):
                    weight_vals[i] = W[channel_id, i]

                # fill conv states and read x_vals
                x_vals = T.alloc_local((width, ), T.float32)
                for i in T.unroll(state_len - advance_len - (width - 1), unroll_factor=2):
                    Conv_State[conv_state_batch_coord, channel_id, i] = Conv_State[conv_state_batch_coord, channel_id,
                                                                                   i + advance_len]
                for i in T.unroll(width - 1):
                    state_val = Conv_State[conv_state_batch_coord, channel_id, state_len - (width - 1) + i]
                    if i < advance_len + (width - 1) and state_len - advance_len - (width - 1) + i >= 0:
                        Conv_State[conv_state_batch_coord, channel_id,
                                   state_len - advance_len - (width - 1) + i] = state_val
                    x_vals[i] = state_val

                # compute output
                for i in T.unroll(seqlen, unroll_factor=2):
                    x_val = X[batch_id, channel_id, i]
                    if i < advance_len and state_len - advance_len + i >= 0:
                        Conv_State[conv_state_batch_coord, channel_id, state_len - advance_len + i] = x_val
                    x_vals[width - 1] = x_val
                    out_val = T.alloc_var(T.float32)
                    out_val = bias_val
                    for j in T.unroll(width):
                        out_val += weight_vals[j] * x_vals[j]
                    if silu_activation:
                        out_val = T.sigmoid(out_val) * out_val
                    Out[batch_id, channel_id, i] = out_val
                    # shift x_vals
                    for j in T.unroll(width - 1):
                        x_vals[j] = x_vals[j + 1]

    return causal_conv1d_update_main


# TODO: support cache_seqlens
# TODO: support complex layout
def causal_conv1d_update(x,
                         conv_state,
                         weight,
                         bias=None,
                         activation=None,
                         cache_seqlens=None,
                         conv_state_indices=None):
    """Tilelang implementation of causal_conv1d_update."""
    assert x.dim() in (2, 3)
    assert conv_state.dim() == 3
    assert weight.dim() == 2
    assert activation in ['silu', 'swish', None]
    assert cache_seqlens is None, 'cache_seqlens is not supported in this version'
    if conv_state_indices is not None:
        assert conv_state_indices.dim() == 1 and conv_state_indices.is_contiguous()
        assert conv_state_indices.dtype == torch.int32

    unsqueeze = x.dim() == 2
    if unsqueeze:
        x = x.unsqueeze(-1)

    has_bias = bias is not None
    width = weight.size(-1)
    _, hidden_size, seqlen = x.shape
    state_len = conv_state.size(-1)

    out = x.new_empty(x.shape)

    num_warps = 2
    kernel = causal_conv1d_update_fwd(hidden_size=hidden_size,
                                      seqlen=seqlen,
                                      state_len=state_len,
                                      width=width,
                                      has_bias=has_bias,
                                      activation=activation,
                                      dtype=x.dtype,
                                      conv_stride=conv_state.stride(),
                                      num_warps=num_warps)

    kernel(x, conv_state, weight, bias, out, conv_state_indices)

    if unsqueeze:
        out = out.squeeze(-1)

    return out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/ds_index.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl

from .utils import get_device_props


@triton.jit
def _fp8_index_kernel(
    q_ptr,
    q_s_ptr,
    k_cache_ptr,
    k_s_cache_ptr,
    cu_seqlen_q_ptr,
    k_seqlen_ptr,
    block_offset_ptr,
    out_ptr,
    stride_qm: tl.constexpr,
    stride_qh: tl.constexpr,
    stride_qd: tl.constexpr,
    stride_qsm: tl.constexpr,
    stride_qsh: tl.constexpr,
    stride_kb: tl.constexpr,
    stride_kn: tl.constexpr,
    stride_kd: tl.constexpr,
    stride_ksb: tl.constexpr,
    stride_ksn: tl.constexpr,
    stride_boff0,
    stride_boff1: tl.constexpr,
    stride_om,
    stride_on: tl.constexpr,
    max_q_seqlen,
    causal: tl.constexpr,
    BLOCK_H: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
    NUM_SPLIT: tl.constexpr,
):
    """Fp8 index kernel."""
    m_id = tl.program_id(0).to(tl.int64)
    split_id = tl.program_id(1).to(tl.int64)

    assert stride_qd == 1
    assert stride_kd == 1

    batch_id = m_id // max_q_seqlen
    q_id = m_id % max_q_seqlen
    q_start = tl.load(cu_seqlen_q_ptr + batch_id)
    q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_start
    if q_id >= q_seqlen:
        return

    k_seqlen = tl.load(k_seqlen_ptr + batch_id)
    if k_seqlen <= 0:
        return

    q_pos = q_start + q_id
    offs_h = tl.arange(0, BLOCK_H)
    offs_d = tl.arange(0, BLOCK_D)
    offs_n = tl.arange(0, BLOCK_N)

    q_ptrs = q_ptr + q_pos * stride_qm + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
    q_s_ptrs = q_s_ptr + q_pos * stride_qsm + offs_h * stride_qsh
    q = tl.load(q_ptrs)
    q_s = tl.load(q_s_ptrs)

    k_ptrs = k_cache_ptr + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
    k_s_ptrs = k_s_cache_ptr + offs_n * stride_ksn
    o_ptrs = out_ptr + q_pos * stride_om + offs_n * stride_on + split_id * BLOCK_N * stride_on
    boff_ptr = block_offset_ptr + batch_id * stride_boff0 + split_id * stride_boff1

    causal_pos = k_seqlen - q_seqlen + q_id
    num_blocks = tl.cdiv(k_seqlen, BLOCK_N)
    for boff_id in tl.range(split_id, num_blocks, NUM_SPLIT, num_stages=3):
        boff = tl.load(boff_ptr)

        k = tl.load(k_ptrs + boff * stride_kb)
        k_s = tl.load(k_s_ptrs + boff * stride_ksb)

        logits = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32)
        logits = tl.dot(q, k, acc=logits)
        logits = tl.maximum(logits, 0) * q_s[:, None]
        logits_sum = tl.sum(logits, axis=0) * k_s

        if causal:
            mask_off = boff_id * BLOCK_N + offs_n
            mask = mask_off <= causal_pos
            logits_sum = tl.where(mask, logits_sum, float('-inf'))

        tl.store(o_ptrs, logits_sum, mask=offs_n + boff_id * BLOCK_N < k_seqlen)
        boff_ptr += NUM_SPLIT * stride_boff1
        o_ptrs += NUM_SPLIT * BLOCK_N * stride_on


def fp8_index(q: torch.Tensor,
              q_s: torch.Tensor,
              k_cache: torch.Tensor,
              k_s_cache: torch.Tensor,
              cu_seqlen_q: torch.Tensor,
              k_seqlens: torch.Tensor,
              block_offset: torch.Tensor,
              max_q_seqlen: int = None,
              max_k_seqlen: int = None,
              causal: bool = False):
    """Fp8 index.

    q: (cum_seqlen, num_heads, head_dim)
    q_s: (cum_seqlen, num_heads)
    k_cache: (num_blocks, block_size, head_dim)
    k_s_cache: (num_blocks, block_size)
    cu_seqlen_q: (batch_size,)
    cu_seqlen_k: (batch_size,)
    block_offset: (batch_size, num_blocks)
    """
    assert q.dim() == 3
    assert k_cache.dim() == 3
    assert q_s.dim() == 2
    assert k_s_cache.dim() == 2
    cum_seqlen, num_heads, head_dim = q.shape
    block_size = k_cache.size(1)
    batch_size = k_seqlens.numel()
    is_decoding = batch_size == cum_seqlen
    if max_k_seqlen is None:
        max_num_blocks = k_cache.size(0)
        max_k_seqlen = max_num_blocks * block_size

    # max q seqlen
    if is_decoding:
        if max_q_seqlen is None:
            max_q_seqlen = 1
        assert max_q_seqlen == 1
    elif max_q_seqlen is None:
        max_q_seqlen = cum_seqlen

    assert q.stride(-1) == 1 and k_cache.stride(-1) == 1

    out = q.new_empty((cum_seqlen, max_k_seqlen), dtype=torch.float32)

    num_warps = 4
    device_idx = q.device.index
    props = get_device_props(device_idx)
    num_sm = props['multi_processor_count']
    # estimated occupancy 12.5%
    warps_per_sm = props['warps_per_sm'] // 8
    assert warps_per_sm >= num_warps
    cta_per_sm = warps_per_sm // num_warps
    cta_per_device = num_sm * cta_per_sm
    # we better have a tensor to indicate batch id of each q
    M = max_q_seqlen * batch_size
    NUM_SPLIT = max(1, triton.cdiv(cta_per_device, M))
    grid = (M, NUM_SPLIT)

    _fp8_index_kernel[grid](q,
                            q_s,
                            k_cache,
                            k_s_cache,
                            cu_seqlen_q,
                            k_seqlens,
                            block_offset,
                            out,
                            *q.stride(),
                            *q_s.stride(),
                            *k_cache.stride(),
                            *k_s_cache.stride(),
                            *block_offset.stride(),
                            *out.stride(),
                            max_q_seqlen=max_q_seqlen,
                            causal=causal,
                            BLOCK_H=num_heads,
                            BLOCK_N=block_size,
                            BLOCK_D=head_dim,
                            NUM_SPLIT=NUM_SPLIT,
                            num_warps=num_warps)
    return out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal, Optional

import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def _quant_int8(val):
    val_min = tl.min(val, 1)
    val_max = tl.max(val, 1)
    scales = (val_max - val_min) / 255
    zeros = -val_min / scales
    q_val = (val / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)
    return q_val, scales, zeros


@triton.jit
def _quant_int4(val1, val2):
    val1 = val1.to(tl.float32)
    val2 = val2.to(tl.float32)
    val_min = tl.min(tl.minimum(val1, val2), 1)
    val_max = tl.max(tl.maximum(val1, val2), 1)
    scales = (val_max - val_min) / 15
    zeros = -val_min / scales
    q_val1 = (val1 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)
    q_val2 = (val2 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)
    q_val = q_val1 + q_val2 * 16
    return q_val, scales, zeros


@triton.jit
def _fill_kv_cache_kernel(
    KStates,
    VStates,
    KCaches,
    VCaches,
    QStartLoc,
    QSeqLens,
    KVSeqLens,
    BlockOffsets,
    is_decoding: tl.constexpr,
    head_dim: tl.constexpr,
    head_dim_v: tl.constexpr,
    stride_kss,
    stride_ksh,
    stride_ksd,
    stride_vss,
    stride_vsh,
    stride_vsd,
    stride_kcn: tl.constexpr,
    stride_kcb: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_vcn: tl.constexpr,
    stride_vcb: tl.constexpr,
    stride_vch: tl.constexpr,
    stride_vcd: tl.constexpr,
    stride_boff,
    BLOCK: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Fill kv cache kernel."""
    batch_id = tl.program_id(2)
    head_id = tl.program_id(0)
    block_id = tl.program_id(1)

    q_startloc = tl.load(QStartLoc + batch_id)
    q_seqlen = tl.load(QSeqLens + batch_id)
    kv_seqlen = tl.load(KVSeqLens + batch_id)
    history_seqlen = kv_seqlen - q_seqlen

    kv_block_id = history_seqlen // BLOCK + block_id

    if kv_seqlen <= 0:
        return

    if kv_block_id * BLOCK >= kv_seqlen:
        return

    if is_decoding:
        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)
        kv_mask = tl.full((1, ), 1, dtype=tl.int1)
        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)
    else:
        page_offs = tl.arange(0, BLOCK)
        kv_offs = kv_block_id * BLOCK + page_offs
        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)
        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
        q_offs = token_off + page_offs

    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)

    d_off = tl.arange(0, BLOCK_D)
    mask_ks = kv_mask[:, None]
    mask_kc = mask_ks & (d_off[None, :] < head_dim)
    d_off = d_off % head_dim

    ks_ptr = KStates + head_id * stride_ksh
    ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd
    kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch
    kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd

    if BLOCK_DV > 0:
        dv_off = tl.arange(0, BLOCK_DV)
        mask_vs = kv_mask[:, None]
        mask_vc = mask_vs & (dv_off[None, :] < head_dim_v)
        dv_off = dv_off % head_dim_v
        vs_ptr = VStates + head_id * stride_vsh
        vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd
        vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch
        vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd

    k = tl.load(ks_ptrs, mask=mask_ks)
    if BLOCK_DV > 0:
        v = tl.load(vs_ptrs, mask=mask_vs)
    tl.store(kc_ptrs, k, mask=mask_kc)
    if BLOCK_DV > 0:
        tl.store(vc_ptrs, v, mask=mask_vc)


@triton.jit
def _fill_page_quant_int8(
    state_ptr,
    cache_ptr,
    scales_zeros_ptr,
    block_off,
    head_id,
    page_offs,
    q_offs,
    kv_mask,
    head_dim: tl.constexpr,
    stride_ss,
    stride_sh,
    stride_sd,
    stride_cn: tl.constexpr,
    stride_cb: tl.constexpr,
    stride_ch: tl.constexpr,
    stride_cd: tl.constexpr,
    stride_szn: tl.constexpr,
    stride_szb: tl.constexpr,
    stride_szh: tl.constexpr,
    stride_szd: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    """Fill page int8."""
    d_off = tl.arange(0, BLOCK_D)
    mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)
    d_off = d_off % head_dim
    state_ptr = state_ptr + head_id * stride_sh
    state_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd
    cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
    cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd
    scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
    scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb
    zeros_ptrs = scales_ptrs + stride_szd

    state = tl.load(state_ptrs, mask=kv_mask[:, None])
    state, scales, zeros = _quant_int8(state)

    tl.store(cache_ptrs, state, mask=mask_kc)
    tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])
    tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])


@triton.jit
def _fill_page_quant_int4(
    state_ptr,
    cache_ptr,
    scales_zeros_ptr,
    block_off,
    head_id,
    page_offs,
    q_offs,
    kv_mask,
    head_dim: tl.constexpr,
    stride_ss,
    stride_sh,
    stride_sd,
    stride_cn: tl.constexpr,
    stride_cb: tl.constexpr,
    stride_ch: tl.constexpr,
    stride_cd: tl.constexpr,
    stride_szn: tl.constexpr,
    stride_szb: tl.constexpr,
    stride_szh: tl.constexpr,
    stride_szd: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    """Fill page int4."""
    d_off = tl.arange(0, BLOCK_D)
    mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)
    state_ptr = state_ptr + head_id * stride_sh
    state0_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd
    state1_ptrs = state0_ptrs + head_dim * stride_sd
    cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
    cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd
    scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
    scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb
    zeros_ptrs = scales_ptrs + stride_szd

    state0 = tl.load(state0_ptrs, mask=mask_kc)
    state1 = tl.load(state1_ptrs, mask=mask_kc)
    state, scales, zeros = _quant_int4(state0, state1)

    tl.store(cache_ptrs, state, mask=mask_kc)
    tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])
    tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])


@triton.jit
def _fill_page_quant(state_ptr, cache_ptr, scales_zeros_ptr, block_off, head_id, page_offs, q_offs, kv_mask,
                     head_dim: tl.constexpr, stride_ss, stride_sh, stride_sd, stride_cn: tl.constexpr,
                     stride_cb: tl.constexpr, stride_ch: tl.constexpr, stride_cd: tl.constexpr,
                     stride_szn: tl.constexpr, stride_szb: tl.constexpr, stride_szh: tl.constexpr,
                     stride_szd: tl.constexpr, BLOCK_D: tl.constexpr, quant_policy: tl.constexpr):
    """Fill page."""
    if quant_policy == 8:
        return _fill_page_quant_int8(state_ptr,
                                     cache_ptr,
                                     scales_zeros_ptr,
                                     block_off,
                                     head_id,
                                     page_offs,
                                     q_offs,
                                     kv_mask,
                                     head_dim=head_dim,
                                     stride_ss=stride_ss,
                                     stride_sh=stride_sh,
                                     stride_sd=stride_sd,
                                     stride_cn=stride_cn,
                                     stride_cb=stride_cb,
                                     stride_ch=stride_ch,
                                     stride_cd=stride_cd,
                                     stride_szn=stride_szn,
                                     stride_szb=stride_szb,
                                     stride_szh=stride_szh,
                                     stride_szd=stride_szd,
                                     BLOCK_D=BLOCK_D)
    elif quant_policy == 4:
        return _fill_page_quant_int4(state_ptr,
                                     cache_ptr,
                                     scales_zeros_ptr,
                                     block_off,
                                     head_id,
                                     page_offs,
                                     q_offs,
                                     kv_mask,
                                     head_dim=head_dim,
                                     stride_ss=stride_ss,
                                     stride_sh=stride_sh,
                                     stride_sd=stride_sd,
                                     stride_cn=stride_cn,
                                     stride_cb=stride_cb,
                                     stride_ch=stride_ch,
                                     stride_cd=stride_cd,
                                     stride_szn=stride_szn,
                                     stride_szb=stride_szb,
                                     stride_szh=stride_szh,
                                     stride_szd=stride_szd,
                                     BLOCK_D=BLOCK_D)
    else:
        tl.static_assert(False, 'Unsupported quant policy')


@triton.jit
def _fill_kv_cache_quant_kernel(
    KStates,
    VStates,
    KCaches,
    VCaches,
    KScalesZeros,
    VScalesZeros,
    QStartLoc,
    QSeqLens,
    KVSeqLens,
    BlockOffsets,
    is_decoding: tl.constexpr,
    head_dim: tl.constexpr,
    head_dim_v: tl.constexpr,
    stride_kss,
    stride_ksh,
    stride_ksd,
    stride_vss,
    stride_vsh,
    stride_vsd,
    stride_kcn: tl.constexpr,
    stride_kcb: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_vcn: tl.constexpr,
    stride_vcb: tl.constexpr,
    stride_vch: tl.constexpr,
    stride_vcd: tl.constexpr,
    stride_kszn: tl.constexpr,
    stride_kszb: tl.constexpr,
    stride_kszh: tl.constexpr,
    stride_kszd: tl.constexpr,
    stride_vszn: tl.constexpr,
    stride_vszb: tl.constexpr,
    stride_vszh: tl.constexpr,
    stride_vszd: tl.constexpr,
    quant_policy: tl.constexpr,
    stride_boff,
    BLOCK: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Fill kv cache kernel with int4 and int8 quant fuzed.

    Args:
        stride_xss: stride of sequence length dim of key or value states
        stride_xsh: stride of head_num dim of key or value states
        stride_xsh: stride of head_size dim of key or value states
        stride_xn: stride of page num dim
        stride_xb: stride of block size dim
        stride_xh: stride of head_num dim
        stride_xd: stride of head_size dim
    """
    batch_id = tl.program_id(2)
    head_id = tl.program_id(0)
    block_id = tl.program_id(1)

    q_startloc = tl.load(QStartLoc + batch_id)
    q_seqlen = tl.load(QSeqLens + batch_id)
    kv_seqlen = tl.load(KVSeqLens + batch_id)
    history_seqlen = kv_seqlen - q_seqlen

    kv_block_id = history_seqlen // BLOCK + block_id

    if kv_seqlen <= 0:
        return

    if kv_block_id * BLOCK >= kv_seqlen:
        return

    if is_decoding:
        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)
        kv_mask = tl.full((1, ), 1, dtype=tl.int1)
        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)
    else:
        page_offs = tl.arange(0, BLOCK)
        kv_offs = kv_block_id * BLOCK + page_offs
        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)
        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
        q_offs = token_off + page_offs

    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)

    _fill_page_quant(KStates,
                     KCaches,
                     KScalesZeros,
                     block_off,
                     head_id,
                     page_offs,
                     q_offs,
                     kv_mask,
                     head_dim=head_dim,
                     stride_ss=stride_kss,
                     stride_sh=stride_ksh,
                     stride_sd=stride_ksd,
                     stride_cn=stride_kcn,
                     stride_cb=stride_kcb,
                     stride_ch=stride_kch,
                     stride_cd=stride_kcd,
                     stride_szn=stride_kszn,
                     stride_szb=stride_kszb,
                     stride_szh=stride_kszh,
                     stride_szd=stride_kszd,
                     BLOCK_D=BLOCK_D,
                     quant_policy=quant_policy)

    if BLOCK_DV > 0:
        _fill_page_quant(VStates,
                         VCaches,
                         VScalesZeros,
                         block_off,
                         head_id,
                         page_offs,
                         q_offs,
                         kv_mask,
                         head_dim=head_dim_v,
                         stride_ss=stride_vss,
                         stride_sh=stride_vsh,
                         stride_sd=stride_vsd,
                         stride_cn=stride_vcn,
                         stride_cb=stride_vcb,
                         stride_ch=stride_vch,
                         stride_cd=stride_vcd,
                         stride_szn=stride_vszn,
                         stride_szb=stride_vszb,
                         stride_szh=stride_vszh,
                         stride_szd=stride_vszd,
                         BLOCK_D=BLOCK_DV,
                         quant_policy=quant_policy)


def fill_kv_cache(k_states: Tensor,
                  v_states: Optional[Tensor],
                  k_caches: Tensor,
                  v_caches: Optional[Tensor],
                  q_start_loc: Tensor,
                  q_seq_length: Tensor,
                  kv_seq_length: Tensor,
                  max_q_seq_length: int,
                  block_offsets: Tensor,
                  k_scales_zeros: Tensor = None,
                  v_scales_zeros: Tensor = None,
                  quant_policy: Literal[0, 4, 8] = 0,
                  kv_layout: str = 'bshd'):
    """Fill key/value state to cache for paged attention."""
    if kv_layout == 'bshd':
        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
    elif kv_layout == 'bhsd':
        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
    else:
        raise RuntimeError('Unsupported layout.')
    if v_states is None:
        v_states = k_states[..., :0]
    if v_caches is None:
        v_caches = k_caches[..., :0]

    block_offsets = block_offsets.contiguous()
    batch_size = block_offsets.size(0)
    block_size = k_caches.size(s_dim)
    num_heads = k_caches.size(h_dim)
    head_dim = k_caches.size(d_dim)
    head_dim_v = v_caches.size(d_dim)
    if v_states.size(-1) == 0:
        head_dim_v = 0
    if max_q_seq_length == 1:
        max_num_blocks = 1
    else:
        max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1

    BLOCK = block_size
    BLOCK_D = triton.next_power_of_2(head_dim)
    BLOCK_DV = triton.next_power_of_2(head_dim_v)
    if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:
        BLOCK_DV = 0
    grid = (num_heads, max_num_blocks, batch_size)
    is_decoding = max_num_blocks == 1
    if quant_policy == 0:
        _fill_kv_cache_kernel[grid](
            k_states,
            v_states,
            k_caches,
            v_caches,
            q_start_loc,
            q_seq_length,
            kv_seq_length,
            block_offsets,
            is_decoding=is_decoding,
            head_dim=head_dim,
            head_dim_v=head_dim_v,
            stride_kss=k_states.stride(-3),
            stride_ksh=k_states.stride(-2),
            stride_ksd=k_states.stride(-1),
            stride_vss=v_states.stride(-3),
            stride_vsh=v_states.stride(-2),
            stride_vsd=v_states.stride(-1),
            stride_kcn=k_caches.stride(b_dim),
            stride_kcb=k_caches.stride(s_dim),
            stride_kch=k_caches.stride(h_dim),
            stride_kcd=k_caches.stride(d_dim),
            stride_vcn=v_caches.stride(b_dim),
            stride_vcb=v_caches.stride(s_dim),
            stride_vch=v_caches.stride(h_dim),
            stride_vcd=v_caches.stride(d_dim),
            stride_boff=block_offsets.stride(0),
            BLOCK=BLOCK,
            BLOCK_D=BLOCK_D,
            BLOCK_DV=BLOCK_DV,
            num_warps=4,
            num_stages=3,
        )
    else:
        _fill_kv_cache_quant_kernel[grid](
            k_states,
            v_states,
            k_caches,
            v_caches,
            k_scales_zeros,
            v_scales_zeros,
            q_start_loc,
            q_seq_length,
            kv_seq_length,
            block_offsets,
            is_decoding=is_decoding,
            head_dim=head_dim,
            head_dim_v=head_dim_v,
            stride_kss=k_states.stride(-3),
            stride_ksh=k_states.stride(-2),
            stride_ksd=k_states.stride(-1),
            stride_vss=v_states.stride(-3),
            stride_vsh=v_states.stride(-2),
            stride_vsd=v_states.stride(-1),
            stride_kcn=k_caches.stride(b_dim),
            stride_kcb=k_caches.stride(s_dim),
            stride_kch=k_caches.stride(h_dim),
            stride_kcd=k_caches.stride(d_dim),
            stride_vcn=v_caches.stride(b_dim),
            stride_vcb=v_caches.stride(s_dim),
            stride_vch=v_caches.stride(h_dim),
            stride_vcd=v_caches.stride(d_dim),
            stride_kszn=k_scales_zeros.stride(b_dim),
            stride_kszb=k_scales_zeros.stride(s_dim),
            stride_kszh=k_scales_zeros.stride(h_dim),
            stride_kszd=k_scales_zeros.stride(d_dim),
            stride_vszn=v_scales_zeros.stride(b_dim),
            stride_vszb=v_scales_zeros.stride(s_dim),
            stride_vszh=v_scales_zeros.stride(h_dim),
            stride_vszd=v_scales_zeros.stride(d_dim),
            quant_policy=quant_policy,
            stride_boff=block_offsets.stride(0),
            BLOCK=BLOCK,
            BLOCK_D=BLOCK_D,
            BLOCK_DV=BLOCK_DV,
            num_warps=4,
            num_stages=1,
        )


@triton.jit
def fast_log2_ceil(x):
    bits_x = tl.cast(x, tl.uint32, bitcast=True)
    exp_x = (bits_x >> 23) & 0xFF
    man_bits = bits_x & ((1 << 23) - 1)
    tmp = exp_x - 127 + tl.where(man_bits != 0, 1, 0)
    return tl.cast(tmp, tl.int32)


@triton.jit
def fast_pow2(x):
    bits_x = (x + 127) << 23
    return tl.cast(bits_x, tl.float32, bitcast=True)


@triton.jit
def fast_round_scale(amax, fp8_max_inv):
    return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))


@triton.jit
def _quant_blocked_fp8(x,
                       fp8_min: tl.constexpr,
                       fp8_max: tl.constexpr,
                       dtype: tl.constexpr,
                       GROUP_SIZE: tl.constexpr = 128,
                       ROUND_SCALE: tl.constexpr = 0):
    x = x.to(tl.float32)
    M: tl.constexpr = x.shape[0]
    N: tl.constexpr = x.shape[1]
    rfp8_max: tl.constexpr = 1 / fp8_max
    x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)
    amax = tl.maximum(tl.max(tl.abs(x), axis=2, keep_dims=True), 1e-6)
    if ROUND_SCALE == 1:
        scale = fast_round_scale(amax, rfp8_max)
    else:
        scale = amax * rfp8_max
    out = x / scale

    out = tl.clamp(out, fp8_min, fp8_max)
    out = out.to(dtype)
    out = out.reshape(M, N)
    scale = scale.reshape(M, N // GROUP_SIZE)
    return out, scale


@triton.jit
def _fill_kv_cache_blocked_fp8_kernel(
    KStates,
    VStates,
    KCaches,
    VCaches,
    KSCaches,
    VSCaches,
    cu_seqlen_q_ptr,
    KVSeqLens,
    BlockOffsets,
    fp8_min: tl.constexpr,
    fp8_max: tl.constexpr,
    is_decoding: tl.constexpr,
    head_dim: tl.constexpr,
    head_dim_v: tl.constexpr,
    stride_kss,
    stride_ksh,
    stride_ksd,
    stride_vss,
    stride_vsh,
    stride_vsd,
    stride_kcn: tl.constexpr,
    stride_kcb: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_vcn: tl.constexpr,
    stride_vcb: tl.constexpr,
    stride_vch: tl.constexpr,
    stride_vcd: tl.constexpr,
    stride_kscn: tl.constexpr,
    stride_kscb: tl.constexpr,
    stride_ksch: tl.constexpr,
    stride_kscd: tl.constexpr,
    stride_vscn: tl.constexpr,
    stride_vscb: tl.constexpr,
    stride_vsch: tl.constexpr,
    stride_vscd: tl.constexpr,
    stride_boff,
    ROUND_SCALE: tl.constexpr,
    GROUP_SIZE: tl.constexpr,
    BLOCK: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Fill kv cache kernel."""
    batch_id = tl.program_id(2)
    head_id = tl.program_id(0)
    block_id = tl.program_id(1)

    q_startloc = tl.load(cu_seqlen_q_ptr + batch_id)
    q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_startloc
    kv_seqlen = tl.load(KVSeqLens + batch_id)
    history_seqlen = kv_seqlen - q_seqlen

    kv_block_id = history_seqlen // BLOCK + block_id

    if kv_seqlen <= 0:
        return

    if kv_block_id * BLOCK >= kv_seqlen:
        return

    if is_decoding:
        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)
        kv_mask = tl.full((1, ), 1, dtype=tl.int1)
        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)
    else:
        page_offs = tl.arange(0, BLOCK)
        kv_offs = kv_block_id * BLOCK + page_offs
        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)
        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
        q_offs = token_off + page_offs

    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)

    d_off = tl.arange(0, BLOCK_D)
    mask_ks = kv_mask[:, None]
    mask_kc = mask_ks & (d_off[None, :] < head_dim)
    d_off = d_off % head_dim

    BLOCK_DS: tl.constexpr = (BLOCK_D + GROUP_SIZE - 1) // GROUP_SIZE
    ds_off = tl.arange(0, BLOCK_DS)

    ks_ptr = KStates + head_id * stride_ksh
    ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd
    kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch
    kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd
    ksc_ptr = KSCaches + block_off * stride_kscn + head_id * stride_ksch
    ksc_ptrs = ksc_ptr + page_offs[:, None] * stride_kscb + ds_off[None, :] * stride_kscd

    if BLOCK_DV > 0:
        dv_off = tl.arange(0, BLOCK_DV)
        mask_vs = kv_mask[:, None]
        mask_vc = mask_vs & (dv_off[None, :] < head_dim_v)

        BLOCK_DVS: tl.constexpr = (BLOCK_DV + GROUP_SIZE - 1) // GROUP_SIZE
        dvs_off = tl.arange(0, BLOCK_DVS)

        dv_off = dv_off % head_dim_v
        vs_ptr = VStates + head_id * stride_vsh
        vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd
        vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch
        vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd
        vsc_ptr = VSCaches + block_off * stride_vscn + head_id * stride_vsch
        vsc_ptrs = vsc_ptr + page_offs[:, None] * stride_vscb + dvs_off[None, :] * stride_vscd

    k = tl.load(ks_ptrs, mask=mask_ks)
    if BLOCK_DV > 0:
        v = tl.load(vs_ptrs, mask=mask_vs)
    kc, kcs = _quant_blocked_fp8(k, fp8_min, fp8_max, KCaches.dtype.element_ty, GROUP_SIZE, ROUND_SCALE)
    tl.store(kc_ptrs, kc, mask=mask_kc)
    tl.store(ksc_ptrs, kcs, mask=kv_mask[:, None] & (ds_off[None, :] < tl.cdiv(head_dim, GROUP_SIZE)))
    if BLOCK_DV > 0:
        vc, vcs = _quant_blocked_fp8(v, fp8_min, fp8_max, VCaches.dtype.element_ty, GROUP_SIZE, ROUND_SCALE)
        tl.store(vc_ptrs, vc, mask=mask_vc)
        tl.store(vsc_ptrs, vcs, mask=kv_mask[:, None] & (ds_off[None, :] < tl.cdiv(head_dim_v, GROUP_SIZE)))


def fill_kv_cache_blocked_fp8(k_states: Tensor,
                              v_states: Optional[Tensor],
                              k_caches: Tensor,
                              v_caches: Optional[Tensor],
                              ks_caches: Tensor,
                              vs_caches: Optional[Tensor],
                              cu_seqlen_q: Tensor,
                              kv_seqlens: Tensor,
                              max_q_seqlen: int,
                              block_offsets: Tensor,
                              group_size: int = 128,
                              kv_layout: str = 'bshd',
                              scale_fmt: Optional[str] = None):
    """Fill key/value state to cache for paged attention with fp8 quantization.

    Args:
        k_states (Tensor): Key states of shape
            (seq_length, num_heads, head_dim).
        v_states (Optional[Tensor]): Value states of shape
            (seq_length, num_heads, head_dim_v). If None, no value states
            are processed.
        k_caches (Tensor): 4D k cache, shape depends on ``kv_layout``.
        v_caches (Optional[Tensor]): 4D v cache, shape depends on
            ``kv_layout``. If None, no value caches are processed.
        ks_caches (Tensor): 4D k scale cache, shape depends on
            ``kv_layout``.
        vs_caches (Optional[Tensor]): 4D v scale cache, shape depends on
            ``kv_layout``. If None, no value scale caches are processed.
        cu_seqlen_q (Tensor): Cumulative sequence lengths of queries,
            shape (batch_size + 1, ).
        kv_seqlens (Tensor): Sequence lengths of key/values, shape
            (batch_size, ).
        max_q_seqlen (int): Maximum sequence length of queries.
        block_offsets (Tensor): Block offsets for each batch, shape
            (batch_size, ).
        group_size (int, optional): Group size for fp8 quantization. Default
            is 128.
        kv_layout (str, optional): Layout of key/value caches. Valid values
            are ``'bshd'`` and ``'bhsd'``. Default is ``'bshd'``.
        scale_fmt (str, optional): Format of the fp8 scaling factors. Valid
            values are ``None`` and ``'ue8m0'``. When set to ``'ue8m0'``,
            scaling factors are stored/interpreted using the UE8M0 fp8 scale
            format; when ``None``, the default scale layout for this kernel
            is used.
    """
    assert scale_fmt in (None, 'ue8m0'), f'Unsupported scale format: {scale_fmt}.'

    if kv_layout == 'bshd':
        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
    elif kv_layout == 'bhsd':
        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
    else:
        raise RuntimeError('Unsupported layout.')

    if v_states is None:
        v_states = k_states[..., :0]
    if v_caches is None:
        v_caches = k_caches[..., :0]
    if vs_caches is None:
        vs_caches = ks_caches[..., :0]

    block_offsets = block_offsets.contiguous()
    batch_size = block_offsets.size(0)
    block_size = k_caches.size(s_dim)
    num_heads = k_caches.size(h_dim)
    head_dim = k_caches.size(d_dim)
    head_dim_v = v_states.size(-1)
    if max_q_seqlen == 1:
        max_num_blocks = 1
    else:
        max_num_blocks = triton.cdiv(max_q_seqlen, block_size) + 1

    BLOCK = block_size
    BLOCK_D = triton.next_power_of_2(head_dim)
    BLOCK_DV = triton.next_power_of_2(head_dim_v)
    if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:
        BLOCK_DV = 0

    dtype = k_caches.dtype
    finfo = torch.finfo(dtype)
    fmin = finfo.min
    fmax = finfo.max

    grid = (num_heads, max_num_blocks, batch_size)
    ROUND_SCALE = 1 if scale_fmt == 'ue8m0' else 0
    is_decoding = max_q_seqlen == 1
    _fill_kv_cache_blocked_fp8_kernel[grid](
        k_states,
        v_states,
        k_caches,
        v_caches,
        ks_caches,
        vs_caches,
        cu_seqlen_q,
        kv_seqlens,
        block_offsets,
        fp8_min=fmin,
        fp8_max=fmax,
        is_decoding=is_decoding,
        head_dim=head_dim,
        head_dim_v=head_dim_v,
        stride_kss=k_states.stride(-3),
        stride_ksh=k_states.stride(-2),
        stride_ksd=k_states.stride(-1),
        stride_vss=v_states.stride(-3),
        stride_vsh=v_states.stride(-2),
        stride_vsd=v_states.stride(-1),
        stride_kcn=k_caches.stride(b_dim),
        stride_kcb=k_caches.stride(s_dim),
        stride_kch=k_caches.stride(h_dim),
        stride_kcd=k_caches.stride(d_dim),
        stride_vcn=v_caches.stride(b_dim),
        stride_vcb=v_caches.stride(s_dim),
        stride_vch=v_caches.stride(h_dim),
        stride_vcd=v_caches.stride(d_dim),
        stride_kscn=ks_caches.stride(b_dim),
        stride_kscb=ks_caches.stride(s_dim),
        stride_ksch=ks_caches.stride(h_dim),
        stride_kscd=ks_caches.stride(d_dim),
        stride_vscn=vs_caches.stride(b_dim),
        stride_vscb=vs_caches.stride(s_dim),
        stride_vsch=vs_caches.stride(h_dim),
        stride_vscd=vs_caches.stride(d_dim),
        stride_boff=block_offsets.stride(0),
        ROUND_SCALE=ROUND_SCALE,
        GROUP_SIZE=group_size,
        BLOCK=BLOCK,
        BLOCK_D=BLOCK_D,
        BLOCK_DV=BLOCK_DV,
        num_warps=4,
    )


================================================
FILE: lmdeploy/pytorch/kernels/cuda/flashattention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence

import torch
import triton
import triton.language as tl
from packaging import version
from torch import Tensor

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

TRITON_VERSION = version.parse(triton.__version__)
VERSION_300 = version.parse('3.0.0')
VERSION_320 = version.parse('3.2.0')
assert TRITON_VERSION >= VERSION_300

# TODO: fast op might not work on non-nv device
tanh = tl.extra.cuda.libdevice.tanh
tl_log2 = tl.log2
tl_exp2 = tl.exp2


def _get_block_d(head_dim_k, head_dim_v):
    """Get block d."""
    BLOCK_DK = triton.next_power_of_2(head_dim_k)
    BLOCK_DK1 = 0
    if BLOCK_DK != head_dim_k:
        BLOCK_DK = BLOCK_DK // 2
        BLOCK_DK1 = max(16, triton.next_power_of_2(head_dim_k - BLOCK_DK))
    BLOCK_DV = triton.next_power_of_2(head_dim_v)
    return BLOCK_DK, BLOCK_DK1, BLOCK_DV


@triton.jit
def softcapping(qk, logit_softcapping: tl.constexpr):
    """Soft capping."""
    if logit_softcapping > 0.0:
        qk = qk / logit_softcapping
        qk = tanh(qk)
        qk = qk * logit_softcapping
    return qk


@triton.jit
def _load_kv(ptrs, boundary_check: tl.constexpr):
    """Load kv."""
    if boundary_check is not None:
        return tl.load(ptrs, boundary_check=boundary_check, padding_option='zero')
    else:
        return tl.load(ptrs)


@triton.jit
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, alibi_slope,
                       global_offs_m, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr,
                       logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr,
                       shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr,
                       BLOCK_DK1: tl.constexpr):
    k_ptrs = tl.advance(k_ptrs, (0, loop_start))
    v_ptrs = tl.advance(v_ptrs, (loop_start, 0))
    if BLOCK_DK1:
        k1_ptrs = tl.advance(k1_ptrs, (0, loop_start))

    offs_n = tl.arange(0, BLOCK_N)
    for start_n in range(loop_start, loop_end, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        k = _load_kv(k_ptrs, boundary_check=k_bound)
        qk = tl.dot(q, k)

        if BLOCK_DK1 != 0:
            k1 = _load_kv(k1_ptrs, boundary_check=k_bound)
            qk += tl.dot(q1, k1)

        if causal_mask:
            qk *= sm_scale
            qk = softcapping(qk, logit_softcapping)
            qk = qk * tl_log2(math.e)
            if block_sparse_size > 1:
                offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size
                qk_mask = (history_mask[:, None]) >= offs_mask[None, :]
            else:
                qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
            if window_size > 0:
                qk_mask = qk_mask & ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
            qk = tl.where(
                qk_mask,
                qk,
                float(-1e30),
            )
            m_i_new = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_i_new[:, None]
        elif window_size > 0:
            qk *= sm_scale
            qk = softcapping(qk, logit_softcapping)
            qk = qk * tl_log2(math.e)
            qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
            qk = tl.where(
                qk_mask,
                qk,
                float(-1e30),
            )
            m_i_new = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_i_new[:, None]
        elif logit_softcapping > 0:
            qk *= sm_scale
            qk = softcapping(qk, logit_softcapping)
            qk = qk * tl_log2(math.e)
            m_i_new = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_i_new[:, None]
        else:
            qk_scale = sm_scale * tl_log2(math.e)
            m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_i_new[:, None]

        if alibi_slope is not None:
            relative_pos = start_n + offs_n[None, :] - global_offs_m[:, None]
            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slope * tl_log2(math.e)
            qk += bias

        # -- compute p, m_i and l_i
        p = tl_exp2(qk)
        alpha = tl_exp2(m_i - m_i_new)
        l_i = alpha * l_i + tl.sum(p, 1)
        # -- update output accumulator --
        # scale acc
        acc = acc * alpha[:, None]

        # update acc
        if shared_kv:
            v = tl.trans(k)
        else:
            v = _load_kv(v_ptrs, boundary_check=v_bound)
        p = p.to(v.dtype)
        acc += tl.dot(p, v)
        # update m_i and l_i
        m_i = m_i_new

        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N))
        v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0))
        if BLOCK_DK1:
            k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N))

    return acc, l_i, m_i


# # FOR DEBUG, DON'T REMOVE
# import itertools
# configs = [
#     triton.Config({
#         'BLOCK_M': BM,
#         'BLOCK_N': BN
#     }, num_stages=s, num_warps=w)
#     for BM, BN, s, w in itertools.product([64, 128], [32, 64], [3, 4], [4])
# ]


# @triton.autotune(list(configs),
#                  key=['head_dim_k', 'head_dim_v'])
@triton.jit
def _flash_prefill_fwd_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    o_ptr,
    cu_seqlens_q_ptr,
    cu_seqlens_k_ptr,
    q_start_loc_ptr,
    q_seqlens_ptr,
    kv_start_loc_ptr,
    kv_seqlens_ptr,
    sinks,
    alibi_slopes_ptr,
    sm_scale,
    stride_qs: tl.constexpr,
    stride_qh: tl.constexpr,
    stride_qd: tl.constexpr,
    stride_ks: tl.constexpr,
    stride_kh,
    stride_kd: tl.constexpr,
    stride_vs: tl.constexpr,
    stride_vh,
    stride_vd: tl.constexpr,
    stride_os: tl.constexpr,
    stride_oh: tl.constexpr,
    stride_od: tl.constexpr,
    kv_group_num,
    head_dim_k: tl.constexpr,
    head_dim_v: tl.constexpr,
    causal: tl.constexpr,
    window_size: tl.constexpr,
    logit_softcapping: tl.constexpr,
    shared_kv: tl.constexpr,
    block_sparse_size: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_DK: tl.constexpr,
    BLOCK_DK1: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Flash attention kernel."""
    start_m = tl.program_id(0)
    head_id = tl.program_id(1)
    batch_id = tl.program_id(2)

    if cu_seqlens_q_ptr is not None:
        q_start_loc = tl.load(cu_seqlens_q_ptr + batch_id).to(tl.int32)
        q_seqlen = tl.load(cu_seqlens_q_ptr + batch_id + 1).to(tl.int32) - q_start_loc
    else:
        q_start_loc = tl.load(q_start_loc_ptr + batch_id).to(tl.int32)
        q_seqlen = tl.load(q_seqlens_ptr + batch_id).to(tl.int32)

    if cu_seqlens_k_ptr is not None:
        kv_start_loc = tl.load(cu_seqlens_k_ptr + batch_id).to(tl.int32)
        kv_seqlen = tl.load(cu_seqlens_k_ptr + batch_id + 1).to(tl.int32) - kv_start_loc
    else:
        kv_start_loc = tl.load(kv_start_loc_ptr + batch_id).to(tl.int32)
        kv_seqlen = tl.load(kv_seqlens_ptr + batch_id).to(tl.int32)

    if BLOCK_M * start_m >= q_seqlen:
        return

    kv_head_id = head_id // kv_group_num
    history_len = kv_seqlen - q_seqlen

    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)

    loop_start = 0
    kv_min_loc = tl.zeros([BLOCK_M], dtype=tl.int32)
    if window_size > 0:
        start_block_id = tl.maximum(history_len + start_m * BLOCK_M - window_size, 0) // BLOCK_N
        kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0)
        loop_start = start_block_id * BLOCK_N

    offs_dk = tl.arange(0, BLOCK_DK)
    mask_dk = offs_dk < head_dim_k
    offs_dk = tl.multiple_of(tl.max_contiguous(offs_dk % head_dim_k, BLOCK_DK), BLOCK_DK)
    off_q = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk[None, :] * stride_qd)
    q_ptrs = q_ptr + off_q
    q = tl.load(q_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk[None, :]))

    k_ptrs = tl.make_block_ptr(
        base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,
        shape=(head_dim_k, kv_seqlen),
        strides=(stride_kd, stride_ks),
        offsets=(0, 0),
        block_shape=(BLOCK_DK, BLOCK_N),
        order=(0, 1),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + kv_start_loc * stride_vs + kv_head_id * stride_vh,
        shape=(kv_seqlen, head_dim_v),
        strides=(stride_vs, stride_vd),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DV),
        order=(1, 0),
    )

    # for alibi
    if alibi_slopes_ptr is not None:
        alibi_slope = tl.load(alibi_slopes_ptr + head_id)
    else:
        alibi_slope = None
    global_offs_m = history_len + offs_m

    if BLOCK_DK + BLOCK_DK1 == head_dim_k:
        k_bound0: tl.constexpr = None
        k_bound1: tl.constexpr = (1, )
    else:
        k_bound0: tl.constexpr = (1, )
        k_bound1: tl.constexpr = (0, 1)
    if head_dim_v == BLOCK_DV:
        v_bound0: tl.constexpr = None
        v_bound1: tl.constexpr = (0, )
    else:
        v_bound0: tl.constexpr = (1, )
        v_bound1: tl.constexpr = (0, 1)

    if BLOCK_DK1 != 0:
        offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1)
        mask_dk1 = offs_dk1 < head_dim_k
        offs_dk1 = tl.multiple_of(tl.max_contiguous(offs_dk1 % head_dim_k, BLOCK_DK1), BLOCK_DK1)
        offs_q1 = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk1[None, :] * stride_qd)
        q1_ptrs = q_ptr + offs_q1
        q1 = tl.load(q1_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk1[None, :]))
        k1_ptrs = tl.make_block_ptr(
            base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,
            shape=(head_dim_k, kv_seqlen),
            strides=(stride_kd, stride_ks),
            offsets=(BLOCK_DK, 0),
            block_shape=(BLOCK_DK1, BLOCK_N),
            order=(0, 1),
        )
    else:
        q1 = q
        k1_ptrs = k_ptrs

    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)

    if causal:
        history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M)
        loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N
    else:
        history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32)
        loop_end = kv_seqlen // BLOCK_N * BLOCK_N

    acc, l_i, m_i = _prefill_fwd_inner(acc,
                                       l_i,
                                       m_i,
                                       q,
                                       k_ptrs,
                                       v_ptrs,
                                       q1,
                                       k1_ptrs,
                                       loop_start,
                                       loop_end,
                                       sm_scale,
                                       alibi_slope,
                                       global_offs_m,
                                       history_mask,
                                       kv_min_loc,
                                       causal_mask=False,
                                       window_size=window_size,
                                       logit_softcapping=logit_softcapping,
                                       k_bound=k_bound0,
                                       v_bound=v_bound0,
                                       shared_kv=shared_kv,
                                       block_sparse_size=block_sparse_size,
                                       BLOCK_N=BLOCK_N,
                                       BLOCK_DK1=BLOCK_DK1)

    loop_start = loop_end
    if causal:
        loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N)
    else:
        loop_end = kv_seqlen
    acc, l_i, m_i = _prefill_fwd_inner(acc,
                                       l_i,
                                       m_i,
                                       q,
                                       k_ptrs,
                                       v_ptrs,
                                       q1,
                                       k1_ptrs,
                                       loop_start,
                                       loop_end,
                                       sm_scale,
                                       alibi_slope,
                                       global_offs_m,
                                       history_mask,
                                       kv_min_loc,
                                       causal_mask=True,
                                       window_size=window_size,
                                       logit_softcapping=logit_softcapping,
                                       k_bound=k_bound1,
                                       v_bound=v_bound1,
                                       shared_kv=shared_kv,
                                       block_sparse_size=block_sparse_size,
                                       BLOCK_N=BLOCK_N,
                                       BLOCK_DK1=BLOCK_DK1)
    # epilogue
    if sinks is not None:
        sink = tl.load(sinks + head_id).to(l_i.dtype)
        l_i = l_i + tl.exp2(sink * tl_log2(math.e) - m_i)

    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]

    # initialize pointers to output
    offs_dv = tl.arange(0, BLOCK_DV)
    mask_dv = offs_dv < head_dim_v
    off_o = ((q_start_loc + offs_m[:, None]) * stride_os + head_id * stride_oh + offs_dv[None, :] * stride_od)
    out_ptrs = o_ptr + off_o
    tl.store(out_ptrs, acc, mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :])


_nv_cap = None


def _kernel_meta_sm7x(BLOCK_DK):
    num_warps = 4
    num_stages = min(4, max(2, 768 // BLOCK_DK))
    BLOCK_M = max(16, 8192 // BLOCK_DK)
    BLOCK_N = 32
    return BLOCK_M, BLOCK_N, num_warps, num_stages


def _kernel_meta_sm8x(BLOCK_DK: int, shared_kv: bool):
    num_warps = 8
    min_m = 64 if shared_kv else 16
    BLOCK_M = max(min_m, 16384 // BLOCK_DK)
    BLOCK_M = min(128, BLOCK_M)
    BLOCK_N = BLOCK_M
    num_stages = 3 if BLOCK_DK <= 128 else 2

    return BLOCK_M, BLOCK_N, num_warps, num_stages


def _kernel_meta_sm86(BLOCK_DK: int, shared_kv: bool):
    """Sm86 has different smem size with sm80."""
    num_warps = 4
    if BLOCK_DK <= 128:
        BLOCK_M = 128
        BLOCK_N = 64
        num_stages = 3
    elif BLOCK_DK <= 256:
        BLOCK_M = 64
        BLOCK_N = 32
        num_stages = 2
    else:
        BLOCK_M = 32
        BLOCK_N = 32
        num_stages = 2

    return BLOCK_M, BLOCK_N, num_warps, num_stages


def _kernel_meta_sm9x(BLOCK_DK: int, shared_kv: bool):

    num_warps = 8
    BLOCK_M = 128 if BLOCK_DK <= 256 else 64
    if not shared_kv and BLOCK_DK >= 512:
        BLOCK_M = 32

    # fix crash on triton<3.2.0
    if BLOCK_DK >= 512 and TRITON_VERSION < VERSION_320:
        BLOCK_M = 32
        num_warps = 4

    BLOCK_N = 128 if BLOCK_DK <= 128 else 64

    num_stages = 3 if BLOCK_DK <= 128 else 2
    return BLOCK_M, BLOCK_N, num_warps, num_stages


def _kernel_meta_sm12x(BLOCK_DK: int, shared_kv: bool):
    # Blackwell (sm_120, cc 12.x) + B200/B100 variants
    if BLOCK_DK <= 128:
        BLOCK_M = 128
        BLOCK_N = 128 if shared_kv else 64
        num_warps = 8
        num_stages = 3
    elif BLOCK_DK <= 256:
        BLOCK_M = 64
        BLOCK_N = 128 if shared_kv else 64
        num_warps = 8
        num_stages = 3
    elif BLOCK_DK <= 512:
        BLOCK_M = 64 if shared_kv else 32
        BLOCK_N = 64
        num_warps = 4
        num_stages = 2
    else:
        BLOCK_M = 32
        BLOCK_N = 32 if not shared_kv else 64
        num_warps = 4
        num_stages = 2

    return BLOCK_M, BLOCK_N, num_warps, num_stages


def _kernel_meta_rocm(BLOCK_DK: int, shared_kv: bool):
    BLOCK_N = 32
    BLOCK_M = 32 if BLOCK_DK > 128 else 64
    num_warps = 4
    num_stages = 1
    return BLOCK_M, BLOCK_N, num_warps, num_stages


def flash_attn_varlen_func(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens_q: Tensor = None,
    cu_seqlens_k: Tensor = None,
    max_seqlen_q: int = None,
    max_seqlen_k: int = None,  # not used, just for align with fa interface
    softmax_scale: float = None,
    causal: bool = False,
    window_size: int = (-1, -1),
    softcap: float = 0.0,
    # old seqlens
    q_start_loc: Tensor = None,
    q_seqlens: Tensor = None,
    kv_start_loc: Tensor = None,
    kv_seqlens: Tensor = None,
    # args not in fa
    alibi_slopes: Tensor = None,
    sinks: Tensor = None,
    block_sparse_size: int = 1,
    kv_layout: str = 'hsd',
):
    """Varlen flash Attention forward.

    Support sliding window, softcapping.
    """

    global _nv_cap
    if _nv_cap is None:
        _nv_cap = torch.cuda.get_device_capability()

    def grid(args):
        return (triton.cdiv(max_seqlen_q, args['BLOCK_M']), num_heads, batch)

    if kv_layout == 'shd':
        s_dim, h_dim, d_dim = (0, 1, 2)
    elif kv_layout == 'hsd':
        s_dim, h_dim, d_dim = (1, 0, 2)
    else:
        raise RuntimeError('Unsupported layout.')

    if max_seqlen_q is None:
        max_seqlen_q = q.size(0)

    if window_size is None:
        window_size = -1
    elif isinstance(window_size, Sequence):
        window_size = window_size[0]

    if softcap is None:
        softcap = -1.0

    head_dim_q = q.size(-1)
    head_dim_k = k.size(d_dim)
    head_dim_v = v.size(d_dim)

    o = q.new_empty(*q.size()[:-1], head_dim_v)
    assert head_dim_q == head_dim_k and head_dim_v == o.size(-1)

    if softmax_scale is None:
        softmax_scale = 1.0 / (head_dim_q**0.5)

    if cu_seqlens_k is None:
        assert kv_start_loc is not None and kv_seqlens is not None
    if cu_seqlens_q is None:
        assert q_start_loc is not None and q_seqlens is not None
        batch = q_seqlens.size(0)
    else:
        batch = cu_seqlens_q.size(0) - 1
    num_heads = q.size(-2)
    num_kv_heads = k.size(h_dim)
    kv_group_num = num_heads // num_kv_heads

    if sinks is not None:
        assert sinks.is_contiguous()
        assert sinks.numel() == num_heads

    BLOCK_DK, BLOCK_DK1, BLOCK_DV = _get_block_d(head_dim_k, head_dim_v)

    shared_kv = k.data_ptr() == v.data_ptr() and BLOCK_DK == BLOCK_DV

    num_warps = 4
    hip_mode = getattr(torch.version, 'hip', None) is not None
    if hip_mode:
        BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_rocm(BLOCK_DK, shared_kv)
    else:
        if _nv_cap[0] < 8:
            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm7x(BLOCK_DK)
        elif _nv_cap[0] < 9:
            if _nv_cap[1] in [6, 9]:
                BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm86(BLOCK_DK, shared_kv)
            else:
                BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DK, shared_kv)
        elif _nv_cap[0] < 10:
            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DK, shared_kv)
        else:
            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm12x(BLOCK_DK, shared_kv)

    BLOCK_M = min(128, BLOCK_M)
    _flash_prefill_fwd_kernel[grid](
        q,
        k,
        v,
        o,
        cu_seqlens_q,
        cu_seqlens_k,
        q_start_loc,
        q_seqlens,
        kv_start_loc,
        kv_seqlens,
        sinks,
        alibi_slopes,
        sm_scale=softmax_scale,
        stride_qs=q.stride(0),
        stride_qh=q.stride(1),
        stride_qd=q.stride(2),
        stride_ks=k.stride(s_dim),
        stride_kh=k.stride(h_dim),
        stride_kd=k.stride(d_dim),
        stride_vs=v.stride(s_dim),
        stride_vh=v.stride(h_dim),
        stride_vd=v.stride(d_dim),
        stride_os=o.stride(0),
        stride_oh=o.stride(1),
        stride_od=o.stride(2),
        kv_group_num=kv_group_num,
        head_dim_k=head_dim_k,
        head_dim_v=head_dim_v,
        causal=causal,
        window_size=window_size,
        logit_softcapping=softcap,
        shared_kv=shared_kv,
        block_sparse_size=block_sparse_size,
        BLOCK_DK=BLOCK_DK,
        BLOCK_DK1=BLOCK_DK1,
        BLOCK_DV=BLOCK_DV,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return o


================================================
FILE: lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal

import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def _flatten_kv_cache(
    kc_ptr,
    vc_ptr,
    ko_ptr,
    vo_ptr,
    start_loc_ptr,
    seqlens_ptr,
    block_offsets_ptr,
    stride_kcb: tl.constexpr,
    stride_kcs: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_vcb: tl.constexpr,
    stride_vcs: tl.constexpr,
    stride_vch: tl.constexpr,
    stride_vcd: tl.constexpr,
    stride_koh,
    stride_kos: tl.constexpr,
    stride_kod: tl.constexpr,
    stride_voh,
    stride_vos: tl.constexpr,
    stride_vod: tl.constexpr,
    stride_boff,
    OUT_SIZE,
    HEAD_DIM_K: tl.constexpr,
    HEAD_DIM_V: tl.constexpr,
    BLOCK_BS: tl.constexpr,
    BLOCK_DK: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Flatten kv cache."""
    page_id = tl.program_id(0)
    batch_id = tl.program_id(1)
    head_id = tl.program_id(2)

    num_batches = tl.num_programs(1)

    seqlen = tl.load(seqlens_ptr + batch_id)
    start_loc = tl.load(start_loc_ptr + batch_id)
    # fill last block to prevent attention nan
    if batch_id == num_batches - 1:
        seqlen = (OUT_SIZE - start_loc).to(seqlen.dtype)
    if page_id * BLOCK_BS >= seqlen:
        return

    start_loc = tl.load(start_loc_ptr + batch_id)
    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)

    offs_bs = tl.arange(0, BLOCK_BS)
    offs_dk = tl.arange(0, BLOCK_DK) % HEAD_DIM_K
    offs_dv = tl.arange(0, BLOCK_DV) % HEAD_DIM_V
    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)
    mask_bs = offs_obs < seqlen
    mask_dk = tl.arange(0, BLOCK_DK) < HEAD_DIM_K
    mask_dv = tl.arange(0, BLOCK_DV) < HEAD_DIM_V

    kc_ptrs = (kc_ptr + b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch +
               offs_dk[None, :] * stride_kcd)
    vc_ptrs = (vc_ptr + b_off * stride_vcb + offs_bs[:, None] * stride_vcs + head_id * stride_vch +
               offs_dv[None, :] * stride_vcd)
    ko_ptrs = (ko_ptr + head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos +
               offs_dk[None, :] * stride_kod)
    vo_ptrs = (vo_ptr + head_id * stride_voh + (start_loc + offs_obs[:, None]) * stride_vos +
               offs_dv[None, :] * stride_vod)

    kc = tl.load(kc_ptrs)
    tl.store(ko_ptrs, kc, mask=mask_bs[:, None] & mask_dk[None, :])
    if HEAD_DIM_V > 0:
        vc = tl.load(vc_ptrs)
        tl.store(vo_ptrs, vc, mask=mask_bs[:, None] & mask_dv[None, :])


@triton.jit
def _dequant_int4(val, HEAD_DIM: tl.constexpr, BLOCK: tl.constexpr):
    """Dequant int4."""
    offs = tl.arange(0, BLOCK) // (HEAD_DIM // 2)
    shift = (offs % 2) * 4
    return (val >> shift) & 0xf


@triton.jit
def _flatten_kv_cache_quant(
    kc_ptr,
    vc_ptr,
    ko_ptr,
    vo_ptr,
    ksz_ptr,
    vsz_ptr,
    start_loc_ptr,
    seqlens_ptr,
    block_offsets_ptr,
    stride_kcb: tl.constexpr,
    stride_kcs: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_vcb: tl.constexpr,
    stride_vcs: tl.constexpr,
    stride_vch: tl.constexpr,
    stride_vcd: tl.constexpr,
    stride_kszb: tl.constexpr,
    stride_kszs: tl.constexpr,
    stride_kszh: tl.constexpr,
    stride_kszd: tl.constexpr,
    stride_vszb: tl.constexpr,
    stride_vszs: tl.constexpr,
    stride_vszh: tl.constexpr,
    stride_vszd: tl.constexpr,
    stride_koh,
    stride_kos: tl.constexpr,
    stride_kod: tl.constexpr,
    stride_voh,
    stride_vos: tl.constexpr,
    stride_vod: tl.constexpr,
    stride_boff,
    quant_policy: tl.constexpr,
    OUT_SIZE,
    HEAD_DIM_K: tl.constexpr,
    HEAD_DIM_V: tl.constexpr,
    BLOCK_BS: tl.constexpr,
    BLOCK_DK: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Flatten kv cache."""
    page_id = tl.program_id(0)
    batch_id = tl.program_id(1)
    head_id = tl.program_id(2)

    num_batches = tl.num_programs(1)

    seqlen = tl.load(seqlens_ptr + batch_id)
    start_loc = tl.load(start_loc_ptr + batch_id)
    if batch_id == num_batches - 1:
        seqlen = OUT_SIZE - start_loc
    if page_id * BLOCK_BS >= seqlen:
        return

    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)

    offs_bs = tl.arange(0, BLOCK_BS)
    if quant_policy == 4:
        HALF_HDK: tl.constexpr = HEAD_DIM_K // 2
        HALF_HDV: tl.constexpr = HEAD_DIM_V // 2
        offs_dk = tl.arange(0, BLOCK_DK) % HALF_HDK
        offs_dv = tl.arange(0, BLOCK_DV) % HALF_HDV
    else:
        offs_dk = tl.arange(0, BLOCK_DK) % HEAD_DIM_K
        offs_dv = tl.arange(0, BLOCK_DV) % HEAD_DIM_V
    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)
    mask_bs = offs_obs < seqlen

    offs_dok = tl.arange(0, BLOCK_DK)
    offs_dov = tl.arange(0, BLOCK_DV)
    mask_dok = offs_dok < HEAD_DIM_K
    mask_dov = offs_dov < HEAD_DIM_V

    kc_ptrs = (kc_ptr + b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch +
               offs_dk[None, :] * stride_kcd)
    vc_ptrs = (vc_ptr + b_off * stride_vcb + offs_bs[:, None] * stride_vcs + head_id * stride_vch +
               offs_dv[None, :] * stride_vcd)
    ksz_ptrs = (ksz_ptr + b_off * stride_kszb + offs_bs * stride_kszs + head_id * stride_kszh)
    vsz_ptrs = (vsz_ptr + b_off * stride_vszb + offs_bs * stride_vszs + head_id * stride_vszh)
    ko_ptrs = (ko_ptr + head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos +
               offs_dok[None, :] * stride_kod)
    vo_ptrs = (vo_ptr + head_id * stride_voh + (start_loc + offs_obs[:, None]) * stride_vos +
               offs_dov[None, :] * stride_vod)

    kc = tl.load(kc_ptrs)
    if quant_policy == 4:
        kc = _dequant_int4(kc, HEAD_DIM_K, BLOCK_DK)
    ks = tl.load(ksz_ptrs)
    kz = tl.load(ksz_ptrs + stride_kszd)
    ksz = ks * kz
    kq = (kc * ks[:, None] - ksz[:, None]).to(ko_ptr.dtype.element_ty)
    tl.store(ko_ptrs, kq, mask=mask_bs[:, None] & mask_dok[None, :])
    vc = tl.load(vc_ptrs)
    if quant_policy == 4:
        vc = _dequant_int4(vc, HEAD_DIM_V, BLOCK_DV)
    vs = tl.load(vsz_ptrs)
    vz = tl.load(vsz_ptrs + stride_vszd)
    vsz = vs * vz
    vq = (vc * vs[:, None] - vsz[:, None]).to(vo_ptr.dtype.element_ty)
    tl.store(vo_ptrs, vq, mask=mask_bs[:, None] & mask_dov[None, :])


def flatten_kv_cache(k_caches: Tensor,
                     v_caches: Tensor,
                     seqlens: Tensor,
                     block_offsets: Tensor,
                     start_loc: Tensor = None,
                     out_size: int = None,
                     out_dtype: torch.dtype = None,
                     k_scales_zeros: Tensor = None,
                     v_scales_zeros: Tensor = None,
                     quant_policy: Literal[0, 4, 8] = 0,
                     kv_layout: str = 'bshd',
                     flatten_kv_layout: str = 'hsd'):
    """Recovery paged kv cache to normal kv cache."""
    if kv_layout == 'bshd':
        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
    elif kv_layout == 'bhsd':
        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
    else:
        raise RuntimeError('Unsupported layout.')

    if out_dtype is None:
        out_dtype = k_caches.dtype

    if out_size is None or out_size <= 0:
        out_size = k_caches.size(b_dim) * k_caches.size(s_dim)

    if start_loc is None:
        start_loc = seqlens.cumsum(0) - seqlens

    batch_size, num_blocks = block_offsets.size()
    num_heads = k_caches.size(h_dim)
    k_head_dim = k_caches.size(d_dim)
    v_head_dim = v_caches.size(d_dim)
    if quant_policy == 4:
        k_head_dim *= 2
        v_head_dim *= 2
    BLOCK_DK = triton.next_power_of_2(k_head_dim)
    BLOCK_DV = triton.next_power_of_2(v_head_dim)
    BLOCK_BS = k_caches.size(s_dim)
    shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim
    if flatten_kv_layout == 'hsd':
        k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
        if quant_policy == 0 and shared_kv:
            v_states = k_states[..., :v_head_dim]
            v_head_dim = 0
        else:
            v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)
        stride_koh = k_states.stride(0)
        stride_kos = k_states.stride(1)
        stride_voh = v_states.stride(0)
        stride_vos = v_states.stride(1)
    elif flatten_kv_layout == 'shd':
        k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)
        if quant_policy == 0 and shared_kv:
            v_states = k_states[..., :v_head_dim]
            v_head_dim = 0
        else:
            v_states = v_caches.new_empty(out_size, num_heads, v_head_dim, dtype=out_dtype)
        stride_koh = k_states.stride(1)
        stride_kos = k_states.stride(0)
        stride_voh = v_states.stride(1)
        stride_vos = v_states.stride(0)
    else:
        raise RuntimeError('Unsupported layout.')

    grid = (num_blocks, batch_size, num_heads)
    if quant_policy == 0:
        _flatten_kv_cache[grid](
            k_caches,
            v_caches,
            k_states,
            v_states,
            start_loc,
            seqlens,
            block_offsets,
            stride_kcb=k_caches.stride(b_dim),
            stride_kcs=k_caches.stride(s_dim),
            stride_kch=k_caches.stride(h_dim),
            stride_kcd=k_caches.stride(d_dim),
            stride_vcb=v_caches.stride(b_dim),
            stride_vcs=v_caches.stride(s_dim),
            stride_vch=v_caches.stride(h_dim),
            stride_vcd=v_caches.stride(d_dim),
            stride_koh=stride_koh,
            stride_kos=stride_kos,
            stride_kod=k_states.stride(2),
            stride_voh=stride_voh,
            stride_vos=stride_vos,
            stride_vod=v_states.stride(2),
            stride_boff=block_offsets.stride(0),
            OUT_SIZE=out_size,
            HEAD_DIM_K=k_head_dim,
            HEAD_DIM_V=v_head_dim,
            BLOCK_BS=BLOCK_BS,
            BLOCK_DK=BLOCK_DK,
            BLOCK_DV=BLOCK_DV,
        )
    else:
        _flatten_kv_cache_quant[grid](
            k_caches,
            v_caches,
            k_states,
            v_states,
            k_scales_zeros,
            v_scales_zeros,
            start_loc,
            seqlens,
            block_offsets,
            stride_kcb=k_caches.stride(b_dim),
            stride_kcs=k_caches.stride(s_dim),
            stride_kch=k_caches.stride(h_dim),
            stride_kcd=k_caches.stride(d_dim),
            stride_vcb=v_caches.stride(b_dim),
            stride_vcs=v_caches.stride(s_dim),
            stride_vch=v_caches.stride(h_dim),
            stride_vcd=v_caches.stride(d_dim),
            stride_kszb=k_scales_zeros.stride(b_dim),
            stride_kszs=k_scales_zeros.stride(s_dim),
            stride_kszh=k_scales_zeros.stride(h_dim),
            stride_kszd=k_scales_zeros.stride(d_dim),
            stride_vszb=v_scales_zeros.stride(b_dim),
            stride_vszs=v_scales_zeros.stride(s_dim),
            stride_vszh=v_scales_zeros.stride(h_dim),
            stride_vszd=v_scales_zeros.stride(d_dim),
            stride_koh=stride_koh,
            stride_kos=stride_kos,
            stride_kod=k_states.stride(2),
            stride_voh=stride_voh,
            stride_vos=stride_vos,
            stride_vod=v_states.stride(2),
            stride_boff=block_offsets.stride(0),
            quant_policy=quant_policy,
            OUT_SIZE=out_size,
            HEAD_DIM_K=k_head_dim,
            HEAD_DIM_V=v_head_dim,
            BLOCK_BS=BLOCK_BS,
            BLOCK_DK=BLOCK_DK,
            BLOCK_DV=BLOCK_DV,
        )

    return k_states, v_states


@triton.jit
def dequant_fp8(x, scale, GROUP_SIZE: tl.constexpr):
    """Dequant fp8."""
    M: tl.constexpr = x.shape[0]
    N: tl.constexpr = x.shape[1]
    x = x.to(scale.dtype)
    x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)
    scale = scale.reshape(M, N // GROUP_SIZE, 1)
    x = x * scale
    x = x.reshape(M, N)
    return x


@triton.jit
def flatten_kv_cache_mla_fp8_kernel(
    kc_nope_ptr,
    kc_scale_ptr,
    kc_pe_ptr,
    ko_ptr,
    start_loc_ptr,
    seqlens_ptr,
    block_offsets_ptr,
    stride_kcb: tl.constexpr,
    stride_kcs: tl.constexpr,
    stride_kch: tl.constexpr,
    stride_kcd: tl.constexpr,
    stride_kcsb: tl.constexpr,
    stride_kcss: tl.constexpr,
    stride_kcsh: tl.constexpr,
    stride_kcsd: tl.constexpr,
    stride_kcpb: tl.constexpr,
    stride_kcps: tl.constexpr,
    stride_kcph: tl.constexpr,
    stride_kcpd: tl.constexpr,
    stride_koh,
    stride_kos: tl.constexpr,
    stride_kod: tl.constexpr,
    stride_boff,
    OUT_SIZE,
    BLOCK_BS: tl.constexpr,
    BLOCK_NOPE: tl.constexpr,
    BLOCK_PE: tl.constexpr,
    GROUP_SIZE: tl.constexpr,
):
    """Mla fp8 flatten kv cache kernel."""
    page_id = tl.program_id(0)
    batch_id = tl.program_id(1)
    head_id = tl.program_id(2)
    num_batches = tl.num_programs(1)

    seqlen = tl.load(seqlens_ptr + batch_id)
    start_loc = tl.load(start_loc_ptr + batch_id)
    # fill last block to prevent attention nan
    if batch_id == num_batches - 1:
        seqlen = OUT_SIZE - start_loc
    if page_id * BLOCK_BS >= seqlen:
        return

    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)

    BLOCK_SCALE: tl.constexpr = BLOCK_NOPE // GROUP_SIZE
    offs_bs = tl.arange(0, BLOCK_BS)
    offs_dnope = tl.arange(0, BLOCK_NOPE)
    offs_scale = tl.arange(0, BLOCK_SCALE)
    offs_dpe = tl.arange(0, BLOCK_PE)
    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)
    mask_bs = offs_obs < seqlen

    offs_kc = b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch
    kc_nope_ptrs = (kc_nope_ptr + offs_kc + offs_dnope[None, :] * stride_kcd)

    offs_kc_scale = b_off * stride_kcsb + offs_bs[:, None] * stride_kcss + head_id * stride_kcsh
    kc_scale_ptrs = (kc_scale_ptr + offs_kc_scale + offs_scale[None, :] * stride_kcsd)

    offs_kc_pe = b_off * stride_kcpb + offs_bs[:, None] * stride_kcps + head_id * stride_kcph
    kc_pe_ptrs = (kc_pe_ptr + offs_kc_pe + offs_dpe[None, :] * stride_kcpd)

    offs_ko = head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos
    ko_nope_ptrs = (ko_ptr + offs_ko + offs_dnope[None, :] * stride_kod)
    ko_pe_ptrs = (ko_ptr + offs_ko + (BLOCK_NOPE + offs_dpe[None, :]) * stride_kod)

    # nope
    kc_nope = tl.load(kc_nope_ptrs)
    kc_scale = tl.load(kc_scale_ptrs)
    ko_nope = dequant_fp8(kc_nope, kc_scale, GROUP_SIZE)
    ko_nope = ko_nope.to(ko_ptr.dtype.element_ty)
    tl.store(ko_nope_ptrs, ko_nope, mask=mask_bs[:, None])

    # pe
    kc_pe = tl.load(kc_pe_ptrs)
    tl.store(ko_pe_ptrs, kc_pe, mask=mask_bs[:, None])


def flatten_kv_cache_mla_fp8(k_caches: Tensor,
                             seqlens: Tensor,
                             block_offsets: Tensor,
                             start_loc: Tensor = None,
                             out_size: int = None,
                             out_dtype: torch.dtype = None,
                             flatten_kv_layout: str = 'hsd'):
    """This kernel is designed to support mla fp8."""
    assert k_caches.dim() == 4

    b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)

    if out_dtype is None:
        out_dtype = torch.bfloat16

    if out_size is None or out_size <= 0:
        out_size = k_caches.size(b_dim) * k_caches.size(s_dim)

    # TODO: DIRTY magic number
    k_caches_nope = k_caches[..., :512]
    k_caches_scale = k_caches[..., 512:512 + 16].view(torch.float32)
    k_caches_pe = k_caches[..., 512 + 16:].view(out_dtype)

    if start_loc is None:
        start_loc = seqlens.cumsum(0) - seqlens

    batch_size, num_blocks = block_offsets.size()
    num_heads = k_caches.size(h_dim)
    k_head_dim = 576
    BLOCK_NOPE = 512
    BLOCK_PE = 64
    BLOCK_BS = k_caches.size(s_dim)
    if flatten_kv_layout == 'hsd':
        k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
        stride_koh = k_states.stride(0)
        stride_kos = k_states.stride(1)
    elif flatten_kv_layout == 'shd':
        k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)
        stride_koh = k_states.stride(1)
        stride_kos = k_states.stride(0)
    else:
        raise RuntimeError(f'Unsupported layout: {flatten_kv_layout}.')

    grid = (num_blocks, batch_size, num_heads)
    flatten_kv_cache_mla_fp8_kernel[grid](
        k_caches_nope,
        k_caches_scale,
        k_caches_pe,
        k_states,
        start_loc,
        seqlens,
        block_offsets,
        stride_kcb=k_caches_nope.stride(b_dim),
        stride_kcs=k_caches_nope.stride(s_dim),
        stride_kch=k_caches_nope.stride(h_dim),
        stride_kcd=k_caches_nope.stride(d_dim),
        stride_kcsb=k_caches_scale.stride(b_dim),
        stride_kcss=k_caches_scale.stride(s_dim),
        stride_kcsh=k_caches_scale.stride(h_dim),
        stride_kcsd=k_caches_scale.stride(d_dim),
        stride_kcpb=k_caches_pe.stride(b_dim),
        stride_kcps=k_caches_pe.stride(s_dim),
        stride_kcph=k_caches_pe.stride(h_dim),
        stride_kcpd=k_caches_pe.stride(d_dim),
        stride_koh=stride_koh,
        stride_kos=stride_kos,
        stride_kod=k_states.stride(2),
        stride_boff=block_offsets.stride(0),
        OUT_SIZE=out_size,
        BLOCK_BS=BLOCK_BS,
        BLOCK_NOPE=BLOCK_NOPE,
        BLOCK_PE=BLOCK_PE,
        GROUP_SIZE=128,
    )

    return k_states


================================================
FILE: lmdeploy/pytorch/kernels/cuda/fused_lora.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl


def get_autotune_config():
    """Get autotune config."""
    return [
        triton.Config({
            'BLOCK_SIZE_M': 32,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 128
        }, num_stages=4, num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 16,
            'BLOCK_SIZE_N': 256,
            'BLOCK_SIZE_K': 128
        }, num_stages=4, num_warps=4),
    ]


@triton.jit
def _atomic_store(ptrs, val, mask):
    """Atomic store values."""
    dtype = ptrs.dtype.element_ty
    if (dtype == torch.float16) | (dtype == torch.float32):
        tl.atomic_add(ptrs, val, mask=mask, sem='relaxed')
    else:
        # bfloat16 does not support atomic add
        origin = tl.load(ptrs, mask=mask)
        val = val.to(origin.dtype)
        val += origin
        tl.store(ptrs, val, mask=mask)


@triton.autotune(
    configs=get_autotune_config(),
    key=['N', 'K'],
    restore_value=['c_ptr'],
)
@triton.jit
def _fused_lora_kernel(
    a_ptr,
    lora_a_ptr,
    lora_b_ptr,
    c_ptr,
    scaling_ptr,
    rank_start_ptr,
    ranks_ptr,
    seq_start_ptr,
    seq_lens_ptr,
    adapter_ids_ptr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak: tl.constexpr,
    stride_lar: tl.constexpr,
    stride_lak: tl.constexpr,
    stride_lbr: tl.constexpr,
    stride_lbn: tl.constexpr,
    stride_cm,
    stride_cn: tl.constexpr,
    BLOCK_SIZE_R: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    CUM: tl.constexpr,
):
    """Fused lora kernel."""
    pid = tl.program_id(axis=0)
    bid = tl.program_id(axis=1)

    M = tl.load(seq_lens_ptr + bid)
    if M <= 0:
        return

    seq_start = tl.load(seq_start_ptr + bid)
    adapter_id = tl.load(adapter_ids_ptr + bid)
    rank_start = tl.load(rank_start_ptr + adapter_id)
    rank = tl.load(ranks_ptr + adapter_id)

    pid_m = pid

    if pid_m * BLOCK_SIZE_M >= M:
        return

    offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
    offs_n = tl.arange(0, BLOCK_SIZE_N)

    mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M
    offs_cm = offs_m
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :]

    if rank == 0:
        if not CUM:
            for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
                mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
                c_mask = mask_cm[:, None] * mask_cn[None, :]
                tl.store(c_ptrs, 0.0, mask=c_mask)
                c_ptrs += stride_cn * BLOCK_SIZE_N
    else:

        offs_am = (seq_start + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)
        offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak + offs_r[None, :] * stride_lar)

        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            # Load the next block of A and B
            # If it is out of bounds, set it to 0.
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            la = tl.load(la_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            # We accumulate along the K dimension.
            accumulator = tl.dot(a, la, acc=accumulator)
            # Advance the ptrs to the next K block.
            a_ptrs += BLOCK_SIZE_K * stride_ak
            la_ptrs += BLOCK_SIZE_K * stride_lak
        ar = accumulator.to(lora_b_ptr.dtype.element_ty)

        scaling = tl.load(scaling_ptr + adapter_id).to(ar.dtype)
        ar *= scaling
        ar = tl.where(tl.arange(0, BLOCK_SIZE_R)[None, :] < rank, ar, tl.zeros_like(ar))
        lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr + offs_n[None, :] * stride_lbn)

        for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
            lb = tl.load(lb_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N)
            c = tl.dot(ar, lb)

            mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
            c_mask = mask_cm[:, None] * mask_cn[None, :]
            if CUM:
                _atomic_store(c_ptrs, c, mask=c_mask)
            else:
                tl.store(c_ptrs, c, mask=c_mask)
            c_ptrs += stride_cn * BLOCK_SIZE_N
            lb_ptrs += stride_lbn * BLOCK_SIZE_N


def fused_lora(input: torch.Tensor,
               lora_a: torch.Tensor,
               lora_b: torch.Tensor,
               scaling: torch.LongTensor,
               rank_start: torch.LongTensor,
               ranks: torch.LongTensor,
               seq_start: torch.LongTensor,
               seq_lens: torch.LongTensor,
               adapter_ids: torch.LongTensor,
               max_rank: int,
               max_seqlen: int,
               output: torch.Tensor = None,
               cum: bool = False):
    """Fused lora."""

    def grid(META):
        ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M'])), batch_size)
        return ret

    assert input.dim() == 2
    batch_size = seq_lens.numel()
    M, K = input.shape
    N = lora_b.size(1)

    if output is None:
        output = input.new_empty((M, N))
        cum = False
    else:
        assert output.size(0) == M
        assert output.size(1) == N

    BLOCK_SIZE_R = max(16, max_rank)
    _fused_lora_kernel[grid](
        input,
        lora_a,
        lora_b,
        output,
        scaling,
        rank_start,
        ranks,
        seq_start,
        seq_lens,
        adapter_ids,
        N,
        K,
        stride_am=input.stride(0),
        stride_ak=input.stride(1),
        stride_lar=lora_a.stride(0),
        stride_lak=lora_a.stride(1),
        stride_lbr=lora_b.stride(0),
        stride_lbn=lora_b.stride(1),
        stride_cm=output.stride(0),
        stride_cn=output.stride(1),
        BLOCK_SIZE_R=BLOCK_SIZE_R,
        CUM=cum,
    )

    return output


================================================
FILE: lmdeploy/pytorch/kernels/cuda/fused_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
from typing import Callable

import torch
import triton
import triton.language as tl

from .activation import silu_and_mul


def get_cuda_autotune_config():
    return [
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 256,
            'BLOCK_SIZE_K': 64,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=3,
                      num_warps=8),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 256,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        # SM8
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 256,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 64,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        # SM7-
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 32,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 32,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=5,
                      num_warps=2),
    ]


def _config_prune_func(config: list, *args, **kwargs):
    """Fused moe config prune."""
    device_cap = torch.cuda.get_device_capability()
    num_sm9x = 2
    cum_num_sm8x = 5

    if device_cap[0] >= 9:
        return config[:num_sm9x]
    elif device_cap[0] >= 8:
        return config[num_sm9x:cum_num_sm8x]
    else:
        return config[cum_num_sm8x:]


@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['N', 'K', 'tune_hint'],
    prune_configs_by=dict(early_config_prune=_config_prune_func),
)
@triton.jit
def fused_moe_kernel(
    A,
    B,
    bias,
    C,
    SortedIdx,
    ExpStart,
    ExpEnd,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am: tl.constexpr,
    stride_ak: tl.constexpr,
    stride_be: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_bk: tl.constexpr,
    stride_cm: tl.constexpr,
    stride_cn: tl.constexpr,
    stride_bie: tl.constexpr,
    stride_bin: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    M_NP2: tl.constexpr,
    tune_hint: tl.constexpr,
    top_k: tl.constexpr,
    expert_offset: tl.constexpr,
    reindex_a: tl.constexpr,
    reindex_c: tl.constexpr,
):
    """Fused moe kernel."""
    exp_id = tl.program_id(1)
    pid = tl.program_id(0)

    exp_start = tl.load(ExpStart + exp_id + expert_offset)
    exp_end = tl.load(ExpEnd + exp_id + expert_offset)
    M = exp_end - exp_start
    if M <= 0:
        return

    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

    if GROUP_SIZE_M == 1:
        pid_m = pid % num_pid_m
        pid_n = pid // num_pid_m
    else:
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
        return

    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    mask_sid = offs_sid < exp_end
    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    if reindex_a:
        offs_am = sid // top_k
    else:
        offs_am = offs_sid
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

    # deepseek has 160 experts, exp index would overflow int32
    exp_off = stride_be * exp_id.to(tl.int64)
    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, acc=accumulator)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    if bias is not None:
        bias_ptrs = bias + exp_id * stride_bie + offs_bn * stride_bin
        bias_val = tl.load(bias_ptrs).to(accumulator.dtype)
        accumulator += bias_val[None]

    c = accumulator.to(A.dtype.element_ty)

    if reindex_c:
        offs_cm = sid
    else:
        offs_cm = offs_sid
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
    tl.store(c_ptrs, c, mask=mask_sid[:, None])


def fused_moe_kernel_launcher(
    A: torch.Tensor,
    B: torch.Tensor,
    C: torch.Tensor,
    sorted_idx: torch.Tensor,
    exp_start: torch.Tensor,
    exp_end: torch.Tensor,
    bias: torch.Tensor = None,
    top_k: int = 1,
    num_tokens: int = None,
    expert_offset: int = 0,
    reindex_a: bool = True,
    reindex_c: bool = True,
):
    """Fused moe kernel launcher."""

    if num_tokens is None:
        num_tokens = A.size(0)
    M_NP2 = triton.next_power_of_2(num_tokens)
    M_NP2 = max(64, M_NP2)
    E, N, K = B.shape
    tune_hint = min(2, triton.cdiv(M_NP2, 512))

    def _grid_fn(META):
        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)
        return grid

    A = A.flatten(0, -2)
    C = C.flatten(0, -2)
    enable_bias = bias is not None

    grid = _grid_fn
    fused_moe_kernel[grid](
        A,
        B,
        bias,
        C,
        sorted_idx,
        exp_start,
        exp_end,
        N=N,
        K=K,
        stride_am=A.stride(0),
        stride_ak=A.stride(1),
        stride_be=B.stride(0),
        stride_bn=B.stride(1),
        stride_bk=B.stride(2),
        stride_cm=C.stride(0),
        stride_cn=C.stride(1),
        stride_bie=bias.stride(0) if enable_bias else 0,
        stride_bin=bias.stride(1) if enable_bias else 0,
        tune_hint=tune_hint,
        top_k=top_k,
        expert_offset=expert_offset,
        reindex_a=reindex_a,
        reindex_c=reindex_c,
        M_NP2=M_NP2,
    )


@triton.jit
def _get_exp_mask_kernel(
    a_ptr,
    o_mask_ptr,
    o_k_ptr,
    stride_a_token: tl.constexpr,
    stride_a_exp: tl.constexpr,
    stride_o_exp,
    stride_o_token: tl.constexpr,
    topk: tl.constexpr,
    num_experts: tl.constexpr,
    BLOCK_NA: tl.constexpr,
    BLOCK_NO: tl.constexpr,
):
    token_id = tl.program_id(0)

    offs_n = tl.arange(0, BLOCK_NA)
    mask_n = offs_n < topk
    a_ptrs = a_ptr + token_id * stride_a_token + offs_n * stride_a_exp
    a = tl.load(a_ptrs, mask=mask_n)

    # fill zeros
    offs_no = tl.arange(0, BLOCK_NO)
    mask_no = offs_no < num_experts
    o_ptrs = o_mask_ptr + token_id * stride_o_token + offs_no * stride_o_exp
    tl.store(o_ptrs, 0, mask=mask_no)

    # fill a
    o_ptrs = o_mask_ptr + token_id * stride_o_token + a * stride_o_exp
    tl.store(o_ptrs, 1, mask=mask_n)

    # fill kid
    ok_ptrs = o_k_ptr + token_id * stride_o_token + a * stride_o_exp
    tl.store(ok_ptrs, offs_n, mask=mask_n)


def _get_exp_mask(topk_ids: torch.Tensor, num_experts: int):
    """Get exp mask."""
    assert topk_ids.dim() == 2
    M, topk = topk_ids.shape
    assert topk <= num_experts

    out_mask = topk_ids.new_empty((num_experts, M))
    out_k = topk_ids.new_empty((num_experts, M))
    BLOCK_NA = triton.next_power_of_2(topk)
    BLOCK_NO = triton.next_power_of_2(num_experts)

    grid = (M, )
    _get_exp_mask_kernel[grid](
        topk_ids,
        out_mask,
        out_k,
        stride_a_token=topk_ids.stride(0),
        stride_a_exp=topk_ids.stride(1),
        stride_o_exp=out_mask.stride(0),
        stride_o_token=out_mask.stride(1),
        topk=topk,
        num_experts=num_experts,
        BLOCK_NA=BLOCK_NA,
        BLOCK_NO=BLOCK_NO,
        num_warps=1,
    )
    return out_mask, out_k


@triton.jit
def _get_start_end_kernel(
    exp_cum_ptr,
    exp_topk_ptr,
    exp_out_ptr,
    start_ptr,
    end_ptr,
    stride_cum_exp,
    stride_cum_token: tl.constexpr,
    stride_out: tl.constexpr,
    num_tokens,
    num_experts: tl.constexpr,
    topk: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """Get start end kernel."""
    token_start = tl.program_id(0)

    offs_exp = tl.arange(0, BLOCK_N)
    off_cum = offs_exp * stride_cum_exp + token_start * stride_cum_token
    cum_ptrs = exp_cum_ptr + off_cum
    val_k_ptrs = exp_topk_ptr + off_cum

    mask_exp = offs_exp < num_experts

    # get prev and cur cum
    token_id = token_start
    prev_cum_mask = mask_exp
    if token_start == 0:
        prev_cum_mask = mask_exp & (tl.arange(0, BLOCK_N) > 0)
    prev_cum = tl.load(cum_ptrs - stride_cum_token, mask=prev_cum_mask, other=0)
    cur_cum = tl.load(cum_ptrs, mask=mask_exp)

    # store sorted idx
    mask_out = mask_exp & (cur_cum > prev_cum)
    val_k = tl.load(val_k_ptrs, mask=mask_exp)
    val = token_id * topk + val_k
    out_ptrs = exp_out_ptr + prev_cum * stride_out
    tl.store(out_ptrs, val, mask=mask_out)

    # fill start
    if token_id == 0:
        cur_start_ptrs = start_ptr + offs_exp
        tl.store(cur_start_ptrs, prev_cum, mask=mask_exp)

    # fill end
    if token_id == num_tokens - 1:
        cur_end_ptrs = end_ptr + offs_exp
        tl.store(cur_end_ptrs, cur_cum, mask=mask_exp)


def get_start_end(exp_cum: torch.Tensor, exp_topk: torch.Tensor, topk: int):
    """Get start end."""
    num_experts, num_tokens = exp_cum.shape

    start_end = exp_cum.new_empty(2, num_experts)
    exp_start = start_end[0, :]
    exp_end = start_end[1, :]

    out = exp_cum.new_empty((num_tokens * topk))

    num_warps = 1

    BLOCK_N = triton.next_power_of_2(num_experts)
    grid = (num_tokens, )

    _get_start_end_kernel[grid](
        exp_cum,
        exp_topk,
        out,
        exp_start,
        exp_end,
        stride_cum_exp=exp_cum.stride(0),
        stride_cum_token=exp_cum.stride(1),
        stride_out=out.stride(0),
        num_tokens=num_tokens,
        num_experts=num_experts,
        topk=topk,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
    )
    return out, exp_start, exp_end


def _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int):
    """Get sorted idx."""
    assert topk_ids.dim() == 2
    _, topk = topk_ids.shape

    # get expert mask   (num_experts, num_tokens)
    exp_mask, exp_topk = _get_exp_mask(topk_ids, num_experts)
    # get cumsum   (num_experts, num_tokens)
    exp_cum = exp_mask.flatten().cumsum(0).view_as(exp_mask)

    # get sort idx and start/end
    sorted_idx, start, end = get_start_end(exp_cum, exp_topk, topk)

    return sorted_idx, start, end


def _renormalize(topk_weights: torch.Tensor, renormalize: bool):
    if renormalize:
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
    if not topk_weights.is_contiguous():
        topk_weights = topk_weights.contiguous()
    return topk_weights


def _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device, zeros: bool):
    """Make intermediate."""
    if zeros:
        return torch.zeros(shape, dtype=dtype, device=device)
    else:
        return torch.empty(shape, dtype=dtype, device=device)


@triton.jit
def _moe_reduce_kernel(
    hidden_states_ptr,
    weights_ptr,
    out_ptr,
    stride_hm,
    stride_hk: tl.constexpr,
    stride_hn: tl.constexpr,
    stride_wm,
    stride_wk: tl.constexpr,
    stride_om,
    stride_on: tl.constexpr,
    fp32_acc: tl.constexpr,
    K: tl.constexpr,
    N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid = tl.program_id(0)
    num_n_split = tl.cdiv(N, BLOCK_N)
    mid = pid // num_n_split
    nid = pid % num_n_split

    offs_k = tl.arange(0, BLOCK_K)
    offs_n = nid * BLOCK_N + tl.arange(0, BLOCK_N)
    weights_ptrs = weights_ptr + mid * stride_wm + offs_k * stride_wk
    h_ptrs = hidden_states_ptr + mid * stride_hm + offs_k[:, None] * stride_hk + offs_n[None, :] * stride_hn
    o_ptrs = out_ptr + mid * stride_om + offs_n * stride_on

    mask_k = offs_k < K
    mask_n = offs_n < N  # dummy load to get N
    mask_h = mask_k[:, None] & mask_n[None, :]

    h = tl.load(h_ptrs, mask=mask_h, other=0.0)
    w = tl.load(weights_ptrs, mask=mask_k, other=0.0)

    if fp32_acc:
        h = h.to(tl.float32)
        w = w.to(tl.float32)
    else:
        w = w.to(h.dtype)

    wh = h * w[:, None]
    o = wh.sum(axis=0)
    tl.store(o_ptrs, o, mask=mask_n)


def moe_reduce(hidden_states: torch.Tensor, topk_weights: torch.Tensor, fp32_acc: bool = False) -> torch.Tensor:
    """Moe reduce."""
    assert hidden_states.dim() == 3
    assert topk_weights.dim() == 2
    assert hidden_states.size(0) == topk_weights.size(0)
    assert hidden_states.size(1) == topk_weights.size(1)
    M, K, N = hidden_states.shape

    out = hidden_states.new_empty((M, N))

    BLOCK_K = triton.next_power_of_2(K)
    num_warps = 1
    BLOCK_N = triton.cdiv(num_warps * 512, hidden_states.element_size())
    grid = (M * triton.cdiv(N, BLOCK_N), )

    _moe_reduce_kernel[grid](
        hidden_states,
        topk_weights,
        out,
        hidden_states.stride(0),
        hidden_states.stride(1),
        hidden_states.stride(2),
        topk_weights.stride(0),
        topk_weights.stride(1),
        out.stride(0),
        out.stride(1),
        fp32_acc,
        K,
        N,
        BLOCK_K,
        BLOCK_N,
        num_warps=num_warps,
    )

    return out


def fused_moe(hidden_states: torch.Tensor,
              w1: torch.Tensor,
              w2: torch.Tensor,
              topk_weights: torch.Tensor,
              topk_ids: torch.Tensor,
              topk: int,
              w1_bias: torch.Tensor = None,
              w2_bias: torch.Tensor = None,
              expert_offset: int = 0,
              num_experts: int = None,
              renormalize: bool = False,
              act_func: Callable = None) -> torch.Tensor:
    """Fused moe."""
    M = hidden_states.size(0)
    E, N, _ = w1.shape
    if num_experts is None:
        num_experts = E
    full_exp = num_experts == E

    topk_weights = _renormalize(topk_weights, renormalize)
    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)

    intermediate_cache1 = _make_intermediate((M, topk, N),
                                             dtype=hidden_states.dtype,
                                             device=hidden_states.device,
                                             zeros=not full_exp)
    # gate and up
    fused_moe_kernel_launcher(
        hidden_states,
        w1,
        intermediate_cache1,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        bias=w1_bias,
        top_k=topk,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=True,
        reindex_c=False,
    )

    # activate
    unflat_size = intermediate_cache1.shape[:-1]
    intermediate_cache1 = intermediate_cache1.flatten(0, -2)

    if act_func is None:
        gate_cache = silu_and_mul(intermediate_cache1)
    else:
        gate_cache = act_func(intermediate_cache1)
    gate_cache = gate_cache.unflatten(0, unflat_size)

    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
                                             dtype=hidden_states.dtype,
                                             device=hidden_states.device,
                                             zeros=not full_exp)
    # down
    fused_moe_kernel_launcher(
        gate_cache,
        w2,
        intermediate_cache2,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        bias=w2_bias,
        top_k=1,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=False,
        reindex_c=True,
    )

    ret = moe_reduce(intermediate_cache2, topk_weights)
    return ret


================================================
FILE: lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from dlblas: https://github.com/DeepLink-org/DLBlas
from typing import List, Optional

import torch
import triton
import triton.language as tl

from .activation import silu_and_mul


@triton.jit
def _fwd_kernel_ep_scatter_step1(
    num_recv_tokens_per_expert,
    expert_start_loc,
    m_indices,
    num_experts: tl.constexpr,
    BLOCK_E: tl.constexpr,
    BLOCK_EXPERT_NUM: tl.constexpr,
):
    cur_expert = tl.program_id(0)
    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
    tokens_per_expert = tl.load(
        num_recv_tokens_per_expert + offset_cumsum,
        mask=offset_cumsum < num_experts,
        other=0,
    )
    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
    cur_expert_start = tl.load(expert_start_loc + cur_expert)
    cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
    m_indices_start_ptr = m_indices + cur_expert_start
    off_expert = tl.arange(0, BLOCK_E)
    for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
        tl.store(
            m_indices_start_ptr + start_m + off_expert,
            cur_expert,
        )


@triton.jit
def _fwd_kernel_ep_scatter_step2(
    total_token_num,
    expert_start_loc,
    recv_x,
    recv_x_stride0,
    recv_x_stride1,
    recv_topk,
    recv_topk_stride0,
    recv_topk_stride1,
    output_tensor,
    output_tensor_stride0,
    output_tensor_stride1,
    output_index,
    output_index_stride0,
    output_index_stride1,
    topk_num: tl.constexpr,
    HIDDEN_SIZE: tl.constexpr,
    HIDDEN_SIZE_PAD: tl.constexpr,
):
    start_token_id = tl.program_id(0)
    grid_num = tl.num_programs(0)
    offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
    mask = offset_in < HIDDEN_SIZE
    for token_id in range(start_token_id, total_token_num, grid_num):
        to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
        for topk_index in tl.range(0, topk_num, 1, num_stages=4):
            expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
            if expert_id >= 0:
                dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
                dest_token_index = dest_token_index.to(tl.int64)
                tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index)
                output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0
                tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)


# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
def ep_scatter(
    recv_x: torch.Tensor,
    recv_topk: torch.Tensor,
    num_recv_tokens_per_expert: torch.Tensor,
    expert_start_loc: torch.Tensor,
    output_tensor: torch.Tensor,
    m_indices: torch.Tensor,
    output_index: torch.Tensor,
):
    BLOCK_E = 128  # token num of per expert is aligned to 128
    num_warps = 8
    num_experts = num_recv_tokens_per_expert.shape[0]
    hidden_size = recv_x.shape[1]
    grid = num_experts
    assert m_indices.shape[0] % BLOCK_E == 0
    _fwd_kernel_ep_scatter_step1[(grid, )](
        num_recv_tokens_per_expert,
        expert_start_loc,
        m_indices,
        num_experts=num_experts,
        num_warps=num_warps,
        BLOCK_E=BLOCK_E,
        BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
    )
    grid = min(recv_topk.shape[0], 1024 * 8)
    _fwd_kernel_ep_scatter_step2[(grid, )](
        recv_topk.shape[0],
        expert_start_loc,
        recv_x,
        recv_x.stride(0),
        recv_x.stride(1),
        recv_topk,
        recv_topk.stride(0),
        recv_topk.stride(1),
        output_tensor,
        output_tensor.stride(0),
        output_tensor.stride(1),
        output_index,
        output_index.stride(0),
        output_index.stride(1),
        topk_num=recv_topk.shape[1],
        num_warps=num_warps,
        HIDDEN_SIZE=hidden_size,
        HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
    )
    return


@triton.jit
def _fwd_kernel_ep_gather(
    total_token_num,
    input_tensor,
    input_tensor_stride0,
    input_tensor_stride1,
    recv_topk_ids,
    recv_topk_ids_stride0,
    recv_topk_ids_stride1,
    recv_topk_weight,
    recv_topk_weight_stride0,
    recv_topk_weight_stride1,
    input_index,
    input_index_stride0,
    input_index_stride1,
    output_tensor,
    output_tensor_stride0,
    output_tensor_stride1,
    topk_num: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    cur_block = tl.program_id(0)
    start_cur_token = tl.program_id(1)
    grid_num = tl.num_programs(1)
    # align with xtuner rl
    compute_dtype = output_tensor.dtype.element_ty
    # compute_dtype = tl.float32

    for cur_token in range(start_cur_token, total_token_num, grid_num):
        off_d = tl.arange(0, BLOCK_D)
        accumulator = tl.zeros([BLOCK_D], dtype=compute_dtype)
        for topk_index in range(0, topk_num):
            expert_id = tl.load(recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index)
            if expert_id >= 0:
                source_token_index = tl.load(input_index + cur_token * input_index_stride0 + topk_index)
                acc_weight = tl.load(recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index)
                tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d)
                accumulator += tmp.to(compute_dtype) * acc_weight.to(compute_dtype)
        tl.store(
            output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d,
            accumulator.to(output_tensor.dtype.element_ty),
        )


@torch.no_grad()
def ep_gather(
    input_tensor: torch.Tensor,
    recv_topk_ids: torch.Tensor,
    recv_topk_weight: torch.Tensor,
    input_index: torch.Tensor,
    output_tensor: torch.Tensor,
):
    BLOCK_D = 1024  # block size of quantization
    num_warps = 2
    num_tokens = output_tensor.shape[0]
    hidden_size = input_tensor.shape[1]
    assert hidden_size % BLOCK_D == 0
    grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
    _fwd_kernel_ep_gather[grid](
        num_tokens,
        input_tensor,
        input_tensor.stride(0),
        input_tensor.stride(1),
        recv_topk_ids,
        recv_topk_ids.stride(0),
        recv_topk_ids.stride(1),
        recv_topk_weight,
        recv_topk_weight.stride(0),
        recv_topk_weight.stride(1),
        input_index,
        input_index.stride(0),
        input_index.stride(1),
        output_tensor,
        output_tensor.stride(0),
        output_tensor.stride(1),
        topk_num=recv_topk_ids.shape[1],
        num_warps=num_warps,
        BLOCK_D=BLOCK_D,
    )
    return


def _deepgemm_grouped_bf16_nt_contiguous(
    x: torch.Tensor,
    w: torch.Tensor,
    out: torch.Tensor,
    m_indices: torch.Tensor,
):
    from lmdeploy.pytorch.third_party import deep_gemm
    return deep_gemm.m_grouped_bf16_gemm_nt_contiguous(x, w, out, m_indices)


def fused_moe_v3(
    hidden_states: torch.Tensor,
    topk_idx,
    topk_weights,
    w13_weight: torch.Tensor,
    w2_weight: torch.Tensor,
    num_recv_tokens_per_expert: Optional[List[int]],
):
    if num_recv_tokens_per_expert is None:
        return hidden_states
    all_tokens = sum(num_recv_tokens_per_expert)
    if all_tokens <= 0:
        return hidden_states
    M, K = hidden_states.size()
    N = w13_weight.size(1)
    gather_out = torch.empty_like(hidden_states)
    input_tensor = hidden_states.new_empty((all_tokens, K))
    m_indices = hidden_states.new_empty(all_tokens, dtype=torch.int32)
    output_index = torch.empty_like(topk_idx)
    num_recv_tokens_per_expert_gpu = torch.tensor(
        num_recv_tokens_per_expert,
        dtype=torch.int32,
        pin_memory=True,
        device='cpu',
    ).cuda(non_blocking=True)
    expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
    ep_scatter(
        hidden_states,
        topk_idx,
        num_recv_tokens_per_expert_gpu,
        expert_start_loc,
        input_tensor,
        m_indices,
        output_index,
    )
    del hidden_states
    gateup_output = gather_out.new_empty((all_tokens, N))
    _deepgemm_grouped_bf16_nt_contiguous(input_tensor, w13_weight, gateup_output, m_indices)
    down_input = gateup_output.new_empty((
        all_tokens,
        N // 2,
    ))
    down_input = silu_and_mul(gateup_output.view(-1, N), down_input)
    down_output = gather_out.new_empty((all_tokens, K))
    _deepgemm_grouped_bf16_nt_contiguous(
        down_input,
        w2_weight,
        down_output,
        m_indices,
    )
    ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
    return gather_out


================================================
FILE: lmdeploy/pytorch/kernels/cuda/fused_noaux_tc.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1, num_stages=1),
        triton.Config({}, num_warps=1, num_stages=2),
        triton.Config({}, num_warps=1, num_stages=3),
        triton.Config({}, num_warps=1, num_stages=4),
        triton.Config({}, num_warps=2, num_stages=1),
        triton.Config({}, num_warps=2, num_stages=2),
        triton.Config({}, num_warps=2, num_stages=3),
        triton.Config({}, num_warps=2, num_stages=4),
        triton.Config({}, num_warps=4, num_stages=1),
        triton.Config({}, num_warps=4, num_stages=2),
        triton.Config({}, num_warps=4, num_stages=3),
        triton.Config({}, num_warps=4, num_stages=4),
        triton.Config({}, num_warps=8, num_stages=1),
        triton.Config({}, num_warps=8, num_stages=2),
        triton.Config({}, num_warps=8, num_stages=3),
        triton.Config({}, num_warps=8, num_stages=4),
    ],
    key=['num_experts', 'n_group'],
)
@triton.jit
def _noaux_routing_kernel(
    logits_ptr,
    bias_ptr,
    scores_ptr,
    tmp_scores_ptr,
    batch_size,
    num_experts: tl.constexpr,
    n_group: tl.constexpr,
    group_size: tl.constexpr,
    topk_group: tl.constexpr,
    # The following arguments are not used inside the kernel but kept for signature compatibility
    renormalize: tl.constexpr,
    routed_scaling_factor,
    logits_stride_0,
    logits_stride_1,
    bias_stride_0,
    scores_stride_0,
    scores_stride_1,
    tmp_scores_stride_0,
    tmp_scores_stride_1,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= batch_size:
        return
    idx = tl.arange(0, BLOCK_SIZE)
    mask = idx < num_experts  # always true if BLOCK_SIZE == num_experts, but kept for safety
    # 1. Load logits and bias
    logits = tl.load(logits_ptr + pid * logits_stride_0 + idx * logits_stride_1, mask=mask, other=0.0)
    bias = tl.load(bias_ptr + idx * bias_stride_0, mask=mask, other=0.0)
    # 2. Compute scores (sigmoid) and bias‑adjusted scores
    scores = tl.sigmoid(logits)  # original scores
    scores_fc = scores + bias  # bias‑adjusted scores
    # 3. Compute group scores: sum of top‑2 scores_fc per group
    # Reshape to (n_group, group_size) – requires BLOCK_SIZE == num_experts
    scores_fc_2d = tl.reshape(scores_fc, (n_group, group_size))
    # Max and argmax per group
    max_val = tl.max(scores_fc_2d, axis=1)
    max_idx = tl.argmax(scores_fc_2d, axis=1)  # index within group (0..group_size-1)
    # Second max per group: mask out the max element
    col_range = tl.arange(0, group_size)
    mask_max = col_range[None, :] == max_idx[:, None]
    scores_fc_masked = tl.where(mask_max, -float('inf'), scores_fc_2d)
    second_max = tl.max(scores_fc_masked, axis=1)
    group_scores = max_val + second_max
    # 4. Select top‑k groups and build selected_mask
    selected_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int1)
    group_scores_copy = group_scores
    for _ in range(topk_group):
        max_idx_g = tl.argmax(group_scores_copy, axis=0)  # group index
        # mark experts in this group
        group_start = max_idx_g * group_size
        group_end = group_start + group_size
        group_mask = (idx >= group_start) & (idx < group_end) & mask
        selected_mask = selected_mask | group_mask
        # remove this group
        g_idx = tl.arange(0, n_group)
        g_mask = g_idx == max_idx_g
        group_scores_copy = tl.where(g_mask, -float('inf'), group_scores_copy)
    # 5. Build masked scores (tmp_scores) – experts in selected groups keep scores_fc, others 0
    tmp_scores = tl.where(selected_mask, scores_fc, 0.0)
    # 6. Store outputs
    off_scores = pid * scores_stride_0 + idx * scores_stride_1
    tl.store(scores_ptr + off_scores, scores, mask=mask)
    off_tmp = pid * tmp_scores_stride_0 + idx * tmp_scores_stride_1
    tl.store(tmp_scores_ptr + off_tmp, tmp_scores, mask=mask)


# ---------------------------------------------------------------------------
# Wrappers and Benchmarking Logic (Kept exactly as requested)
# ---------------------------------------------------------------------------


def fused_noaux_tc_routing(
    logits: torch.Tensor,
    bias: torch.Tensor,
    num_experts: int = 256,
    n_group: int = 8,
    topk_group: int = 4,
    top_k: int = 8,
    renormalize: bool = True,
    routed_scaling_factor: float = 2.5,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size = logits.shape[0]
    group_size = num_experts // n_group
    assert num_experts % n_group == 0, 'num_experts must be divisible by n_group'
    # Convert to float32 and ensure contiguous
    logits = logits.float().contiguous()
    bias = bias.float().contiguous()
    # Output tensors from the kernel
    scores = torch.empty(batch_size, num_experts, device=logits.device, dtype=torch.float32)
    tmp_scores = torch.empty(batch_size, num_experts, device=logits.device, dtype=torch.float32)
    # Block size: exactly num_experts (must be multiple of 32 for good performance)
    BLOCK_SIZE = num_experts
    # Ensure BLOCK_SIZE is at least 32 and a multiple of 32? Not strictly required but good.
    # If not multiple of 32, we could round up, but then reshape would break. So we assume it is.
    # For safety, we assert:
    assert BLOCK_SIZE % 32 == 0, 'num_experts must be a multiple of 32 for optimal performance'
    # Kernel launch
    grid = (batch_size, )
    _noaux_routing_kernel[grid](
        logits,
        bias,
        scores,
        tmp_scores,
        batch_size,
        num_experts=num_experts,
        n_group=n_group,
        group_size=group_size,
        topk_group=topk_group,
        renormalize=int(renormalize),  # not used inside kernel
        routed_scaling_factor=routed_scaling_factor,
        logits_stride_0=logits.stride(0),
        logits_stride_1=logits.stride(1),
        bias_stride_0=bias.stride(0),
        scores_stride_0=scores.stride(0),
        scores_stride_1=scores.stride(1),
        tmp_scores_stride_0=tmp_scores.stride(0),
        tmp_scores_stride_1=tmp_scores.stride(1),
        BLOCK_SIZE=BLOCK_SIZE,
    )
    # Final expert selection using PyTorch's topk (guarantees exact match)
    _, topk_idx = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
    topk_weight = scores.gather(1, topk_idx)
    if renormalize:
        topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
    topk_weight = topk_weight * routed_scaling_factor
    return topk_weight, topk_idx


================================================
FILE: lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

import tilelang
import tilelang.language as T
import torch
from tvm import tir

BufferLikeType = tir.Buffer | tir.BufferRegion | tir.BufferLoad


@T.macro
def normalize_qk(k_local: T.Buffer, q_local: T.Buffer, k_per_thr: int) -> None:
    k_sum = T.alloc_var(T.float32)
    q_sum = T.alloc_var(T.float32)
    k_sum = 0
    q_sum = 0
    for i in T.Unroll(k_per_thr):
        k_sum += k_local[i] * k_local[i]
        q_sum += q_local[i] * q_local[i]
    k_sum = T.warp_reduce_sum(k_sum)
    q_sum = T.warp_reduce_sum(q_sum)
    k_norm = T.rsqrt(k_sum + 1e-6)
    q_norm = T.rsqrt(q_sum + 1e-6)
    for i in T.Unroll(k_per_thr):
        k_local[i] = k_local[i] * k_norm
        q_local[i] = q_local[i] * q_norm


@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, )
def fused_recurrent_gated_delta_rule_fwd(SEQLEN,
                                         H,
                                         K,
                                         HV,
                                         V,
                                         NUM_STATE,
                                         q_stride: Sequence[int],
                                         k_stride: Sequence[int],
                                         v_stride: Sequence[int],
                                         state_stride: Sequence[int],
                                         scale,
                                         dtype,
                                         state_dtype,
                                         g_dtype=None,
                                         beta_dtype=None,
                                         use_g: bool = False,
                                         use_beta: bool = False,
                                         use_qk_l2norm_in_kernel: bool = False,
                                         output_final_state: bool = False,
                                         use_state_indices: bool = False,
                                         is_circular_buffer: bool = False,
                                         num_warps: int = 1):

    num_threads = num_warps * 32
    state_num_bits = T.DataType(state_dtype).bits
    data_num_bits = T.DataType(dtype).bits
    state_vec_width = 128 // state_num_bits
    data_vec_width = 128 // data_num_bits
    warp_size = 32
    k_per_thr = T.ceildiv(K, warp_size)
    v_per_warp = max(state_vec_width, data_vec_width, 8)
    # Target v_per_cta >= V to minimize grid_V blocks.
    # More waves means fewer blocks but more sequential wave iterations.
    target_v_per_cta = max(V, v_per_warp * num_warps * 2)
    num_waves = T.ceildiv(target_v_per_cta, v_per_warp * num_warps)
    v_per_cta = v_per_warp * num_warps * num_waves

    B = T.dynamic('B')
    N = B if not use_state_indices else T.dynamic('N')

    # dtype
    if g_dtype is None:
        g_dtype = dtype
    if beta_dtype is None:
        beta_dtype = dtype

    @T.prim_func
    def fused_recurrent_gated_delta_rule_main(
        Query: T.StridedTensor([B, SEQLEN, H, K], dtype=dtype, strides=q_stride),
        Key: T.StridedTensor([B, SEQLEN, H, K], dtype=dtype, strides=k_stride),
        Value: T.StridedTensor([B, SEQLEN, HV, V], dtype=dtype, strides=v_stride),
        Out: T.Tensor([B, SEQLEN, HV, V], dtype=dtype),
        G: T.Tensor([B, SEQLEN, HV], dtype=g_dtype),
        Beta: T.Tensor([B, SEQLEN, HV], dtype=beta_dtype),
        State: T.StridedTensor([N, NUM_STATE, HV, K, V], dtype=state_dtype, strides=state_stride),
        StateIndices: T.Tensor([B], dtype=torch.int64) = None,
        CacheSeqlens: T.Tensor([B], dtype=torch.int32) = None,
    ):
        with T.Kernel(T.ceildiv(V, v_per_cta), B * HV, threads=num_threads) as (v_start, bhv_idx):
            tidx = T.get_thread_binding(0)
            b_id = bhv_idx // HV
            hv_id = bhv_idx % HV
            h_id = hv_id // (HV // H)
            warp_id = tidx // warp_size
            lane_id = tidx % warp_size
            k_off = lane_id * k_per_thr

            # state_idx
            if use_state_indices:
                state_id = StateIndices[b_id]
            else:
                state_id = b_id

            if is_circular_buffer:
                state_seq_id = CacheSeqlens[b_id] % NUM_STATE
                state_update_id = T.alloc_var(T.int32)
                state_update_id = (state_seq_id + 1) % NUM_STATE
            else:
                state_seq_id = 0
                state_update_id = 0

            # load states
            h_smem = T.alloc_shared([K, v_per_cta], state_dtype)
            T.annotate_layout({h_smem: tilelang.layout.make_swizzled_layout(h_smem)})
            for i, j in T.Parallel(K, v_per_cta):
                v_idx = v_start * v_per_cta + j
                if v_idx < V:
                    h_smem[i, j] = State[state_id, state_seq_id, hv_id, i, v_idx]
                else:
                    h_smem[i, j] = 0.0

            # since H is more heavy than qkv, we would put wave loop outside
            for wave_id in range(num_waves):
                # load states local

                v_warp_off = wave_id * num_warps * v_per_warp + warp_id * v_per_warp
                v_off = v_start * v_per_cta + v_warp_off
                h_local = T.alloc_local([k_per_thr, v_per_warp], T.float32)
                if is_circular_buffer:
                    state_update_id = (state_seq_id + 1) % NUM_STATE
                for j in T.Unroll(k_per_thr):
                    k_idx = k_off + j
                    for vg in T.Unroll(v_per_warp // state_vec_width):
                        for i in T.Vectorized(state_vec_width):
                            idx = vg * state_vec_width + i
                            h_local[j, idx] = h_smem[k_idx, v_warp_off + idx]

                for seq_id in range(SEQLEN):
                    # load q, k, g, beta
                    q_local = T.alloc_local([k_per_thr], T.float32)
                    k_local = T.alloc_local([k_per_thr], T.float32)
                    for i in T.Vectorized(k_per_thr):
                        k_idx = (k_off + i) % K
                        q_local[i] = Query[b_id, seq_id, h_id, k_idx]
                    for i in T.Vectorized(k_per_thr):
                        k_idx = (k_off + i) % K
                        k_local[i] = Key[b_id, seq_id, h_id, k_idx]

                    # normalize
                    if use_qk_l2norm_in_kernel:
                        normalize_qk(k_local, q_local, k_per_thr)

                    for i in T.Vectorized(k_per_thr):
                        q_local[i] = q_local[i] * scale

                    # load g, beta
                    if use_g:
                        g = T.cast(G[b_id, seq_id, hv_id], T.float32)
                    else:
                        g = 0.0
                    g_exp = T.exp(g)
                    if use_beta:
                        beta = T.cast(Beta[b_id, seq_id, hv_id], T.float32)
                    else:
                        beta = 1.0

                    # load v
                    v_local = T.alloc_local([v_per_warp], dtype)
                    for vg in T.Unroll(v_per_warp // data_vec_width):
                        for i in T.Vectorized(data_vec_width):
                            idx = vg * data_vec_width + i
                            v_idx = (v_off + idx) % V
                            v_local[idx] = Value[b_id, seq_id, hv_id, v_idx]

                    # update states
                    for i in T.Unroll(v_per_warp):
                        hk = T.alloc_var(T.float32)
                        hk = 0
                        for j in T.Unroll(k_per_thr):
                            h_local[j, i] = h_local[j, i] * g_exp
                            hk += h_local[j, i] * k_local[j]
                        hk = T.warp_reduce_sum(hk)
                        v = (v_local[i] - hk) * beta
                        for j in T.Unroll(k_per_thr):
                            h_local[j, i] = h_local[j, i] + k_local[j] * v

                    # store states
                    if output_final_state and state_id >= 0:
                        if is_circular_buffer:
                            for j in T.Unroll(k_per_thr):
                                if (k_off + j) < K:
                                    for vg in T.Unroll(v_per_warp // state_vec_width):
                                        for i in T.Vectorized(state_vec_width):
                                            idx = vg * state_vec_width + i
                                            if v_off + idx < V:
                                                State[state_id, state_update_id, hv_id, k_off + j,
                                                      v_off + idx] = h_local[j, idx]
                            state_update_id = (state_update_id + 1) % NUM_STATE

                    # compute output
                    o_local = T.alloc_local([v_per_warp], dtype)
                    for i in T.Unroll(v_per_warp):
                        # o = q * h
                        o = T.alloc_var(T.float32)
                        o = 0.0
                        for j in T.Unroll(k_per_thr):
                            o += q_local[j] * h_local[j, i]
                        o = T.warp_reduce_sum(o)
                        o_local[i] = o

                    if lane_id == 0 and state_id >= 0:
                        for vg in T.Unroll(v_per_warp // data_vec_width):
                            for i in T.Vectorized(data_vec_width):
                                idx = vg * data_vec_width + i
                                v_idx = (v_off + idx)
                                if v_idx < V:
                                    Out[b_id, seq_id, hv_id, v_idx] = o_local[idx]

                # write h_local back to h_smem for coalesced global store
                if output_final_state and state_id >= 0 and not is_circular_buffer:
                    for j in T.Unroll(k_per_thr):
                        k_idx = k_off + j
                        for vg in T.Unroll(v_per_warp // state_vec_width):
                            for i in T.Vectorized(state_vec_width):
                                idx = vg * state_vec_width + i
                                h_smem[k_idx, v_warp_off + idx] = h_local[j, idx]

            # coalesced state writeback via shared memory
            if output_final_state and state_id >= 0 and not is_circular_buffer:
                for i, j in T.Parallel(K, v_per_cta):
                    v_idx = v_start * v_per_cta + j
                    if v_idx < V:
                        State[state_id, state_update_id, hv_id, i, v_idx] = h_smem[i, j]

    return fused_recurrent_gated_delta_rule_main


def fused_recurrent_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor | None = None,
    beta: torch.Tensor | None = None,
    scale: float | None = None,
    initial_state: torch.Tensor | None = None,
    output_final_state: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
    state_indices: torch.Tensor | None = None,
    cache_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Fused recurrent gated delta rule.

    Args:
        q: [B, T, H, K]
        k: [B, T, H, K]
        v: [B, T, HV, V]
        g: [B, T, HV], optional
        beta: [B, T, HV], optional
        scale: float, optional
        initial_state: [N, HV, K, V], optional, if state_indices is not proviced, N=B
        use_qk_l2norm_in_kernel: whether to apply l2 normalization on q and k in the kernel
        state_indices: [B], optional, the indices to update in the recurrent state, required
        cache_seqlens: [B], optional, the cached sequence lengths for each batch element
    Returns:
        o: [B, T, HV, V]
        final_state: [N, HV, K, V] if output_final_state else None
    """
    # T is imported as tilelang.language, use seqlen instead
    _, seqlen, H, K, V = *k.shape, v.shape[-1]
    HV = v.shape[2]
    if scale is None:
        scale = 1 / (q.shape[-1]**0.5)
    g_dtype = torch.float32
    beta_dtype = torch.float32
    if g is not None:
        assert g.is_contiguous()
        g_dtype = g.dtype
    if beta is not None:
        assert beta.is_contiguous()
        beta_dtype = beta.dtype
    if state_indices is not None:
        assert state_indices.is_contiguous()
        assert initial_state is not None, 'initial_state is required when state_indices is provided'
        assert state_indices.shape == (q.shape[0], )

    o = torch.empty_like(v)
    final_state = initial_state
    state_dtype = q.dtype
    if final_state is not None:
        state_dim = final_state.dim()
        # expand dim
        if state_dim == 4:
            final_state = final_state.unsqueeze(1)
        state_stride = final_state.stride()
        state_dtype = final_state.dtype

        # set and check num states
        num_states = final_state.shape[1]
    else:
        state_dim = 4
        state_stride = (0, 0, 0, 0, 0)
        num_states = 1

    num_warps = 4
    kernel = fused_recurrent_gated_delta_rule_fwd(seqlen,
                                                  H,
                                                  K,
                                                  HV,
                                                  V,
                                                  NUM_STATE=num_states,
                                                  q_stride=q.stride(),
                                                  k_stride=k.stride(),
                                                  v_stride=v.stride(),
                                                  state_stride=state_stride,
                                                  scale=scale,
                                                  dtype=q.dtype,
                                                  state_dtype=state_dtype,
                                                  g_dtype=g_dtype,
                                                  beta_dtype=beta_dtype,
                                                  use_g=g is not None,
                                                  use_beta=beta is not None,
                                                  use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
                                                  output_final_state=output_final_state,
                                                  use_state_indices=state_indices is not None,
                                                  is_circular_buffer=cache_seqlens is not None,
                                                  num_warps=num_warps)

    kernel(q, k, v, o, g, beta, final_state, state_indices, cache_seqlens)

    if not output_final_state:
        final_state = None
    elif final_state is not None and state_dim == 4:
        final_state = final_state.squeeze(1)
    return o, final_state


================================================
FILE: lmdeploy/pytorch/kernels/cuda/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl


@triton.jit
def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, stride_sb, stride_st, stride_ib, stride_it,
                                 num_tokens, BLOCK_N: tl.constexpr):
    """Kernel."""
    batch_id = tl.program_id(0)
    n_off = tl.arange(0, BLOCK_N)

    # sampling random seed
    seed = tl.load(Seeds + batch_id)
    offset = tl.load(Offsets + batch_id).to(tl.int32)
    samp = tl.rand(seed, offset)

    # initialize
    acc = 0.0
    score_ptr = Scores + batch_id * stride_sb + n_off * stride_st
    indice_ptr = Indices + batch_id * stride_ib
    output = tl.load(indice_ptr)

    found_mask = False
    for b_idx in tl.range(0, num_tokens, BLOCK_N):
        # triton does not have break statement, use mask to skip computation
        if not found_mask:
            s_off = b_idx + n_off
            s_mask = (s_off < num_tokens)
            scores = tl.load(score_ptr, mask=s_mask, other=0.0).to(tl.float32)
            c_scores = tl.cumsum(scores, 0)
            cum_scores = acc + c_scores
            acc += tl.max(c_scores, 0)

            pre_cum_scores = cum_scores - scores
            valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
            found_mask = tl.sum(valid_mask, 0) > 0

            if found_mask:
                valid_pos = tl.argmax(valid_mask.to(tl.int32), 0)
                indice = tl.load(indice_ptr + valid_pos * stride_it)
                output = indice
        score_ptr += stride_st * BLOCK_N
        indice_ptr += stride_it * BLOCK_N

    tl.store(Outputs + batch_id, output)


def multinomial_sampling(scores: torch.Tensor,
                         seeds: torch.LongTensor,
                         offsets: torch.LongTensor,
                         indices: torch.Tensor = None):
    """Multinomial sampling.

    Note that this kernel assumes the input scores are already sorted in descending order.

    scores: [batch_size, num_tokens], sorted softmax scores
    seeds: [batch_size]
    offsets: [batch_size]
    indices: [batch_size, num_tokens], original token indices before sorting
    """
    assert scores.dim() == 2
    batch_size, num_tokens = scores.size()
    device = scores.device

    if num_tokens == 1:
        return torch.zeros_like(scores, dtype=torch.long)

    if indices is None:
        indices = torch.arange(num_tokens, device=device)
        indices = indices.expand_as(scores)

    assert indices.dim() == 2
    assert indices.size() == scores.size()

    outputs = indices[:, 0].clone()

    BLOCK_N = 128

    grid = [batch_size]
    _multinomial_sampling_kernel[grid](scores,
                                       seeds,
                                       offsets,
                                       indices,
                                       outputs,
                                       stride_sb=scores.stride(0),
                                       stride_st=scores.stride(1),
                                       stride_ib=indices.stride(0),
                                       stride_it=indices.stride(1),
                                       num_tokens=num_tokens,
                                       BLOCK_N=BLOCK_N,
                                       num_warps=1)

    return outputs


================================================
FILE: lmdeploy/pytorch/kernels/cuda/pagedattention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/ModelTC/lightllm
import math
from typing import Literal, Sequence

import torch
import triton
import triton.language as tl
from packaging import version
from torch import Tensor

from lmdeploy.utils import get_logger

from .utils import get_device_props

logger = get_logger('lmdeploy')

TRITON_VERSION = version.parse(triton.__version__)
VERSION_300 = version.parse('3.0.0')

assert TRITON_VERSION >= version.parse('2.2.0')

# TODO: fast op might not work on non-nv device
if TRITON_VERSION >= VERSION_300:
    tanh = tl.extra.cuda.libdevice.tanh
    fast_dividef = tl.extra.cuda.libdevice.fast_dividef
    tl_log2 = tl.log2
    tl_exp2 = tl.exp2
else:
    tanh = tl.math.tanh
    fast_dividef = tl.math.fast_dividef
    tl_log2 = tl.math.log2
    tl_exp2 = tl.math.exp2


@triton.jit
def _fwd_grouped_split_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    sm_scale: tl.constexpr,
    cache_seqlens_ptr,
    page_table_ptr,
    acc_out_ptr,
    alibi_slopes_ptr,
    stride_qbs: tl.constexpr,
    stride_qh: tl.constexpr,
    stride_qd: tl.constexpr,
    stride_kp: tl.constexpr,
    stride_kbs: tl.constexpr,
    stride_kh: tl.constexpr,
    stride_kd: tl.constexpr,
    stride_vp: tl.constexpr,
    stride_vbs: tl.constexpr,
    stride_vh: tl.constexpr,
    stride_vd: tl.constexpr,
    stride_ok: tl.constexpr,
    stride_obs: tl.constexpr,
    stride_oh: tl.constexpr,
    stride_od: tl.constexpr,
    stride_boffb,
    kv_group_num: tl.constexpr,
    seq_len: tl.constexpr,
    window_size: tl.constexpr,
    head_size: tl.constexpr,
    head_size_v: tl.constexpr,
    num_heads_q: tl.constexpr,
    logit_softcapping: tl.constexpr,
    shared_kv: tl.constexpr,
    SPLIT_K: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_DV: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_H: tl.constexpr,
    BLOCK_DMODEL1: tl.constexpr,
):
    """First step kernel of split k attention."""
    cur_batch = tl.program_id(2)
    tile_id = tl.program_id(0)
    split_k_id = tl.program_id(1)

    HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len
    TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H)
    subtile_id = tile_id % TILES_PER_GROUP
    cur_kv_head = tile_id // TILES_PER_GROUP
    offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H)
    cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num
    cur_token = cur_batch * seq_len + offs_h // kv_group_num

    mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num
    mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len)
    mask_h = mask_h & (cur_head < num_heads_q)

    q_seqlen = 1
    kv_seqlen = tl.load(cache_seqlens_ptr + cur_batch)
    if kv_seqlen <= 0:
        return
    history_len = kv_seqlen - q_seqlen
    if alibi_slopes_ptr is not None:
        alibi_slopes = tl.load(alibi_slopes_ptr + cur_head, mask=mask_h, other=1.0) * tl_log2(math.e)
    else:
        alibi_slopes = None

    # initialize offsets
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    mask_d = offs_d < head_size
    offs_d = offs_d % head_size
    offs_dv = tl.arange(0, BLOCK_DV)
    mask_dv = offs_dv < head_size_v
    offs_dv = offs_dv % head_size_v
    off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
    off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)

    off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
    q = tl.load(q_ptr + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)

    k_ptrs = k_ptr + off_k
    v_ptrs = v_ptr + off_v

    if BLOCK_DMODEL1 != 0:
        offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)
        mask_d1 = offs_d1 < head_size
        offs_d1 = offs_d1 % head_size
        off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
        q1 = tl.load(q_ptr + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)
        off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
        k1_ptrs = k_ptr + off_k1

    block_offset_ptrs = page_table_ptr + cur_batch * stride_boffb

    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_H], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_H], dtype=tl.float32)
    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)

    num_total_blocks = tl.cdiv(kv_seqlen, BLOCK_N)
    BLOCK_PER_CTA = tl.cdiv(num_total_blocks, SPLIT_K)
    kv_len_per_prog = BLOCK_PER_CTA * BLOCK_N
    loop_start = kv_len_per_prog * split_k_id
    loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen)

    # load block offset
    # dirty
    start_block_id = loop_start // BLOCK_N
    if window_size > 0:
        start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N
        kv_min_loc = tl.maximum(history_len - window_size, 0)

    loop_start = start_block_id * BLOCK_N
    block_offset_ptrs += start_block_id
    for start_n in range(loop_start, loop_end, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        b_offset = tl.load(block_offset_ptrs)
        block_offset_ptrs += 1

        # -- compute qk ----
        k = tl.load(k_ptrs + b_offset * stride_kp)
        if BLOCK_DMODEL1 != 0:
            k1 = tl.load(k1_ptrs + b_offset * stride_kp)

        if shared_kv:
            v = k.trans(1, 0)
        else:
            v = tl.load(v_ptrs + b_offset * stride_vp)

        qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        if BLOCK_DMODEL1 != 0:
            qk += tl.dot(q1, k1)
        qk *= sm_scale
        if logit_softcapping > 0.0:
            qk = qk / logit_softcapping
            qk = tanh(qk)
            qk = qk * logit_softcapping
        qk = qk * tl_log2(math.e)
        # NOTE: inf - inf = nan, and nan will leads to error
        if start_n + BLOCK_N > history_len or window_size > 0:
            qk_mask = history_len >= (start_n + offs_n)
            if window_size > 0:
                qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)
            qk = tl.where(
                qk_mask[None, :],
                qk,
                -float('inf'),
            )

        if alibi_slopes_ptr is not None:
            relative_pos = kv_seqlen - start_n - offs_n[None, :]
            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slopes[:, None]
            qk += bias

        # -- compute p, m_i and l_i
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        p = tl_exp2(qk - m_i_new[:, None])
        alpha = tl_exp2(m_i - m_i_new)
        l_i_new = alpha * l_i + tl.sum(p, 1)

        # -- update output accumulator --
        # scale acc
        acc = acc * alpha[:, None]

        # update acc
        p, v = _convert_pv(p, v)
        acc += tl.dot(p, v)
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    # initialize pointers to output
    if loop_end > loop_start:
        off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
                   offs_dv[None, :] * stride_od)
        tl.store(acc_out_ptr + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])

    off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
    tl.store(acc_out_ptr + off_meta, m_i, mask=mask_h)
    tl.store(acc_out_ptr + off_meta + 1, l_i, mask=mask_h)


@triton.jit
def _fwd_grouped_split_quant_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    KScalesZeros,
    VScalesZeros,
    sm_scale,
    cache_seqlens_ptr,
    page_table_ptr,
    acc_out_ptr,
    alibi_slopes_ptr,
    stride_qbs: tl.constexpr,
    stride_qh: tl.constexpr,
    stride_qd: tl.constexpr,
    stride_kp: tl.constexpr,
    stride_kbs: tl.constexpr,
    stride_kh: tl.constexpr,
    stride_kd: tl.constexpr,
    stride_vp: tl.constexpr,
    stride_vbs: tl.constexpr,
    stride_vh: tl.constexpr,
    stride_vd: tl.constexpr,
    stride_kszp: tl.constexpr,
    stride_kszbs: tl.constexpr,
    stride_kszh: tl.constexpr,
    stride_kszd: tl.constexpr,
    stride_vszp: tl.constexpr,
    stride_vszbs: tl.constexpr,
    stride_vszh: tl.constexpr,
    stride_vszd: tl.constexpr,
    quant_policy: tl.constexpr,
    stride_ok: tl.constexpr,
    stride_obs: tl.constexpr,
    stride_oh: tl.constexpr,
    stride_od: tl.constexpr,
    stride_boffb,
    kv_group_num: tl.constexpr,
    window_size: tl.constexpr,
    head_size: tl.constexpr,
    head_size_v: tl.constexpr,
    num_heads_q: tl.constexpr,
    logit_softcapping: tl.constexpr,
    SPLIT_K: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_DV: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_H: tl.constexpr,
    BLOCK_DMODEL1: tl.constexpr,
):
    """First step kernel of split k attention.

    Args:
        stride_xp: stride of page num dim
        stride_xbs: stride of block size dim
        stride_h: stride of head num dim
        stride_d: stride of head size dim
    """
    cur_batch = tl.program_id(2)
    cur_kv_head = tl.program_id(0)
    split_k_id = tl.program_id(1)

    if BLOCK_H < kv_group_num:
        HEAD_PER_CTA: tl.constexpr = BLOCK_H
    else:
        HEAD_PER_CTA: tl.constexpr = kv_group_num
    cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
    mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
    mask_h = mask_h & (cur_head < num_heads_q)
    if BLOCK_H < kv_group_num:
        cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num

    q_seqlen = 1
    kv_seqlen = tl.load(cache_seqlens_ptr + cur_batch)
    if kv_seqlen <= 0:
        return
    history_len = kv_seqlen - q_seqlen
    if alibi_slopes_ptr is not None:
        alibi_slopes = tl.load(alibi_slopes_ptr + cur_head, mask=mask_h, other=1.0) * tl_log2(math.e)
    else:
        alibi_slopes = None

    # initialize offsets
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    offs_dsz = tl.arange(0, 1)
    mask_d = offs_d < head_size
    offs_d = offs_d % head_size
    offs_dv = tl.arange(0, BLOCK_DV)
    mask_dv = offs_dv < head_size_v
    offs_dv = offs_dv % head_size_v
    off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
    off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)
    off_ksz = (cur_kv_head * stride_kszh + offs_dsz[:, None] * stride_kszd + offs_n[None, :] * stride_kszbs)
    off_vsz = (cur_kv_head * stride_vszh + offs_dsz[None, :] * stride_vszd + offs_n[:, None] * stride_vszbs)

    off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
    q = tl.load(q_ptr + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)

    ksz_ptrs = KScalesZeros + off_ksz
    vsz_ptrs = VScalesZeros + off_vsz

    if BLOCK_DMODEL1 != 0:
        offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)
        mask_d1 = offs_d1 < head_size
        offs_d1 = offs_d1 % head_size
        off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
        q1 = tl.load(q_ptr + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)
        off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)

    block_offset_ptrs = page_table_ptr + cur_batch * stride_boffb

    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_H], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_H], dtype=tl.float32)
    if quant_policy == 4:
        if BLOCK_DMODEL1 != 0:
            offs_d1 = BLOCK_DMODEL // 2 + tl.arange(0, BLOCK_DMODEL1)
            shift_k1d = (offs_d1 // (head_size // 2) * 4)[:, None]
            offs_d1 = offs_d1 % (head_size // 2)
            off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
        offs_d = tl.arange(0, BLOCK_DMODEL) % (head_size // 2)
        shift_kd = (tl.arange(0, BLOCK_DMODEL) // (head_size // 2) * 4)[:, None]
        off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
        offs_dv = tl.arange(0, BLOCK_DV * 2) % head_size_v
        shift_vd = (tl.arange(0, BLOCK_DV * 2) // head_size_v * 4)
        off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)
        acc = tl.zeros([BLOCK_H, BLOCK_DV * 2], dtype=tl.float32)  # v head_dim packed
        mask_dv = tl.arange(0, BLOCK_DV * 2) < (head_size_v * 2)
        offs_dv = tl.arange(0, BLOCK_DV * 2) % (head_size_v * 2)
    else:
        acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)

    num_total_blocks = tl.cdiv(kv_seqlen, BLOCK_N)
    BLOCK_PER_CTA = tl.cdiv(num_total_blocks, SPLIT_K)
    kv_len_per_prog = BLOCK_PER_CTA * BLOCK_N
    loop_start = kv_len_per_prog * split_k_id
    loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen)

    # load block offset
    # dirty
    start_block_id = loop_start // BLOCK_N
    if window_size > 0:
        start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N
        kv_min_loc = tl.maximum(history_len - window_size, 0)

    loop_start = start_block_id * BLOCK_N
    for start_n in range(loop_start, loop_end, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N)

        # -- compute qk ----
        # k = tl.load(k_ptrs + b_offset * stride_kp)
        k = tl.load(k_ptr + off_k + b_offset * stride_kp)
        if quant_policy == 4:
            k = (k >> shift_kd) & 0x0F
        ks = tl.load(ksz_ptrs + b_offset * stride_kszp)
        kz = tl.load(ksz_ptrs + b_offset * stride_kszp + 1)
        if BLOCK_DMODEL1 != 0:
            k1 = tl.load(k_ptr + off_k1 + b_offset * stride_kp)
            if quant_policy == 4:
                k1 = (k1 >> shift_k1d) & 0x0F
            k1 = ((k1 - kz) * ks).to(q.dtype)

        if quant_policy == 4:
            v = tl.load(v_ptr + off_v + b_offset * stride_vp)
            v = (v >> shift_vd) & 0x0F
        else:
            v = tl.load(v_ptr + off_v + b_offset * stride_vp)
        vs = tl.load(vsz_ptrs + b_offset * stride_vszp)
        vz = tl.load(vsz_ptrs + b_offset * stride_vszp + 1)

        k = ((k - kz) * ks).to(q.dtype)
        v = ((v - vz) * vs).to(q.dtype)
        qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        if BLOCK_DMODEL1 != 0:
            qk += tl.dot(q1, k1)
        qk *= sm_scale
        if logit_softcapping > 0.0:
            qk = qk / logit_softcapping
            qk = tanh(qk)
            qk = qk * logit_softcapping
        qk = qk * tl_log2(math.e)
        # NOTE: inf - inf = nan, and nan will leads to error
        if start_n + BLOCK_N > history_len or window_size > 0:
            qk_mask = history_len >= (start_n + offs_n)
            if window_size > 0:
                qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)
            qk = tl.where(
                qk_mask[None, :],
                qk,
                -float('inf'),
            )

        if alibi_slopes_ptr is not None:
            relative_pos = kv_seqlen - start_n - offs_n[None, :]
            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slopes[:, None]
            qk += bias

        # -- compute p, m_i and l_i
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        p = tl_exp2(qk - m_i_new[:, None])
        alpha = tl_exp2(m_i - m_i_new)
        l_i_new = alpha * l_i + tl.sum(p, 1)

        # -- update output accumulator --
        # scale acc
        acc = acc * alpha[:, None]

        # update acc
        p, v = _convert_pv(p, v)
        acc += tl.dot(p, v)
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    # initialize pointers to output
    if loop_end > loop_start:
        off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
                   offs_dv[None, :] * stride_od)
        tl.store(acc_out_ptr + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])

    if quant_policy == 4:
        off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v * 2)
    else:
        off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
    tl.store(acc_out_ptr + off_meta, m_i, mask=mask_h)
    tl.store(acc_out_ptr + off_meta + 1, l_i, mask=mask_h)


@triton.jit
def _reduce_split_kernel(
    acc_ptr,
    out_ptr,
    sinks_ptr,
    stride_ak,
    stride_abs,
    stride_ah,
    stride_ad,
    stride_obs,
    stride_oh,
    stride_od,
    head_size_v: tl.constexpr,
    SPLIT_K: tl.constexpr,
    BLOCK_DV: tl.constexpr,
):
    """Second step kernel of split k attention."""
    cur_batch = tl.program_id(1)
    cur_head = tl.program_id(0)

    # initialize offsets
    offs_dv = tl.arange(0, BLOCK_DV)
    offs_k = tl.arange(0, SPLIT_K)
    mask_dv = offs_dv < head_size_v

    offs_acc = (cur_batch * stride_abs + cur_head * stride_ah + offs_k[:, None] * stride_ak +
                offs_dv[None, :] * stride_ad)
    offs_mi = (cur_batch * stride_abs + cur_head * stride_ah + stride_ak * offs_k + head_size_v)

    m_k = tl.load(acc_ptr + offs_mi)
    l_k = tl.load(acc_ptr + offs_mi + 1)
    acc_k = tl.load(acc_ptr + offs_acc, mask=mask_dv[None, :] & (m_k[:, None] > -float('inf')), other=0.0)

    m_max = tl.max(m_k, 0)
    alpha = tl_exp2(m_k - m_max)
    acc_k = acc_k * alpha[:, None]
    l_k = l_k * alpha

    acc = tl.sum(acc_k, 0)
    l_sum = tl.sum(l_k, 0)

    if sinks_ptr is not None:
        sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype)
        l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max)
    acc = acc / l_sum

    out_offs = (cur_batch * stride_obs + cur_head * stride_oh + offs_dv * stride_od)
    tl.store(out_ptr + out_offs, acc, mask=mask_dv)


@triton.jit
def _convert_pv(p, v):
    """Convert pv."""
    p = p.to(v.dtype)
    return p, v


_nv_cap = None


def _kernel_meta_default(BLOCK_DMODEL: int, BLOCK_H: int):
    """Kernel meta default."""
    return 4, 2


def _kernel_meta_sm8x(BLOCK_DMODEL: int, BLOCK_H: int):
    """Kernel meta default."""
    num_stages = 2
    if BLOCK_DMODEL * BLOCK_H > 8192:
        num_warps = 8
    else:
        num_warps = 4
    return num_warps, num_stages


def _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int):
    """Kernel meta default."""
    num_warps = 4
    if BLOCK_DMODEL * BLOCK_H > 4096:
        num_stages = 2
    else:
        num_stages = 3
    return num_warps, num_stages


def _get_split_k(device_idx: int, head_grid: int, batch_size: int, num_warps: int):
    """Get split k."""
    props = get_device_props(device_idx)
    num_sm = props['multi_processor_count']
    # estimated occupancy 12.5%
    warps_per_sm = props['warps_per_sm'] // 8
    cta_per_sm = triton.cdiv(warps_per_sm, num_warps)
    cta_per_device = num_sm * cta_per_sm

    SPLIT_K = triton.cdiv(cta_per_device // head_grid, triton.next_power_of_2(batch_size))
    SPLIT_K = 1 << (SPLIT_K.bit_length() - 1)
    max_split = 1 << (num_sm.bit_length() - 1)
    SPLIT_K = max(min(SPLIT_K, max_split), 4)
    return SPLIT_K


def flash_attn_with_kvcache(
    q: Tensor,
    k_cache: Tensor,
    v_cache: Tensor,
    cache_seqlens: Tensor,
    page_table: Tensor,
    cu_seqlens_q: Tensor = None,  # not used, for align with fa
    max_seqlen_q: int = None,
    softmax_scale: float = None,
    causal: bool = False,  # not used, for align with fa
    window_size: int = None,
    softcap: float = None,
    scheduler_metadata: Tensor = None,  # not used, for align with fa
    # args not in fa
    alibi_slopes: Tensor = None,
    k_scales_zeros: Tensor = None,
    v_scales_zeros: Tensor = None,
    quant_policy: Literal[0, 4, 8] = 0,
    sinks: Tensor = None,
    kv_layout: str = 'bshd',
):
    """Paged Attention forward.

    Note that this kernel is decoding-only
    """

    global _nv_cap
    if _nv_cap is None:
        _nv_cap = torch.cuda.get_device_capability()

    if kv_layout == 'bshd':
        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
    elif kv_layout == 'bhsd':
        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
    else:
        raise RuntimeError('Unsupported layout.')

    if window_size is None:
        window_size = -1
    elif isinstance(window_size, Sequence):
        window_size = window_size[0]

    if softcap is None:
        softcap = -1.0

    shared_kv = k_cache.data_ptr() == v_cache.data_ptr()

    def _get_block_d(Lk):
        """Get block d."""
        BLOCK_DMODEL = triton.next_power_of_2(Lk)
        BLOCK_DMODEL1 = 0
        if BLOCK_DMODEL != Lk:
            BLOCK_DMODEL = BLOCK_DMODEL // 2
            BLOCK_DMODEL1 = max(16, triton.next_power_of_2(Lk - BLOCK_DMODEL))
        BLOCK_DV = triton.next_power_of_2(Lv)
        return BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV

    # shape constraints
    Lq, Lk, Lv = q.shape[-1], k_cache.shape[d_dim], v_cache.shape[d_dim]
    if quant_policy == 4:
        assert Lq == Lk * 2
        o = q.new_empty(q.shape[:-1] + (Lv * 2, ))
    else:
        assert Lq == Lk
        o = q.new_empty(q.shape[:-1] + (Lv, ))

    if softmax_scale is None:
        softmax_scale = 1.0 / (Lq**0.5)
    batch, head = cache_seqlens.shape[0], q.shape[-2]
    num_tokens = q.shape[-3]
    num_kv_heads = k_cache.shape[h_dim]
    kv_group_num = head // num_kv_heads

    if sinks is not None:
        assert sinks.is_contiguous()
        assert sinks.numel() == head

    BLOCK = k_cache.size(s_dim)
    assert BLOCK >= 16
    if Lq > 512 and BLOCK > 32:
        logger.warning(f'`head_dim={Lq}` and `block_size={BLOCK}` '
                       'might leads to bad performance. '
                       'Please reduce `block_size`.')

    valid = num_tokens % batch == 0
    assert valid, 'we only support decoding paged attention.'
    seq_len = num_tokens // batch
    if max_seqlen_q is not None:
        assert max_seqlen_q == seq_len, 'we only support decoding paged attention.'

    BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
    HEADS_PER_REQ = kv_group_num * seq_len
    BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ)))
    TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H)
    grid_1 = TILES_PER_GROUP * num_kv_heads

    if _nv_cap[0] < 8:
        num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)
    elif _nv_cap[0] < 9:
        num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H)
    else:
        num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H)

    SPLIT_K = _get_split_k(q.device.index, grid_1, batch, num_warps)

    if quant_policy != 4:
        acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32)
    else:
        acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)

    grid = (
        grid_1,
        SPLIT_K,
        batch,
    )

    if quant_policy > 0:
        _fwd_grouped_split_quant_kernel[grid](q,
                                              k_cache,
                                              v_cache,
                                              k_scales_zeros,
                                              v_scales_zeros,
                                              softmax_scale,
                                              cache_seqlens,
                                              page_table,
                                              acc,
                                              alibi_slopes,
                                              stride_qbs=q.stride(-3),
                                              stride_qh=q.stride(-2),
                                              stride_qd=q.stride(-1),
                                              stride_kp=k_cache.stride(b_dim),
                                              stride_kbs=k_cache.stride(s_dim),
                                              stride_kh=k_cache.stride(h_dim),
                                              stride_kd=k_cache.stride(d_dim),
                                              stride_vp=v_cache.stride(b_dim),
                                              stride_vbs=v_cache.stride(s_dim),
                                              stride_vh=v_cache.stride(h_dim),
                                              stride_vd=v_cache.stride(d_dim),
                                              stride_kszp=k_scales_zeros.stride(b_dim),
                                              stride_kszbs=k_scales_zeros.stride(s_dim),
                                              stride_kszh=k_scales_zeros.stride(h_dim),
                                              stride_kszd=k_scales_zeros.stride(d_dim),
                                              stride_vszp=v_scales_zeros.stride(b_dim),
                                              stride_vszbs=v_scales_zeros.stride(s_dim),
                                              stride_vszh=v_scales_zeros.stride(h_dim),
                                              stride_vszd=v_scales_zeros.stride(d_dim),
                                              quant_policy=quant_policy,
                                              stride_ok=acc.stride(-2),
                                              stride_obs=acc.stride(-4),
                                              stride_oh=acc.stride(-3),
                                              stride_od=acc.stride(-1),
                                              stride_boffb=page_table.stride(0),
                                              kv_group_num=kv_group_num,
                                              window_size=window_size,
                                              head_size=Lq,
                                              head_size_v=Lv,
                                              num_heads_q=head,
                                              logit_softcapping=softcap,
                                              SPLIT_K=SPLIT_K,
                                              BLOCK_DMODEL=BLOCK_DMODEL,
                                              BLOCK_DV=BLOCK_DV,
                                              BLOCK_N=BLOCK,
                                              BLOCK_H=BLOCK_H,
                                              BLOCK_DMODEL1=BLOCK_DMODEL1,
                                              num_warps=num_warps,
                                              num_stages=num_stages)

    else:
        _fwd_grouped_split_kernel[grid](q,
                                        k_cache,
                                        v_cache,
                                        softmax_scale,
                                        cache_seqlens,
                                        page_table,
                                        acc,
                                        alibi_slopes,
                                        stride_qbs=q.stride(-3),
                                        stride_qh=q.stride(-2),
                                        stride_qd=q.stride(-1),
                                        stride_kp=k_cache.stride(b_dim),
                                        stride_kbs=k_cache.stride(s_dim),
                                        stride_kh=k_cache.stride(h_dim),
                                        stride_kd=k_cache.stride(d_dim),
                                        stride_vp=v_cache.stride(b_dim),
                                        stride_vbs=v_cache.stride(s_dim),
                                        stride_vh=v_cache.stride(h_dim),
                                        stride_vd=v_cache.stride(d_dim),
                                        stride_ok=acc.stride(-2),
                                        stride_obs=acc.stride(-4),
                                        stride_oh=acc.stride(-3),
                                        stride_od=acc.stride(-1),
                                        stride_boffb=page_table.stride(0),
                                        kv_group_num=kv_group_num,
                                        seq_len=seq_len,
                                        window_size=window_size,
                                        head_size=Lk,
                                        head_size_v=Lv,
                                        num_heads_q=head,
                                        logit_softcapping=softcap,
                                        shared_kv=shared_kv,
                                        SPLIT_K=SPLIT_K,
                                        BLOCK_DMODEL=BLOCK_DMODEL,
                                        BLOCK_DV=BLOCK_DV,
                                        BLOCK_N=BLOCK,
                                        BLOCK_H=BLOCK_H,
                                        BLOCK_DMODEL1=BLOCK_DMODEL1,
                                        num_warps=num_warps,
                                        num_stages=num_stages)

    num_warps = 2
    grid = (head, num_tokens)
    if quant_policy == 4:
        Lv *= 2
        BLOCK_DV *= 2
    _reduce_split_kernel[grid](acc,
                               o,
                               sinks,
                               stride_ak=acc.stride(2),
                               stride_abs=acc.stride(0),
                               stride_ah=acc.stride(1),
                               stride_ad=acc.stride(3),
                               stride_obs=o.stride(0),
                               stride_oh=o.stride(1),
                               stride_od=o.stride(2),
                               SPLIT_K=SPLIT_K,
                               head_size_v=Lv,
                               BLOCK_DV=BLOCK_DV,
                               num_warps=num_warps,
                               num_stages=1)
    return o


================================================
FILE: lmdeploy/pytorch/kernels/cuda/rms_norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor

from .utils import get_device_props


@triton.jit
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
    """Compute rms norm."""
    xf = x.to(tl.float32)

    var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
    out = xf * tl.math.rsqrt(var + eps)
    out = w * out.to(x.dtype)
    return out


@triton.jit
def add_rms_norm_kernel(input, weight, residual, output, out_residual, num_feats, num_groups, stride_ib, stride_ih,
                        stride_id: tl.constexpr, stride_rb, stride_rh, stride_rd: tl.constexpr, stride_ob, stride_oh,
                        stride_od: tl.constexpr, stride_rob, stride_roh, stride_rod: tl.constexpr,
                        has_residual: tl.constexpr, eps: tl.constexpr, N_COLS: tl.constexpr, BLOCK_N: tl.constexpr,
                        NUM_STAGES: tl.constexpr):
    """Rms norm kernel."""
    prog_id = tl.program_id(0)
    prog_stride = tl.num_programs(0)
    offsets = tl.arange(0, BLOCK_N)
    mask = offsets < N_COLS

    w = tl.load(weight + offsets, mask=mask)

    x_ptrs = input + offsets * stride_id
    res_ptrs = residual + offsets * stride_rd
    out_res_ptrs = out_residual + offsets * stride_rod
    out_ptrs = output + offsets * stride_od
    for idx in tl.range(prog_id, num_feats, prog_stride, num_stages=NUM_STAGES):
        batch_id = idx // num_groups
        head_id = idx % num_groups
        cur_x_ptrs = x_ptrs + batch_id * stride_ib + head_id * stride_ih
        cur_res_ptrs = res_ptrs + batch_id * stride_rb + head_id * stride_rh
        cur_out_ptrs = out_ptrs + batch_id * stride_ob + head_id * stride_oh
        cur_out_res_ptrs = out_res_ptrs + batch_id * stride_rob + head_id * stride_roh
        x = tl.load(cur_x_ptrs, mask=mask)
        if has_residual:
            res = tl.load(cur_res_ptrs, mask=mask)
            x += res
            tl.store(cur_out_res_ptrs, x, mask=mask)
        out = _compute_rms_norm(x, w, eps, N_COLS)
        tl.store(cur_out_ptrs, out, mask=mask)


def _unsqueeze_to_3d(tensor: Tensor) -> Tensor:
    """Unsqueeze tensor to 3d."""
    if tensor.dim() == 3:
        return tensor
    elif tensor.dim() == 2:
        return tensor.unsqueeze(0)
    elif tensor.dim() == 1:
        return tensor.unsqueeze(0).unsqueeze(0)
    else:
        raise ValueError(f'Unsupported tensor dim {tensor.dim()}')


def _squeeze_to_origin_dim(tensor: Tensor, origin_dim: int) -> Tensor:
    """Squeeze tensor to origin dim."""
    if origin_dim == 3:
        return tensor
    elif origin_dim == 2:
        return tensor.squeeze(0)
    elif origin_dim == 1:
        return tensor.squeeze(0).squeeze(0)
    else:
        raise ValueError(f'Unsupported origin dim {origin_dim}')


def rms_norm(hidden_states: Tensor,
             weight: Tensor,
             eps: float = 1e-6,
             residual: Tensor = None,
             out: Tensor = None,
             out_residual: Tensor = None):
    """Rms norm."""
    assert hidden_states.dim() <= 3
    assert weight.stride(-1) == 1
    feat_size = weight.shape[0]
    assert hidden_states.size(-1) == feat_size

    origin_dim = hidden_states.dim()
    if out is None:
        out = torch.empty_like(hidden_states)
    has_residual = residual is not None
    if has_residual:
        if out_residual is None:
            out_residual = torch.empty_like(residual)
    else:
        residual = hidden_states
        out_residual = out

    shape = hidden_states.shape
    assert residual.shape == shape
    assert out.shape == shape
    assert out_residual.shape == shape

    hidden_states = _unsqueeze_to_3d(hidden_states)
    residual = _unsqueeze_to_3d(residual)
    out = _unsqueeze_to_3d(out)
    out_residual = _unsqueeze_to_3d(out_residual)

    num_feats = hidden_states.numel() // hidden_states.size(-1)

    BLOCK_N = triton.next_power_of_2(feat_size)

    props = get_device_props(hidden_states.device.index)
    num_sm = props['multi_processor_count']
    warps_per_sm = props['warps_per_sm']
    blocks_per_sm = props['blocks_per_sm']
    num_warps = min(triton.cdiv(BLOCK_N, 2048), 4)
    cta_per_sm = min(blocks_per_sm, warps_per_sm // num_warps)
    cta_per_device = num_sm * cta_per_sm
    num_stages = 1

    grid = (min(num_feats, cta_per_device), )
    add_rms_norm_kernel[grid](
        hidden_states,
        weight,
        residual,
        out,
        out_residual,
        num_feats=num_feats,
        num_groups=hidden_states.size(1),
        stride_ib=hidden_states.stride(0),
        stride_ih=hidden_states.stride(1),
        stride_id=hidden_states.stride(2),
        stride_rb=residual.stride(0),
        stride_rh=residual.stride(1),
        stride_rd=residual.stride(2),
        stride_ob=out.stride(0),
        stride_oh=out.stride(1),
        stride_od=out.stride(2),
        stride_rob=out_residual.stride(0),
        stride_roh=out_residual.stride(1),
        stride_rod=out_residual.stride(2),
        has_residual=has_residual,
        eps=eps,
        N_COLS=feat_size,
        BLOCK_N=BLOCK_N,
        NUM_STAGES=num_stages,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    out = _squeeze_to_origin_dim(out, origin_dim)
    out_residual = _squeeze_to_origin_dim(out_residual, origin_dim)
    if has_residual:
        return out, out_residual
    return out


if __name__ == '__main__':
    import time

    def torch_forward(hidden_states, weight, variance_epsilon=1e-6):
        """Pytorch forward."""
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
        return weight * hidden_states.to(input_dtype)

    def test_rms_norm(bsz, ctx_len, feat_len, dtype):
        """Test rms norm."""
        input = torch.empty((bsz, ctx_len, feat_len), dtype=dtype, device='cuda').normal_(mean=0., std=0.5).contiguous()
        weight = torch.empty((feat_len), dtype=dtype, device='cuda').normal_(mean=0., std=0.5).contiguous()
        triton_output = rms_norm(hidden_states=input, weight=weight)
        torch_output = torch_forward(hidden_states=input, weight=weight)
        assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0)

        N_REPEATS = 20

        t0 = time.time()
        for _ in range(N_REPEATS):
            torch_forward(hidden_states=input, weight=weight)

        t1 = time.time()
        for _ in range(N_REPEATS):
            rms_norm(hidden_states=input, weight=weight)
        t2 = time.time()

        torch_cost = (t1 - t0) / N_REPEATS * 1000
        triton_cost = (t2 - t1) / N_REPEATS * 1000
        print('input {} weight {} dtype {}\n  torch {:.3f} triton {:.3f} (ms)\n'.format(
            input.shape, weight.shape, dtype, torch_cost, triton_cost))

    test_rms_norm(1, 8128, 5120, torch.float16)
    test_rms_norm(1, 8128, 5120, torch.float32)
    test_rms_norm(1, 992, 128, torch.float16)
    test_rms_norm(1, 65537, 128, torch.float32)


================================================
FILE: lmdeploy/pytorch/kernels/cuda/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools

import torch
import triton
from packaging import version

WARPS_PER_SM = {
    (8, 0): 64,
    (8, 6): 48,
    (8, 7): 48,
    (8, 9): 48,
    (9, 0): 64,
    (10, 0): 64,
    (10, 1): 48,
    (11, 0): 48,
    (12, 0): 48,
}

BLOCKS_PER_SM = {
    (8, 0): 32,
    (8, 6): 16,
    (8, 7): 16,
    (8, 9): 24,
    (9, 0): 32,
    (10, 0): 32,
    (10, 1): 24,
    (11, 0): 24,
    (12, 0): 24,
}

TRITON_VERSION = version.parse(triton.__version__)


@functools.lru_cache
def get_device_props(device=None):
    if device is None:
        device = torch.cuda.current_device()

    props = torch.cuda.get_device_properties(device)

    warps_per_sm = WARPS_PER_SM.get((props.major, props.minor), 32)
    blocks_per_sm = BLOCKS_PER_SM.get((props.major, props.minor), warps_per_sm // 2)
    out = dict(
        multi_processor_count=props.multi_processor_count,
        warps_per_sm=warps_per_sm,
        blocks_per_sm=blocks_per_sm,
    )
    return out


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == 'cuda'


@functools.lru_cache
def supports_tma():
    ret = is_cuda() and torch.cuda.get_device_capability()[0] >= 9
    if not ret:
        return False

    VALID_VERSION = version.parse('3.4.0')
    return TRITON_VERSION >= VALID_VERSION


if supports_tma():
    from triton.tools.tensor_descriptor import TensorDescriptor  # noqa: F401


================================================
FILE: lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import torch
import triton
import triton.language as tl

from .activation import silu_and_mul
from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize, moe_reduce
from .w8a8_triton_kernels import per_token_quant_int8


def get_cuda_autotune_config():
    return [
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 256,
            'BLOCK_SIZE_K': 32,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 64,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 64,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=4,
                      num_warps=4),
        triton.Config({
            'BLOCK_SIZE_M': 128,
            'BLOCK_SIZE_N': 128,
            'BLOCK_SIZE_K': 128,
            'GROUP_SIZE_M': 1,
        },
                      num_stages=3,
                      num_warps=8),
    ]


@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['N', 'K', 'M_NP2'],
)
@triton.jit
def fused_moe_w8a8_kernel(
    A,
    A_scale,
    B,
    B_scale,
    C,
    SortedIdx,
    ExpStart,
    ExpEnd,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am: tl.constexpr,
    stride_ak: tl.constexpr,
    stride_be: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_bk: tl.constexpr,
    stride_bse: tl.constexpr,
    stride_cm: tl.constexpr,
    stride_cn: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    M_NP2: tl.constexpr,
    top_k: tl.constexpr,
    expert_offset: tl.constexpr,
    reindex_a: tl.constexpr,
    reindex_c: tl.constexpr,
    ACCUMULATOR_DTYPE: tl.constexpr,
):
    """Fused moe kernel."""
    exp_id = tl.program_id(1)
    pid = tl.program_id(0)

    exp_start = tl.load(ExpStart + exp_id + expert_offset)
    exp_end = tl.load(ExpEnd + exp_id + expert_offset)
    M = exp_end - exp_start
    if M <= 0:
        return

    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

    if GROUP_SIZE_M == 1:
        pid_m = pid % num_pid_m
        pid_n = pid // num_pid_m
    else:
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
        return

    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    mask_sid = offs_sid < exp_end
    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)

    offs_k = tl.arange(0, BLOCK_SIZE_K)
    if reindex_a:
        offs_am = sid // top_k
    else:
        offs_am = offs_sid
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    as_ptrs = A_scale + offs_am
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

    # deepseek has 160 experts, exp index would overflow int32
    exp_id = exp_id.to(tl.int64)
    exp_off = stride_be * exp_id
    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    bs_ptrs = B_scale + exp_id * stride_bse + offs_bn

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, acc=accumulator, out_dtype=ACCUMULATOR_DTYPE)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    ascale = tl.load(as_ptrs, mask=mask_sid)
    bscale = tl.load(bs_ptrs)
    c = accumulator.to(ascale.dtype)
    c = c * ascale[:, None] * bscale[None, :]

    c = c.to(C.dtype.element_ty)

    if reindex_c:
        offs_cm = sid
    else:
        offs_cm = offs_sid
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
    tl.store(c_ptrs, c, mask=mask_sid[:, None])


def fused_moe_w8a8_kernel_launcher(
    A: torch.Tensor,
    A_scale: torch.Tensor,
    B: torch.Tensor,
    B_scale: torch.Tensor,
    C: torch.Tensor,
    sorted_idx: torch.Tensor,
    exp_start: torch.Tensor,
    exp_end: torch.Tensor,
    top_k: int = 1,
    num_tokens: int = None,
    expert_offset: int = 0,
    reindex_a: bool = True,
    reindex_c: bool = True,
):
    """Fused moe kernel launcher."""

    if num_tokens is None:
        num_tokens = A.size(0)
    M_NP2 = triton.next_power_of_2(num_tokens)
    M_NP2 = max(64, M_NP2)
    E, N, K = B.shape

    assert A_scale.is_contiguous()
    assert B_scale.is_contiguous()
    accumulator_dtype = tl.float32 if A.is_floating_point() else tl.int32

    def _grid_fn(META):
        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)
        return grid

    A = A.flatten(0, -2)
    C = C.flatten(0, -2)

    grid = _grid_fn
    fused_moe_w8a8_kernel[grid](
        A,
        A_scale,
        B,
        B_scale,
        C,
        sorted_idx,
        exp_start,
        exp_end,
        N=N,
        K=K,
        stride_am=A.stride(0),
        stride_ak=A.stride(1),
        stride_be=B.stride(0),
        stride_bn=B.stride(1),
        stride_bk=B.stride(2),
        stride_bse=B_scale.stride(0),
        stride_cm=C.stride(0),
        stride_cn=C.stride(1),
        top_k=top_k,
        expert_offset=expert_offset,
        reindex_a=reindex_a,
        reindex_c=reindex_c,
        M_NP2=M_NP2,
        ACCUMULATOR_DTYPE=accumulator_dtype,
    )


def fused_moe_w8a8(input: torch.Tensor,
                   input_scale: torch.Tensor,
                   w1: torch.Tensor,
                   w1_scale: torch.Tensor,
                   w2: torch.Tensor,
                   w2_scale: torch.Tensor,
                   topk_weights: torch.Tensor,
                   topk_ids: torch.Tensor,
                   topk: int,
                   out_dtype: torch.dtype = torch.float16,
                   quant_dtype: torch.dtype = torch.int8,
                   expert_offset: int = 0,
                   num_experts: int = None,
                   renormalize: bool = False) -> torch.Tensor:
    """Fused moe."""
    device = input.device
    M = input.size(0)
    E, N, _ = w1.shape
    if num_experts is None:
        num_experts = E
    full_exp = num_experts == E

    topk_weights = _renormalize(topk_weights, renormalize)
    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)

    intermediate_cache1 = _make_intermediate((M, topk, N), dtype=out_dtype, device=device, zeros=not full_exp)
    # gate and up
    fused_moe_w8a8_kernel_launcher(
        input,
        input_scale,
        w1,
        w1_scale,
        intermediate_cache1,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        top_k=topk,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=True,
        reindex_c=False,
    )

    # activate
    unflat_size = intermediate_cache1.shape[:-1]
    intermediate_cache1 = intermediate_cache1.flatten(0, -2)
    gate_cache = silu_and_mul(intermediate_cache1)
    del intermediate_cache1
    gate_cache = gate_cache.unflatten(0, unflat_size)
    gate_cache, gate_scale = per_token_quant_int8(gate_cache, 1e-7, quant_dtype=quant_dtype)

    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), dtype=out_dtype, device=device, zeros=not full_exp)
    # down
    fused_moe_w8a8_kernel_launcher(
        gate_cache,
        gate_scale,
        w2,
        w2_scale,
        intermediate_cache2,
        sorted_idx=sorted_idx,
        exp_start=exp_start,
        exp_end=exp_end,
        top_k=1,
        num_tokens=M,
        expert_offset=expert_offset,
        reindex_a=False,
        reindex_c=True,
    )

    ret = moe_reduce(intermediate_cache2, topk_weights)
    return ret


================================================
FILE: lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from packaging import version

from ..default.w8a8_kernels import per_channel_quant

TRITON_VERSION = version.parse(triton.__version__)
if TRITON_VERSION >= version.parse('3.0.0'):
    tl_round = tl.extra.cuda.libdevice.round
else:
    tl_round = tl.math.round


@triton.autotune(
    configs=[
        triton.Config({
            'BLOCK_M': 128,
            'BLOCK_N': 256,
            'BLOCK_K': 128,
        }, num_stages=3, num_warps=8),
        triton.Config({
            'BLOCK_M': 256,
            'BLOCK_N': 128,
            'BLOCK_K': 128,
        }, num_stages=3, num_warps=8)
    ],
    key=['N', 'K'],
)
@triton.jit(do_not_specialize=['M'])
def _linear(
    A,
    B,
    C,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    rms_scale_ptr,
    linear_scale_ptr,
    ACCUMULATOR_DTYPE: tl.constexpr,
):
    """Triton-accelerated function used to perform linear operations (dot
    product) on input tensors `A` and `B`, and store the result in output
    tensor `C`.

    The function applies auto-tuning for optimal performance and uses Just-in- Time compilation.
    """

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
        accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    c = accumulator.to(tl.float32)

    rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]
    linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]
    c = c * rms_scale * linear_scale

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


@triton.autotune(
    configs=[
        triton.Config({
            'BLOCK_M': 128,
            'BLOCK_N': 256,
            'BLOCK_K': 128,
        }, num_stages=3, num_warps=8),
        triton.Config({
            'BLOCK_M': 256,
            'BLOCK_N': 128,
            'BLOCK_K': 128,
        }, num_stages=3, num_warps=8)
    ],
    key=['N', 'K'],
)
@triton.jit(do_not_specialize=['M'])
def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
                rms_scale_ptr, linear_scale_ptr, ACCUMULATOR_DTYPE: tl.constexpr):
    """Triton-accelerated function used to perform a linear operation (dot
    product) on input tensors `A` and `B`, with addition of residual.

    The result is stored in tensor `C`. The function applies auto-tuning for optimal performance and uses Just-in-Time
    compilation.
    """

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
        accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    c = accumulator.to(tl.float32)

    rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]
    linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]
    c = c * rms_scale * linear_scale
    c = c.to(residual_ptr.dtype.element_ty)

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    residual_ptrs = (residual_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :])
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    residual = tl.load(residual_ptrs, mask=c_mask, other=0.)
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    tl.store(c_ptrs, c + residual, mask=c_mask)


def matmul_kernel_dynamic_quant(a, b, rms_scale, linear_scale, residual=None, bias=None, output_dtype=torch.float16):
    """This function performs matrix multiplication with dynamic quantization.

    It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a
    `residual` tensor and a `bias`. The output is returned in the specified `output_dtype`.
    """

    assert a.shape[-1] == b.shape[-1]
    assert b.ndim == 2 and b.is_contiguous()
    M = a.numel() // a.shape[-1]
    N, K = b.shape
    c_shape = a.shape[:-1] + (N, )
    if residual is not None:
        assert residual.shape == c_shape
        assert residual.is_contiguous()
    c = a.new_empty(c_shape, dtype=output_dtype)
    accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32

    def grid(META):
        return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )

    if residual is not None:
        _linear_add[grid](a,
                          b,
                          c,
                          residual,
                          M,
                          N,
                          K,
                          a.stride(-2),
                          a.stride(-1),
                          b.stride(1),
                          b.stride(0),
                          c.stride(-2),
                          c.stride(-1),
                          GROUP_SIZE_M=8,
                          rms_scale_ptr=rms_scale,
                          linear_scale_ptr=linear_scale,
                          ACCUMULATOR_DTYPE=accumulator_dtype)
    else:
        _linear[grid](a,
                      b,
                      c,
                      M,
                      N,
                      K,
                      a.stride(-2),
                      a.stride(-1),
                      b.stride(1),
                      b.stride(0),
                      c.stride(-2),
                      c.stride(-1),
                      GROUP_SIZE_M=8,
                      rms_scale_ptr=rms_scale,
                      linear_scale_ptr=linear_scale,
                      ACCUMULATOR_DTYPE=accumulator_dtype)
    if bias is not None:
        c += bias

    return c


@triton.jit
def _per_token_quant_int8(
        y_ptr,
        y_q_ptr,
        y_s_ptr,
        y_stride: tl.constexpr,
        yq_stride: tl.constexpr,
        N,  # number of columns in X
        eps: tl.constexpr,  # epsilon to avoid division by zero
        BLOCK: tl.constexpr,
        Q_MAX: tl.constexpr,
        IS_FLOATING_POINT: tl.constexpr,  # True for floating point dtype
):
    """A Triton-accelerated function to perform per-token quantization on a
    tensor.

    This function converts the tensor values into signed 8-bit integers.
    """
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    y_ptr += row * y_stride
    y_q_ptr += row * yq_stride
    y_s_ptr += row

    cols = tl.arange(0, BLOCK)  # N <= BLOCK
    mask = cols < N

    y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)
    # Quant
    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
    y_s = _absmax / Q_MAX
    y_q = y / y_s
    if not IS_FLOATING_POINT:
        y_q = tl_round(y_q).to(tl.int8)

    tl.store(y_q_ptr + cols, y_q, mask=mask)
    tl.store(y_s_ptr, y_s)


def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
    """Function to perform per-token quantization on an input tensor `x`.

    It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling
    factor used for quantization.
    """
    qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
    q_max = qdtype_info.max
    x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype)
    M = x.numel() // x.shape[-1]
    N = x.shape[-1]
    x_s = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)
    BLOCK = triton.next_power_of_2(N)
    # heuristics for number of warps
    num_warps = min(max(BLOCK // 256, 1), 8)

    if x.dim() > 2:
        x = x.flatten(0, -2)
    assert x.stride(-1) == 1
    # enqueue kernel
    _per_token_quant_int8[(M, )](x,
                                 x_q,
                                 x_s,
                                 y_stride=x.stride(-2),
                                 yq_stride=x_q.stride(-2),
                                 N=N,
                                 eps=eps,
                                 BLOCK=BLOCK,
                                 Q_MAX=q_max,
                                 IS_FLOATING_POINT=quant_dtype.is_floating_point,
                                 num_warps=num_warps)

    return x_q, x_s


@triton.jit
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
    """Compute rms norm."""
    xf = x.to(tl.float32)

    var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
    out = xf * tl.math.rsqrt(var + eps)
    out = (w * out).to(x.dtype)
    return out


@triton.jit
def rms_norm_quant_kernel(
    input,
    weight,
    output,
    out_scale,
    input_row_stride: tl.constexpr,
    eps: tl.constexpr,
    N_COLS: tl.constexpr,
    BLOCK_N: tl.constexpr,
    Q_MIN: tl.constexpr,
    Q_MAX: tl.constexpr,
    IS_FLOATING_POINT: tl.constexpr,
):
    """Rms norm kernel."""
    prog_id = tl.program_id(0)
    offsets = tl.arange(0, BLOCK_N)

    w = tl.load(weight + offsets, mask=offsets < N_COLS)

    x_ptr = input + prog_id * input_row_stride
    x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
    out = _compute_rms_norm(x, w, eps, N_COLS)

    scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
    out_s_ptr = out_scale + prog_id
    tl.store(out_s_ptr, scale)
    out = out / scale
    if not IS_FLOATING_POINT:
        out = tl_round(out)
    out = tl.clamp(out, Q_MIN, Q_MAX)
    out_ptr = output + prog_id * input_row_stride
    tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)


@triton.jit
def add_rms_norm_quant_kernel(
    input,
    weight,
    residual,
    output,
    out_scale,
    out_residual,
    input_row_stride: tl.constexpr,
    residual_row_stride: tl.constexpr,
    eps: tl.constexpr,
    N_COLS: tl.constexpr,
    BLOCK_N: tl.constexpr,
    Q_MIN: tl.constexpr,
    Q_MAX: tl.constexpr,
    IS_FLOATING_POINT: tl.constexpr,
):
    """Rms norm kernel."""
    prog_id = tl.program_id(0)
    offsets = tl.arange(0, BLOCK_N)

    w = tl.load(weight + offsets, mask=offsets < N_COLS)

    x_ptr = input + prog_id * input_row_stride
    x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)

    res_ptr = residual + prog_id * residual_row_stride
    res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)

    new_x = x + res
    out_res_ptr = out_residual + prog_id * residual_row_stride
    tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)

    out = _compute_rms_norm(new_x, w, eps, N_COLS)

    scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
    out_s_ptr = out_scale + prog_id
    tl.store(out_s_ptr, scale)
    out = out / scale
    if not IS_FLOATING_POINT:
        out = tl_round(out)
    out = tl.clamp(out, Q_MIN, Q_MAX)
    out_ptr = output + prog_id * input_row_stride
    tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)


def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8):
    """Performs RMS normalization with dynamic quantization.

    The function reshapes the input tensor `x`, creates an empty tensor `y` with the same shape as `x`, and calculates
    RMS normalization on the reshaped `x` using a Triton kernel `rms_norm_quant_kernel`.
    """
    qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
    y = torch.empty_like(x, dtype=quant_dtype)
    scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)

    feat_size = w.shape[0]
    seq_len = x.numel() // x.size(-1)
    input_stride = x.stride(-2)
    BLOCK_N = triton.next_power_of_2(feat_size)
    grid = (seq_len, )

    if residual is None:
        rms_norm_quant_kernel[grid](x,
                                    w,
                                    y,
                                    scale,
                                    input_row_stride=input_stride,
                                    eps=eps,
                                    N_COLS=feat_size,
                                    BLOCK_N=BLOCK_N,
                                    Q_MIN=qdtype_info.min,
                                    Q_MAX=qdtype_info.max,
                                    IS_FLOATING_POINT=quant_dtype.is_floating_point,
                                    num_warps=4,
                                    num_stages=2)
        return y, scale
    else:
        out_residual = torch.empty_like(x)
        res_stride = residual.stride(-2)
        add_rms_norm_quant_kernel[grid](x,
                                        w,
                                        residual,
                                        y,
                                        scale,
                                        out_residual,
                                        input_row_stride=input_stride,
                                        residual_row_stride=res_stride,
                                        eps=eps,
                                        N_COLS=feat_size,
                                        BLOCK_N=BLOCK_N,
                                        Q_MIN=qdtype_info.min,
                                        Q_MAX=qdtype_info.max,
                                        IS_FLOATING_POINT=quant_dtype.is_floating_point,
                                        num_warps=4,
                                        num_stages=2)
        return y, scale, out_residual


def test_rms_and_linear(x, rms_weight, linear_weight, output_dtype=torch.float16, quant_dtype=torch.int8, eps=1e-5):
    """Test quantized rms norm and quantized linear layer."""

    def rms_norm_torch(x, w, eps):
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + eps)
        return w * x

    def linear_torch(x, b):
        return F.linear(x, b)

    linear_weight_quant, linear_scale = per_channel_quant(linear_weight, quant_dtype)

    rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps, quant_dtype=quant_dtype)
    assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1]
    linear_out = matmul_kernel_dynamic_quant(rms_out,
                                             linear_weight_quant,
                                             rms_scale,
                                             linear_scale,
                                             output_dtype=output_dtype)

    rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
    linear_out_torch = linear_torch(rms_out_torch, linear_weight)
    print(f'linear_out.abs().mean() = {linear_out.abs().mean()}')
    print(f'linear_out_torch.abs().mean() = {linear_out_torch.abs().mean()}')
    print('perchannel error: ', (linear_out - linear_out_torch).abs().mean())
    cos = torch.nn.CosineSimilarity(0)
    print('Output cos', cos(linear_out.flatten().to(torch.float32), linear_out_torch.flatten().to(torch.float32)))


def test_per_token_quant(x, eps, quant_dtype=torch.int8):
    """Test per-token quantization."""

    def per_token_quant_int8_torch(x, eps, quant_dtype):
        qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)

        _absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps)
        x_s = _absmax / qdtype_info.max
        x_q = x / x_s
        if not quant_dtype.is_floating_point:
            x_q = x_q.round()
        x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max)
        return x_q, x_s

    x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype)
    x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps, quant_dtype=quant_dtype)
    assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape
    cos = torch.nn.CosineSimilarity(0)
    print('x_q cos', cos(x_q.flatten().to(torch.float32), x_q_torch.flatten().to(torch.float32)))
    print('x_s cos', cos(x_s.flatten().to(torch.float32), x_s_torch.flatten().to(torch.float32)))


def bench_rms_and_linear(M: int, provider: str, dtype: torch.dtype = torch.float16, eps: float = 1e-5):
    """Benchmark rms and linear."""

    def rms_norm_torch(x, w, eps):
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + eps)
        return w * x

    def linear_torch(x, b):
        return F.linear(x, b)

    N = 4096
    K = 4096

    x_shape = (M, K)
    rms_w_shape = (x_shape[-1], )
    rms_weight = torch.randn(rms_w_shape, dtype=dtype, device='cuda', requires_grad=True)
    x = torch.randn(x_shape, dtype=dtype, device='cuda')
    linear_weight = torch.randn((N, K), dtype=dtype, device='cuda', requires_grad=True)

    if provider == 'torch_fp16':
        rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()

        def y_fwd():
            linear_torch(rms_out_torch, linear_weight)
    else:
        if provider == 'triton_int8':
            quant_dtype = torch.int8
        elif provider == 'triton_fp8_e4m3':
            quant_dtype = torch.float8_e4m3fn
        elif provider == 'triton_fp8_e5m2':
            quant_dtype = torch.float8_e5m2

        linear_weight_quant, linear_scale = per_channel_quant(linear_weight, quant_dtype)

        alpha = max(x.max().abs(), x.min().abs())
        if quant_dtype.is_floating_point:
            qdtype_info = torch.finfo(quant_dtype)
        else:
            qdtype_info = torch.iinfo(quant_dtype)
        rms_scale = alpha / qdtype_info.max
        rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps, quant_dtype=quant_dtype)

        def y_fwd():

            matmul_kernel_dynamic_quant(rms_out, linear_weight_quant, rms_scale, linear_scale, output_dtype=dtype)

    quantiles = [0.5, 0.2, 0.8]
    ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)

    def perf(ms):
        return 2 * M * N * K * 1e-12 / (ms * 1e-3)

    return perf(ms), perf(max_ms), perf(min_ms)


if __name__ == '__main__':
    torch.manual_seed(0)
    device_map = torch.cuda.get_device_capability()
    is_fp8_supported = device_map[0] >= 9
    dtype = torch.float16
    # test (bs, seq_len, dim) x (dim, out_dim)
    x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda')
    rms_weight = torch.randn((4096, ), dtype=dtype, device='cuda', requires_grad=True)

    linear_weight = torch.randn((11008, 4096), dtype=dtype, device='cuda', requires_grad=True)
    test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
    if is_fp8_supported:
        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e4m3fn)
        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e5m2)

    # test (M, K) x (K, N)
    x = torch.randn((4, 4096), dtype=dtype, device='cuda')
    rms_weight = torch.randn((4096, ), dtype=dtype, device='cuda', requires_grad=True)

    linear_weight = torch.randn((2048, 4096), dtype=dtype, device='cuda', requires_grad=True)
    test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
    if is_fp8_supported:
        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e4m3fn)
        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e5m2)

    # test per-token quant
    x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda')
    eps = 1e-7
    test_per_token_quant(x, eps, quant_dtype=torch.int8)
    if is_fp8_supported:
        test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn)
        test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2)

    # benchmark triton kernels
    line_vals = ['triton_int8', 'torch_fp16']
    line_names = ['triton_int8', 'torch_fp16']

    if is_fp8_supported:
        line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
        line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
    config = triton.testing.Benchmark(x_names=['M'],
                                      x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 5)],
                                      line_arg='provider',
                                      line_vals=line_vals,
                                      line_names=line_names,
                                      styles=[('blue', '-'), ('green', '-'), ('orange', '-'), ('black', '-'),
                                              ('yellow', '-')],
                                      ylabel='TFLOPS',
                                      plot_name='bench-triton',
                                      args={
                                          'dtype': torch.float16,
                                      })
    bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear)
    bench_funch.run(print_data=True)


================================================
FILE: lmdeploy/pytorch/kernels/default/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .multinomial_sampling import multinomial_sampling
from .w8a8_kernels import per_channel_quant

__all__ = [
    'multinomial_sampling',
    'per_channel_quant',
]


================================================
FILE: lmdeploy/pytorch/kernels/default/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import LongTensor, Tensor


def multinomial_sampling(scores: Tensor, seeds: LongTensor, offsets: LongTensor, indices: Tensor = None):
    sampled_index = torch.multinomial(scores, num_samples=1, replacement=True)
    outputs = torch.gather(indices, dim=1, index=sampled_index)
    return outputs.view(-1)


================================================
FILE: lmdeploy/pytorch/kernels/default/w8a8_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
    """Quantize the input tensor 'x' channel-wise using the given number of
    bits.

    Args:
        x (torch.Tensor): The input tensor to be quantized. Must be a
            2-dimensional tensor.
        dtype (torch.dtype): The data type to which the quantized tensor should
            be converted.

    Returns:
        tuple: A tuple containing two items -- the quantized tensor and
            the scale used for quantization.
    """
    assert x.ndim == 2
    x = x.to(torch.float32)
    x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
    qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
    q_max = qtype_info.max
    q_min = qtype_info.min
    scale = x_absmax / q_max
    x_q = x / scale
    if not dtype.is_floating_point:
        x_q = torch.round(x_q)
    x_q = x_q.clamp(q_min, q_max).to(dtype)
    return x_q, scale


================================================
FILE: lmdeploy/pytorch/kernels/dispatcher.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import inspect
from typing import Callable

from lmdeploy.utils import get_logger

from ..devices import DeviceContext, get_device_manager

logger = get_logger('lmdeploy')


def _default_api(*args, **kwargs):
    """Default api."""
    ...


class ParamParser:

    def __init__(self, param: inspect.Parameter) -> None:
        self.param = param

    def name(self):
        """name."""
        return self.param.name

    def func_arg(self):
        """Func arg."""
        param = self.param
        name = self.name()
        kind = param.kind
        ret = name
        if kind == inspect.Parameter.VAR_POSITIONAL:
            ret = f'*{name}'
        elif kind == inspect.Parameter.VAR_KEYWORD:
            ret = f'**{name}'

        default = param.default
        if default != inspect._empty:
            ret = f'{ret}={default}'

        return ret

    def func_input(self):
        """Func input."""
        param = self.param
        name = self.name()
        kind = param.kind
        ret = name
        if kind == inspect.Parameter.VAR_POSITIONAL:
            ret = f'*{name}'
        elif kind == inspect.Parameter.VAR_KEYWORD:
            ret = f'**{name}'
        else:
            ret = f'{name}={name}'
        return ret


class FunctionDispatcher:

    def __init__(self, func_name: str):
        self.device_manager = get_device_manager()
        self.impl_map: dict[str, Callable] = dict()
        self.func_name = func_name
        self.dispatched_func = self.load_and_call
        self.device_manager.register_context_callback(self.device_callback)
        self.device_map = {'cuda': 'cuda', 'ascend': 'dlinfer', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}

    def device_callback(self, context: DeviceContext):
        """Device context callback."""
        self.dispatched_func = self.load_and_call

    def load_func(self, device: str):
        """Load function."""
        try:
            mod = importlib.import_module(f'lmdeploy.pytorch.kernels.{device}')
            func = getattr(mod, self.func_name)
            self.impl_map[device] = func
        except Exception:
            logger.debug(f'Failed to load <{self.func_name}>'
                         f' for <{device}>, '
                         'try load default implementation.')
            mod = importlib.import_module('lmdeploy.pytorch.kernels.default')
            if not hasattr(mod, self.func_name):
                raise RuntimeError(f'<{self.func_name}> default and <{device}>'
                                   ' implementation not exists.')
            func = getattr(mod, self.func_name)
            self.impl_map[device] = func

    def load_and_call(self, *args, **kwargs):
        """Load and call."""
        device = self.device_manager.current_context().device_type
        if device not in self.impl_map:
            self.load_func(device)
        self.dispatched_func = self.impl_map[device]
        return self.dispatched_func(*args, **kwargs)

    def make_caller(self, api: Callable = _default_api, globals=None):
        """Make call function."""
        signature = inspect.signature(api)
        params = signature.parameters

        param_parsers = [ParamParser(p) for p in params.values()]
        func_args = [p.func_arg() for p in param_parsers]
        func_inputs = [p.func_input() for p in param_parsers]
        func_args = ', '.join(func_args)
        func_inputs = ', '.join(func_inputs)

        src = f"""
def {self.func_name}({func_args}):
    return dispatcher.dispatched_func({func_inputs})
"""   # noqa: E501

        scope = dict(dispatcher=self, )
        if globals is not None:
            scope.update(globals)
        exec(src, scope)
        return scope[f'{self.func_name}']


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..default import multinomial_sampling, per_channel_quant
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .flash_attention import flash_attention_fwd
from .fused_moe import DlinferMoECommType, DlinferMoeMetadata, fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm

__all__ = [
    'rms_norm',
    'apply_rotary_pos_emb',
    'awq_linear',
    'fill_kv_cache',
    'DlinferMoECommType',
    'DlinferMoeMetadata',
    'fused_moe',
    'paged_attention_fwd',
    'flash_attention_fwd',
    'linear',
    'moe_gating_topk_softmax',
    'multinomial_sampling',
    'per_channel_quant',
]


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def silu_and_mul(input_tensor: Tensor, ) -> Tensor:
    return ext_ops.silu_and_mul(input_tensor)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple

import dlinfer.ops as ext_ops
from torch import Tensor


def apply_rotary_pos_emb(
    query_states: Tensor,
    key_states: Tensor,
    cos: Tensor,
    sin: Tensor,
    q_embed: Optional[Tensor],
    k_embed: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
    query_states_embed, key_states_embed = \
        ext_ops.apply_rotary_pos_emb(query_states,
                                     key_states,
                                     cos, sin)
    if q_embed is None:
        q_embed = query_states_embed.view(query_states.shape)
    elif q_embed is not query_states:
        q_embed.copy_(query_states_embed.view(query_states.shape))

    if k_embed is None:
        k_embed = key_states_embed.view(key_states.shape)
    elif k_embed is not key_states:
        k_embed.copy_(key_states_embed.view(key_states.shape))

    return q_embed, k_embed


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/awq_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import dlinfer.ops as ext_ops
from torch import Tensor


def awq_linear(x: Tensor,
               qweight: Tensor,
               scales: Tensor,
               qzeros: Tensor,
               bias: Optional[Tensor] = None,
               all_reduce: bool = False,
               group_size: int = 0):
    return ext_ops.weight_quant_matmul(x.squeeze(0),
                                       qweight,
                                       scales,
                                       offset=qzeros,
                                       bias=bias,
                                       all_reduce=all_reduce,
                                       group_size=group_size).unsqueeze(0)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import dlinfer.ops as ext_ops
from torch import Tensor


def fill_kv_cache(
    key_states: Tensor,
    value_states: Tensor,
    key_caches: Tensor,
    value_caches: Tensor,
    kv_start_indices: Tensor,
    k_scales_zeros: Sequence[Optional[Tensor]],
    v_scales_zeros: Sequence[Optional[Tensor]],
    quant_bits: int = 0,
):
    """Fill key/value state to cache for paged attention."""
    return ext_ops.fill_kv_cache(key_states,
                                 value_states,
                                 key_caches,
                                 value_caches,
                                 kv_start_indices,
                                 k_scales_zeros=k_scales_zeros,
                                 v_scales_zeros=v_scales_zeros,
                                 quant_bits=quant_bits)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def flash_attention_fwd(
    query_states: Tensor,
    key_states: Tensor,
    value_states: Tensor,
    attn_output: Tensor,
    q_start_loc: Tensor,
    q_seqlens: Tensor,
    kv_start_loc: Tensor,
    kv_seqlens: Tensor,
    num_heads: int,
    num_kv_heads: int,
    max_q_seqlen: int = None,
    window_size: int = None,
    sm_scale: float = None,
    logit_softcapping: float = None,
    causal: bool = True,
):
    return ext_ops.prefill_attention(
        query_states,
        key_states,
        value_states,
        None,
        None,
        q_start_loc,
        q_seqlens,
        kv_seqlens,
        max_q_seqlen,
        num_heads,
        num_kv_heads,
        attn_mask=[],
        softmax_scale=sm_scale,
        attn_output=attn_output,
    )


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import MoECommType as DlinferMoECommType  # noqa: F401
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
from torch import Tensor


def fused_moe(
    hidden_states: Tensor,
    gate_up_weights: Tensor,
    down_weights: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    topk: int,
    renormalize: bool,
    moe_metadata: DlinferMoeMetadata,
):
    """Dlinfer fused moe."""
    return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize,
                             moe_metadata)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/fused_rotary_emb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def fused_rotary_emb(
    query_states: Tensor,
    key_states: Tensor,
    position_ids: torch.LongTensor,
    inv_freq: Tensor,
    scaling_factor: float,
    out_q: Tensor = None,
    out_k: Tensor = None,
    context=None,
):
    batch, seqlen, head, dim = query_states.shape
    num_kv_heads = key_states.shape[-2]
    query_states_reshaped = query_states.view(batch, seqlen, head, dim)
    key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim)
    position_ids = position_ids.squeeze(0).unsqueeze(-1)
    pos_freq = position_ids / scaling_factor * inv_freq
    if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
        cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1).repeat(1, 1, 1, 2).to(query_states.dtype))
        sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1).repeat(1, 1, 1, 2).to(query_states.dtype))
        if context:
            setattr(context, 'cos', cos)
            setattr(context, 'sin', sin)
    cached_cos = context.cos if context else cos
    cached_sin = context.sin if context else sin
    ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, None, None)
    if out_q is None:
        out_q = query_states
    else:
        out_q.copy_(query_states)
    if out_k is None:
        out_k = key_states
    else:
        out_k.copy_(key_states)
    return out_q, out_k


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/linear.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import dlinfer.ops as ext_ops
from torch import Tensor


def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, all_reduce: bool = False, group: str = ''):
    return ext_ops.linear(x, weight, bias=bias, all_reduce=all_reduce, group=group)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
from torch import Tensor


def moe_gating_topk_softmax(router_logits: Tensor, topk: int,
                            moe_metadata: DlinferMoeMetadata) -> Tuple[Tensor, Tensor]:
    routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk, moe_metadata)
    return routing_weights, selected_experts


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import dlinfer.ops as ext_ops
from torch import Tensor


def prefill_attention(
    query_states: Tensor,
    key_states: Tensor,
    value_states: Tensor,
    attn_output: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    block_offsets: Tensor,
    q_start_loc: Tensor,
    q_seq_len: Tensor,
    kv_seq_len: Tensor,
    cu_seq_lens_kv: Tensor,
    max_q_seq_len: int,
    max_kv_seq_len: int,
    block_size: int,
    num_q_heads: int,
    num_kv_heads: int,
    head_size_v: int,
    attn_mask: Sequence[Optional[Tensor]],
    softmax_scale: Optional[float],
    is_unpaged_prefill: Optional[bool],
    kv_scales: Optional[Tensor],
    kv_zeros: Optional[Tensor],
    quant_bits: Optional[int],
) -> Tensor:
    if is_unpaged_prefill:
        return ext_ops.prefill_attention(
            query_states,
            key_states,
            value_states,
            key_cache,
            value_cache,
            q_start_loc,
            q_seq_len,
            kv_seq_len,
            max_q_seq_len,
            num_q_heads,
            num_kv_heads,
            attn_mask,
            softmax_scale=softmax_scale,
            attn_output=attn_output,
        )
    else:
        return ext_ops.paged_prefill_attention(
            query_states,
            key_states,
            value_states,
            key_cache,
            value_cache,
            block_offsets,
            block_size,
            q_start_loc,
            q_seq_len,
            kv_seq_len,
            cu_seq_lens_kv,
            max_q_seq_len,
            max_kv_seq_len,
            num_q_heads,
            num_kv_heads,
            attn_mask,
            head_size_v=head_size_v,
            softmax_scale=softmax_scale,
            attn_output=attn_output,
            kv_scales=kv_scales,
            kv_zeros=kv_zeros,
            quant_bits=quant_bits,
        )


def paged_token_attention(
    q,
    k_cache,
    v_cache,
    attn_output,
    kv_seq_len,
    max_kv_seq_len,
    block_offsets,
    block_size,
    num_q_heads,
    num_kv_heads,
    head_size_v,
    softmax_scale: Optional[float],
    kv_scales: Optional[Tensor],
    kv_zeros: Optional[Tensor],
    quant_bits: Optional[int],
):
    return ext_ops.paged_decode_attention(
        q,
        k_cache,
        v_cache,
        block_offsets,
        block_size,
        kv_seq_len,
        max_kv_seq_len,
        num_q_heads,
        num_kv_heads,
        head_size_v=head_size_v,
        softmax_scale=softmax_scale,
        attn_output=attn_output,
        kv_scales=kv_scales,
        kv_zeros=kv_zeros,
        quant_bits=quant_bits,
    )


def paged_attention_fwd(
    query_states: Tensor,
    key_states: Tensor,
    value_states: Tensor,
    attn_output: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    block_offsets: Tensor,
    q_start_loc: Tensor,
    q_seqlens: Tensor,
    kv_seqlens: Tensor,
    cu_seq_lens_kv: Tensor,
    max_q_seq_len: int,
    max_kv_seq_len: int,
    is_decoding: bool,
    block_size: int,
    num_heads: int,
    num_kv_heads: int,
    v_head_size: int,
    attn_mask: Sequence[Optional[Tensor]] = (),
    softmax_scale: Optional[float] = None,
    is_unpaged_prefill: Optional[bool] = None,
    kv_scales: Optional[Tensor] = None,
    kv_zeros: Optional[Tensor] = None,
    quant_bits: Optional[int] = 0,
):
    if not is_decoding:
        return prefill_attention(
            query_states,
            key_states,
            value_states,
            attn_output,
            key_cache,
            value_cache,
            block_offsets,
            q_start_loc,
            q_seqlens,
            kv_seqlens,
            cu_seq_lens_kv,
            max_q_seq_len,
            max_kv_seq_len,
            block_size,
            num_heads,
            num_kv_heads,
            v_head_size,
            attn_mask,
            softmax_scale,
            is_unpaged_prefill,
            kv_scales=kv_scales,
            kv_zeros=kv_zeros,
            quant_bits=quant_bits,
        )
    else:
        return paged_token_attention(
            query_states,
            key_cache,
            value_cache,
            attn_output,
            kv_seqlens,
            max_kv_seq_len,
            block_offsets,
            block_size,
            num_heads,
            num_kv_heads,
            v_head_size,
            softmax_scale=softmax_scale,
            kv_scales=kv_scales,
            kv_zeros=kv_zeros,
            quant_bits=quant_bits,
        )


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/rms_norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None):
    if residual is None:
        rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon)
        if out is None:
            out = rms_norm_out
        else:
            out.copy_(rms_norm_out)
        return out
    else:
        return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon)


================================================
FILE: lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def dynamic_quant(x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = 'PER_TOKEN'):
    input_quant, input_scale = ext_ops.dynamic_quant(x, quant_dtype, quant_granularity)
    return input_quant, input_scale


def linear_w8a8(
    a: Tensor,
    b: Tensor,
    rms_scale: float,
    linear_scale: float,
    out_dtype: torch.dtype,
    quant_dtype: torch.dtype,
    bias=None,
):
    """This function performs matrix multiplication with dynamic quantization.

    It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a
    `bias`. The output is returned in the specified `output_dtype`.
    """
    return ext_ops.linear_w8a8(a, b, rms_scale, linear_scale, out_dtype, quant_dtype, bias)


def rms_norm_w8a8(
    hidden_states: Tensor,
    weight: Tensor,
    epsilon: float,
    quant_dtype: torch.dtype = torch.int8,
    residual: Tensor = None,
):
    """Rms norm kernel."""
    if residual is None:
        return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon, quant_dtype)
    else:
        return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, epsilon, quant_dtype)


================================================
FILE: lmdeploy/pytorch/kernels/w8a8_triton_kernels.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .dispatcher import FunctionDispatcher

per_channel_quant = FunctionDispatcher('per_channel_quant').make_caller()

matmul_kernel_dynamic_quant = FunctionDispatcher('matmul_kernel_dynamic_quant').make_caller()

per_token_quant_int8 = FunctionDispatcher('per_token_quant_int8').make_caller()

rms_norm_dynamic_quant = FunctionDispatcher('rms_norm_dynamic_quant').make_caller()


================================================
FILE: lmdeploy/pytorch/messages.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List

import numpy as np
import torch
from torch import Tensor

from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
from lmdeploy.utils import get_logger

from .block import LogicalTokenBlocks

if TYPE_CHECKING:
    from lmdeploy.pytorch.paging.scheduler import Scheduler
    from lmdeploy.pytorch.paging.seq_states.states import StateBase
    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
    from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy

logger = get_logger('lmdeploy')

# vlm input type from pipeline
InputEmbeddingType = List[np.ndarray]
InputEmbeddingRangeType = List[List[int]]


@dataclass
class InputEmbeddings:
    """InputEmbeddings."""
    embeddings: np.ndarray
    start: int
    end: int

    def move_position(self, offset: int = 0):
        if offset != 0:
            self.start += offset
            self.end += offset
        return self


@dataclass
class SamplingParam:
    """Sampling parameter."""
    top_p: float = 1.0
    top_k: int = 1
    min_p: float = 0.0
    temperature: float = 0.8
    repetition_penalty: float = 1.0
    ignore_eos: bool = False
    random_seed: int = None
    stop_words: List[int] = field(default_factory=list)
    bad_words: List[int] = field(default_factory=list)
    max_new_tokens: int = 512
    min_new_tokens: int = 0
    response_format: None | str = None
    logits_processors: None | List[LogitsProcessor] = None
    out_logits: bool = False
    out_last_hidden_states: bool = False
    num_logprobs: int = -1
    return_routed_experts: bool = False

    # ngram
    repetition_ngram_size: int = 0
    repetition_ngram_threshold: int = 0

    @classmethod
    def from_gen_config(cls, gen_config: GenerationConfig):
        """From gen config."""
        min_new_tokens = gen_config.min_new_tokens or 0

        stop_words = gen_config.stop_token_ids or []
        bad_words = gen_config.bad_token_ids or []
        if gen_config.ignore_eos:
            bad_words += stop_words
            stop_words = []

        top_k = gen_config.top_k
        top_p = gen_config.top_p
        min_p = gen_config.min_p
        temperature = gen_config.temperature
        repetition_penalty = gen_config.repetition_penalty
        max_new_tokens = gen_config.max_new_tokens
        response_format = gen_config.response_format

        output_logits = gen_config.output_logits
        if output_logits:
            if (output_logits != 'all' or gen_config.max_new_tokens > 0):
                output_logits = None
                logger.warning('Pytorch Engine only support output_logits="all"'
                               ' with max_new_tokens=0')
        if gen_config.output_last_hidden_state is not None:
            logger.warning('Pytorch Engine does not support output last hidden states.')
        if top_p < 0 or top_p > 1.0:
            logger.warning('`top_p` has to be a float > 0 and < 1'
                           f' but is {top_p}')
            top_p = 1.0
        if min_p < 0 or min_p > 1.0:
            logger.warning('`min_p` has to be a float > 0 and < 1'
                           f' but is {min_p}')
            min_p = 0.0
        if temperature == 0:
            logger.warning('`temperature` is 0, set top_k=1.')
            temperature = 1.0
            top_k = 1
        if temperature < 0:
            logger.warning('`temperature` has to be a strictly'
                           f' positive value, but is {temperature}')
            temperature = 1.0
        if repetition_penalty <= 0:
            logger.warning('`repetition_penalty` has to be a strictly'
                           f' positive value, but is {repetition_penalty}')
            repetition_penalty = 1.0
        if max_new_tokens < 0:
            logger.warning('`max_new_tokens` has to be a strictly'
                           f' positive value, but is {max_new_tokens}')
            max_new_tokens = 512
        if min_new_tokens < 0 or min_new_tokens > max_new_tokens:
            logger.warning('`min_new_tokens` has to be '
                           'a int >=0 and <= `max_new_tokens`,'
                           f' but is {min_new_tokens}')
            min_new_tokens = 0
        logprobs = gen_config.logprobs
        if logprobs is None:
            logprobs = -1

        random_seed = gen_config.random_seed
        if random_seed is None:
            import random
            random_seed = random.getrandbits(64)
        return SamplingParam(
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            ignore_eos=gen_config.ignore_eos,
            random_seed=random_seed,
            stop_words=stop_words,
            bad_words=bad_words,
            response_format=response_format,
            max_new_tokens=max_new_tokens,
            min_new_tokens=min_new_tokens,
            logits_processors=gen_config.logits_processors,
            out_logits=(output_logits is not None),
            num_logprobs=logprobs,
            return_routed_experts=gen_config.return_routed_experts,
            repetition_ngram_size=gen_config.repetition_ngram_size,
            repetition_ngram_threshold=gen_config.repetition_ngram_threshold,
        )


class MessageStatus(enum.Enum):
    """Status of a sequence."""

    WAITING = enum.auto()
    READY = enum.auto()
    STOPPED = enum.auto()
    RUNNING = enum.auto()

    # PD Disaggregation
    # MIGRATION_WAITING: state of Unmigrated Requests
    # in both prefill and decode engines are tagged by
    # MIGRATION_READY: state of Migrating Requests
    # in decode engine
    TO_BE_MIGRATED = enum.auto()
    MIGRATION_WAITING = enum.auto()
    MIGRATION_READY = enum.auto()
    MIGRATION_RUNNING = enum.auto()
    MIGRATION_DONE = enum.auto()


SeqMap = Dict[int, 'SchedulerSequence']


@dataclass
class SequenceMeta:
    """Meta data shared by all sequence."""
    block_size: int
    strategy: 'SequenceStrategy' = None
    sampling_strategy: 'SamplingStrategy' = None


class SequenceManager:
    """Sequence manager."""

    def __init__(self, seq_meta: SequenceMeta) -> None:
        self._seq_map: SeqMap = dict()
        self._status_seq_map: Dict[MessageStatus, SeqMap] = defaultdict(dict)

        self.seq_meta = seq_meta
        self._seq_count = 0

    def _new_seq_id(self):
        seq_id = self._seq_count
        self._seq_count += 1
        return seq_id

    def get_all_sequences(self):
        """Get all sequences."""
        return self._seq_map.values()

    def get_sequences(self, states: MessageStatus):
        """Get sequences."""
        return self._status_seq_map[states]

    def num_sequences(self, status: MessageStatus):
        """Num sequences."""
        return len(self.get_sequences(status))

    def add_sequence(self, seq: 'SchedulerSequence'):
        """Add sequence."""
        seq_id = seq.seq_id
        status = seq.status
        status_map = self._status_seq_map[status]
        self._seq_map[seq_id] = seq
        status_map[seq_id] = seq

    def remove_sequence(self, seq: 'SchedulerSequence'):
        """Remove sequence."""
        seq_id = seq.seq_id
        status = seq.status
        status_map = self._status_seq_map[status]
        self._seq_map.pop(seq_id)
        status_map.pop(seq_id)

    def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageStatus):
        """Update status."""
        old_status = seq.status
        if new_status == old_status:
            return
        seq_id = seq.seq_id
        old_status_map = self._status_seq_map[old_status]
        new_status_map = self._status_seq_map[new_status]
        # may be remove by async_end
        if seq_id in old_status_map:
            old_status_map.pop(seq_id)
            new_status_map[seq_id] = seq


def _to_ndarray(token_ids) -> np.ndarray:
    """To ndarray."""
    if isinstance(token_ids, Tensor):
        token_ids = token_ids.numpy()
    elif not isinstance(token_ids, np.ndarray):
        token_ids = np.array(token_ids)
    if token_ids.ndim == 0:
        token_ids = token_ids[None]
    return token_ids


class SchedulerSession:
    """Scheduler session."""

    def __init__(self, session_id: int, seq_manager: SequenceManager, scheduler: 'Scheduler') -> None:
        self.session_id = session_id
        self.seq_meta = seq_manager.seq_meta
        self.sequences: SeqMap = dict()
        self.seq_manager = seq_manager
        self.scheduler = scheduler

    def add_sequence(self,
                     token_ids: Tensor,
                     sampling_param: SamplingParam = None,
                     adapter_name: str = None,
                     multimodals: MultiModalInputs = None,
                     input_embeddings: List[InputEmbeddings] = None,
                     migration_request: None | MigrationRequest = None,
                     resp_cache: bool = False,
                     preserve_cache: bool = False) -> 'SchedulerSequence':
        """Add a new message."""
        from lmdeploy.pytorch.paging.seq_states.states import build_seq_state

        if sampling_param is None:
            sampling_param = SamplingParam()

        seq_id = self.seq_manager._new_seq_id()
        seq = self.seq_meta.strategy.make_sequence(seq_id=seq_id,
                                                   session=self,
                                                   sampling_param=sampling_param,
                                                   adapter_name=adapter_name,
                                                   migration_request=migration_request,
                                                   resp_cache=resp_cache,
                                                   preserve_cache=preserve_cache)
        seq.update_token_ids(
            token_ids,
            multimodals=multimodals,
            embeddings=input_embeddings,
            mode=UpdateTokenMode.INPUTS,
        )
        self.sequences[seq.seq_id] = seq

        # set status
        # update seq manager
        status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING
        seq.set_state(build_seq_state(self.scheduler, seq, status))
        self.seq_manager.add_sequence(seq)

        # metrics
        seq.record_event(EventType.QUEUED)

        return seq

    def remove_sequence(self, seq: 'SchedulerSequence'):
        """Remove sequence."""
        assert seq.seq_id in self.sequences
        seq.state.free()
        self.sequences.pop(seq.seq_id)
        self.seq_manager.remove_sequence(seq)


def _div_up(x, n):
    """Perform div up."""
    return (x + n - 1) // n


def _round_up(x, n):
    """Perform round up."""
    return _div_up(x, n) * n


class HistoryEmbeddings:
    """History embeddings."""

    def __init__(self, embeddings: List[InputEmbeddings] = None):
        self._embeddings: List[InputEmbeddings] = []
        if embeddings is not None:
            self._embeddings.extend(embeddings)

    def append(self, embeddings: List[InputEmbeddings]):
        self._embeddings.extend(embeddings)

    def clone(self):
        ret = HistoryEmbeddings(self._embeddings)
        return ret

    def copy(self):
        return self.clone()

    def get_step(self, step: int) -> int:
        """Get step before a whole image."""
        real_step = step
        num_all_images = len(self._embeddings)
        history_image_num = 0
        if num_all_images > 0:
            history_image_num = sum([1 for emb in self._embeddings if emb.end <= step])
            if history_image_num < num_all_images:
                emb = self._embeddings[history_image_num]
                # for case step in middle of an image
                if emb.start < step:
                    real_step = emb.start
        num_images = num_all_images - history_image_num
        return real_step, history_image_num, num_images

    @property
    def embeddings(self):
        """embeddings."""
        return self._embeddings

    def __len__(self):
        """Get num images."""
        return len(self._embeddings)

    def __getitem__(self, *args, **kwargs):
        """Get values."""
        return self._embeddings.__getitem__(*args, **kwargs)


class _HistoryDataBase:
    """Base class for history data storage."""
    ALLOC_SIZE = 512
    COPY_ON_RESIZE = False

    def __init__(self, data: np.ndarray = None, dtype: np.dtype = np.int64):
        self.dtype = dtype
        self._data = None
        self._num_real = 0

        if data is None:
            self._data = self._create_empty_array(dtype)
        else:
            self._data = data.astype(dtype) if hasattr(data, 'astype') else data
            self._num_real = len(data)

    def _create_empty_array(self, dtype):
        """Create empty array.

        Override in subclass for different shapes.
        """
        return np.empty((self.ALLOC_SIZE, ), dtype=dtype)

    def _get_pad_width(self, reserve_size: int):
        """Get pad width for np.pad.

        Override for multi-dimensional arrays.
        """
        return (0, reserve_size)

    def reserve(self, size: int):
        """Reserve cache."""
        if self._data is None:
            return
        num_tokens = len(self._data)
        if num_tokens >= size:
            return
        reserve_size = _round_up(size - num_tokens, self.ALLOC_SIZE)
        pad_width = self._get_pad_width(reserve_size)
        self._data = np.pad(self._data, pad_width)

    def get_real(self):
        """Get real data."""
        if self._data is None:
            return None
        return self._data[:self._num_real]

    def resize(self, size: int):
        """Set size."""
        assert size <= self._num_real
        self._num_real = size
        if self.COPY_ON_RESIZE and self._data is not None:
            self._data = self._data[:size].copy()

    def append(self, new_data: np.ndarray):
        """Append data."""
        if self._data is None:
            self._data = new_data.astype(self.dtype)
            self._num_real = len(new_data)
            return
        num_tokens = len(new_data)
        self.reserve(num_tokens + self._num_real)
        slice_start = self._num_real
        slice_end = slice_start + num_tokens
        self._num_real += num_tokens
        self._data[slice_start:slice_end] = new_data

    def __setitem__(self, *args, **kwargs):
        """Set values."""
        return self.get_real().__setitem__(*args, **kwargs)

    def __getitem__(self, *args, **kwargs):
        """Get values."""
        return self.get_real().__getitem__(*args, **kwargs)

    def __len__(self):
        """Get length."""
        return self._num_real

    def clone(self):
        """clone."""
        data = None if self._data is None else self.get_real().copy()
        ret = type(self)(data, dtype=self.dtype)
        return ret

    def copy(self):
        """copy."""
        return self.clone()


class HistoryTokenIds(_HistoryDataBase):
    """History token ids."""
    ALLOC_SIZE = 512

    def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = np.int64):
        super().__init__(token_ids, dtype)

    @property
    def _token_ids(self):
        """For backward compatibility."""
        return self._data

    @_token_ids.setter
    def _token_ids(self, value):
        """For backward compatibility."""
        self._data = value


class HistoryRouterExperts(_HistoryDataBase):
    """History router experts."""
    ALLOC_SIZE = 64
    COPY_ON_RESIZE = True

    def __init__(self, expert_ids: np.ndarray = None, dtype: np.dtype = np.uint16):
        super().__init__(expert_ids, dtype)

    def _create_empty_array(self, dtype):
        """Create empty array.

        Override in subclass for different shapes.
        """
        return None

    def _get_pad_width(self, reserve_size: int):
        """Get pad width for multi-dimensional array."""
        return ((0, reserve_size), (0, 0), (0, 0))


class HistoryLogits(_HistoryDataBase):
    """History logits."""
    ALLOC_SIZE = 64
    COPY_ON_RESIZE = True

    def __init__(self, logits: np.ndarray = None, dtype: np.dtype = np.int16):
        super().__init__(logits, dtype)
        self._torch_dtype = None

    def _create_empty_array(self, dtype):
        """Create empty array.

        Override in subclass for different shapes.
        """
        return None

    def _get_pad_width(self, reserve_size: int):
        """Get pad width for multi-dimensional array."""
        return ((0, reserve_size), (0, 0))

    def set_torch_dtype(self, torch_dtype):
        """Set torch dtype."""
        self._torch_dtype = torch_dtype

    def get_logits(self):
        """Get logits as torch tensor."""
        if self._data is None:
            return None
        if self._torch_dtype is None:
            return None

        logits_np = self.get_real()
        return torch.frombuffer(logits_np, dtype=self._torch_dtype).view(logits_np.shape)

    def clone(self):
        """clone."""
        ret = super().clone()
        ret.set_torch_dtype(self._torch_dtype)
        return ret


class HistoryMultiModals:

    def __init__(self, multimodals: MultiModalInputs = None):
        if multimodals is None:
            multimodals = dict()
        self.multimodals = multimodals

    def get_datas(self, start=0, end=-1):
        """Get multimodals from prompts position [start, end)."""
        outs: MultiModalInputs = dict()
        test_range = range(start, end)
        for modal_type, modal_datas in self.multimodals.items():
            data = []
            for modal_data in modal_datas:
                if (modal_data.start not in test_range and modal_data.end - 1 not in test_range):
                    continue
                data.append(modal_data)
            if len(data) > 0:
                outs[modal_type] = data
        return outs

    def add_inputs(self, input_mms: MultiModalInputs):
        """Add new inputs."""
        for modal_type, vals in input_mms.items():
            if modal_type in self.multimodals:
                self.multimodals[modal_type] += vals
            else:
                self.multimodals[modal_type] = vals

    def empty(self):
        if len(self.multimodals) == 0:
            return True

        return all(len(vals) == 0 for vals in self.multimodals)

    @staticmethod
    def update_multimodals(input_mms: MultiModalInputs, prev_len: int):
        """Update multimodals."""
        for vals in input_mms.values():
            for val in vals:
                val.start += prev_len
                val.end += prev_len
        return input_mms


class UpdateTokenMode(enum.Enum):
    """Update token mode."""
    INPUTS = enum.auto()
    PREFILL = enum.auto()
    DECODE = enum.auto()


@dataclass
class SchedulerSequence:
    """Scheduler message."""
    seq_id: int
    session: SchedulerSession
    history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds)
    history_embeddings: HistoryEmbeddings = field(default_factory=HistoryEmbeddings)
    history_multimodals: HistoryMultiModals = field(default_factory=HistoryMultiModals)
    num_new_tokens: int = 0
    sampling_param: SamplingParam = field(default_factory=SamplingParam)
    logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks)
    logical_state: int = -1
    adapter_name: str = None
    arrive_time: float = 0.0
    output_start_pos: int = 0
    meta: Any = None
    num_ignored_history: int = 0
    model_meta: Dict[str, Any] = None

    # For Disaggregation
    migration_request: None | MigrationRequest = None
    resp_cache: bool = False
    preserve_cache: bool = False

    # For logging
    engine_events: List[EngineEvent] = field(default_factory=list)

    # for router replay
    all_routed_experts: HistoryRouterExperts = field(default_factory=HistoryRouterExperts)

    # logits
    all_logits: HistoryLogits = field(default_factory=HistoryLogits)

    def __post_init__(self):
        """Post init."""
        self._seq_meta: SequenceMeta = self.session.seq_meta
        self._num_history_images: int = 0
        self._num_history_ids: int = 0
        self._num_token_ids: int = len(self.history_cache)

        # vlm
        self._num_images: int = len(self.history_embeddings)
        self._state = None

    @property
    def block_size(self) -> int:
        """Block size."""
        return self._seq_meta.block_size

    @property
    def history_image_num(self) -> int:
        """Get history image number."""
        return self._num_history_images

    @property
    def history_image_token_len(self) -> int:
        """Get history image token length."""
        return sum([emb.end - emb.start for emb in self.history_embeddings[:self._num_history_images]])

    @property
    def session_id(self) -> int:
        """Get session id."""
        return self.session.session_id

    @property
    def token_ids(self) -> np.ndarray:
        """Token ids."""
        start = self.num_history_ids
        end = start + self._num_token_ids
        return self.history_cache[start:end]

    @property
    def input_embeddings(self) -> List[InputEmbeddings]:
        """Get current embeddings."""
        start = self.history_image_num
        end = start + self._num_images
        return self.history_embeddings[start:end]

    @property
    def history_ids(self) -> np.ndarray:
        """History ids."""
        return self.history_cache[:self.num_history_ids]

    @property
    def all_ids(self) -> np.ndarray:
        """Full token ids."""
        return self.history_cache[:self.num_all_ids]

    @property
    def valid_ids(self) -> np.ndarray:
        """Valid token ids."""
        return self.history_cache[:self.num_valid_ids]

    @property
    def generated_ids(self) -> np.ndarray:
        end = self.num_valid_ids
        start = end - self.num_new_tokens
        return self.history_cache[start:end]

    @property
    def return_routed_experts(self) -> bool:
        return self.sampling_param.return_routed_experts

    @property
    def routed_experts(self) -> np.ndarray:
        if (not self.return_routed_experts) or self.all_routed_experts is None:
            return None

        end = max(0, self.num_all_ids - 1)
        if 0 < end <= len(self.all_routed_experts):
            return self.all_routed_experts.get_real()[:end]
        else:
            return None

    def append_routed_experts(self, routed_experts: Tensor | np.ndarray):
        """Append routed experts."""
        if not self.return_routed_experts:
            return
        if routed_experts is None:
            return
        if isinstance(routed_experts, Tensor):
            routed_experts = routed_experts.cpu().numpy()
        self.all_routed_experts.append(routed_experts)

    @property
    def num_history_ids(self):
        """Num history ids."""
        return self._num_history_ids

    @property
    def num_token_ids(self):
        return self._num_token_ids

    @property
    def num_valid_ids(self):
        return self._num_history_ids + self._num_token_ids

    @property
    def num_images(self):
        return self._num_images

    @property
    def num_all_ids(self):
        """Num all tokens."""
        return self._num_history_ids + self._num_token_ids

    @property
    def num_blocks(self):
        """Num blocks."""
        return len(self.logical_blocks)

    @property
    def state(self) -> 'StateBase':
        return self._state

    def set_state(self, state: 'StateBase'):
        """Set state."""
        self._state = state

    @property
    def status(self):
        return self.state.status

    @property
    def return_logits(self):
        return self.sampling_param.out_logits

    @property
    def logits(self):
        """Get logits."""
        return self.all_logits.get_logits()

    def append_logits(self, logits: Tensor | np.ndarray):
        """Append logits."""
        if not self.return_logits:
            return
        if logits is None:
            return
        if isinstance(logits, Tensor):
            self.all_logits.set_torch_dtype(logits.dtype)
            logits = logits.view(torch.int16).numpy()
        self.all_logits.append(logits)

    def get_input_multimodals(self):
        """Get input multimodals."""
        start = self.num_history_ids
        end = self.num_all_ids
        return self.history_multimodals.get_datas(start, end)

    def record_event(
        self,
        event_type: EventType,
        timestamp: None | float = None,
    ) -> None:
        self.engine_events.append(EngineEvent.new_event(event_type, timestamp))

    def _update_embeddings(self, embeddings: List[InputEmbeddings]):
        """Update input embeddings."""
        self._num_history_images += self._num_images
        if embeddings is None:
            self._num_images = 0
            return
        new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings]
        self._num_images = len(new_embeddings)
        self.history_embeddings.append(new_embeddings)

    def _update_multimodals(self, multimodals: MultiModalInputs):
        """Update input multimodals."""
        if multimodals is None:
            return
        multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids)
        self.history_multimodals.add_inputs(multimodals)

    def update_token_ids(self,
                         token_ids: Tensor,
                         multimodals: MultiModalInputs = None,
                         embeddings: List[InputEmbeddings] = None,
                         model_meta: Dict[str, Any] = None,
                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
                         **kwargs):
        """Update token ids, old token ids will be added to history."""
        raise NotImplementedError('NotImplemented')

    def set_step(self, step: int):
        """Set step."""
        raise NotImplementedError('NotImplemented')


================================================
FILE: lmdeploy/pytorch/model_inputs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

import numpy as np
import torch
import torch.distributed as torch_dist
from torch.profiler import record_function

# from torch import distributed as dist
import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig, QuantizationConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.utils import CtxMgrBase, singleton

if TYPE_CHECKING:
    from lmdeploy.pytorch.strategies.base import StrategyFactoryBase


@dataclass
class DPMeta:
    tp_sizes: List[int] = None
    moe_tp_sizes: List[int] = None

    @staticmethod
    def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist.DistContext, layer_type: str):
        """Gather tp size."""
        attn_tp = dist_ctx.dist_config.attn_tp
        if tp > 1 and tp != attn_tp:
            dist_group = dist.get_dist_group(layer_type=layer_type)
            gather_group = dist_group.gpu_gather_group
            ranks = torch_dist.get_process_group_ranks(gather_group)
            tp_sizes = [num_tokens[r] for r in ranks]
            assert all(size >= 0 for size in tp_sizes), (f'Invalid tp sizes: {tp_sizes}')
        else:
            tp_sizes = [seqlen]
        return tp_sizes

    @classmethod
    def build(cls, seqlen: int, num_tokens: List[int]):
        """Get dp meta."""
        dist_ctx = dist.get_dist_manager().current_context()
        dist_config = dist_ctx.dist_config

        mlp_tp = dist_config.mlp_tp
        tp_sizes = cls._gather_tp_sizes(mlp_tp, seqlen, num_tokens, dist_ctx, layer_type='mlp')

        moe_tp = dist_config.moe_tp
        if moe_tp == mlp_tp:
            moe_tp_sizes = tp_sizes
        else:
            moe_tp_sizes = cls._gather_tp_sizes(moe_tp, seqlen, num_tokens, dist_ctx, layer_type='moe')

        return DPMeta(tp_sizes=tp_sizes, moe_tp_sizes=moe_tp_sizes)

    def sync_tp_size(self, tp_size: int):
        self.tp_sizes = [tp_size] * len(self.tp_sizes)
        self.moe_tp_sizes = [tp_size] * len(self.moe_tp_sizes)


@dataclass
class VisionModelInputs:
    """Vision model inputs."""
    history_lengths: torch.LongTensor = None
    input_embeddings: List[List[torch.Tensor]] = None
    input_embedding_ranges: List[torch.LongTensor] = None
    input_embedding_indexing: torch.BoolTensor = None
    input_multimodals: List[MultiModalData] = None

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        out_dict = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if v is None:
                continue
            if isinstance(v, torch.Tensor):
                v = v.to(device, non_blocking=non_blocking)
            elif k == 'input_embedding_ranges':
                v = [e.to(device, non_blocking=non_blocking) for e in v]
            elif k == 'input_embeddings':
                v = [[e.to(device, non_blocking=non_blocking) for e in li] for li in v]
            elif k == 'input_multimodals':
                new_v = []
                for mm_datas in v:
                    new_mm_datas = dict()
                    for modal_type, data in mm_datas.items():
                        data = [d.to_device(device, non_blocking=non_blocking) for d in data]
                        new_mm_datas[modal_type] = data
                    new_v.append(new_mm_datas)
                v = new_v
            out_dict[k] = v

        return VisionModelInputs(**out_dict)

    def get_inputs(self, history_lengths: torch.Tensor, seq_lengths: torch.Tensor):
        """Get vision embedding inputs."""
        input_embeddings = None
        input_embedding_indexing = None
        if self.input_embeddings is not None and len(self.input_embeddings) > 0:
            input_embedding_li = []
            for (his_len, seq_len, embeddings, emb_ranges) in zip(history_lengths, seq_lengths, self.input_embeddings,
                                                                  self.input_embedding_ranges):
                for emb, (emb_start, emb_end) in zip(embeddings, emb_ranges):
                    start = max(emb_start, his_len) - emb_start
                    end = min(emb_end, his_len + seq_len) - emb_start
                    if 0 <= start < end:
                        input_embedding_li.append(emb[start:end])
            # has embeddings
            if len(input_embedding_li) > 0:
                input_embeddings = torch.cat(input_embedding_li, dim=0)
                device = input_embeddings.device
                starts = history_lengths - self.history_lengths
                ends = starts + seq_lengths
                input_embedding_indexing = torch.cat(
                    [indexing[s:e] for indexing, s, e in zip(self.input_embedding_indexing, starts, ends)], dim=0)
                index_ranges = torch.arange(input_embedding_indexing.numel(), device=device)
                input_embedding_indexing = index_ranges[input_embedding_indexing]
        return input_embeddings, input_embedding_indexing


@dataclass
class ModelInputsDelta:
    """Delta of ModelInputs."""
    # valid indices
    indices: Optional[torch.Tensor]
    # new block offsets
    block_offsets: torch.Tensor
    # cpu copy of indices
    indice_cpu: np.ndarray
    max_q_seqlen: int
    max_kv_seqlen: int
    sum_kv_seqlen: int
    is_decoding: bool = True
    # sliding window
    num_ignored_history: Optional[torch.Tensor] = None

    @property
    def seq_length(self):
        """Get seq_length."""
        batch_size = self.block_offsets.size(0)
        return torch.full((batch_size, ), self.max_q_seqlen, dtype=torch.long)

    def fill_tensors(self):
        """Fill tensor fields."""
        if self.indices is None:
            self.indice_cpu = self.indice_cpu.copy()
            self.indices = torch.as_tensor(self.indice_cpu)

    @torch.inference_mode()
    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        out_dict = dict()
        self.fill_tensors()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor):
                v = v.to(device, non_blocking=non_blocking)
            out_dict[k] = v

        return ModelInputsDelta(**out_dict)

    def log_info(self):
        """Get log info."""
        ret = (f'num_tokens={self.indices.numel()}, batch_size={self.indices.numel()}'
               f', is_decoding={self.is_decoding}')
        return ret


@dataclass
class ModelInputs:
    """Input of the model."""
    input_ids: torch.Tensor
    seq_length: torch.Tensor
    history_lengths: torch.Tensor
    block_offsets: torch.Tensor
    is_decoding: bool
    num_ignored_history: torch.Tensor
    max_q_seqlen: int
    max_kv_seqlen: int
    sum_kv_seqlen: int
    local_adapter_ids: torch.Tensor = None
    vision_inputs: VisionModelInputs = None
    model_metas: List[Dict[str, Any]] = None
    dp_meta: 'DPMeta' = None
    enable_microbatch: bool = False
    is_dummy: bool = False
    state_offsets: torch.Tensor = None
    target_hidden_states: torch.Tensor = None
    target_position_ids: torch.Tensor = None
    is_chunk: bool = False
    is_first_chunk: bool = True

    def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None):
        """Update input ids."""
        assert self.is_decoding
        if step_seqlens is None:
            step_seqlens = self.seq_length
        self.history_lengths += step_seqlens
        self.max_kv_seqlen += self.max_q_seqlen
        self.sum_kv_seqlen += self.max_q_seqlen * self.seq_length.numel()
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        self.input_ids = input_ids
        return self

    @torch.inference_mode()
    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        out_dict = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor):
                v = v.to(device, non_blocking=non_blocking)
            elif isinstance(v, VisionModelInputs):
                v = v.to_device(device, non_blocking=non_blocking)
            out_dict[k] = v

        return ModelInputs(**out_dict)

    def build_dp_meta(self, num_tokens: List[int]):
        """Build dp meta."""
        self.dp_meta = DPMeta.build(self.input_ids.numel(), num_tokens)

    def log_info(self):
        """Get log info."""
        ret = (f'num_tokens={self.input_ids.numel()}, batch_size={self.seq_length.numel()}'
               f', is_decoding={self.is_decoding}, has_vision={self.vision_inputs is not None}')
        return ret


@dataclass
class StepContext:
    """Context of Model.

    patched model might need extra information to perform inference. This dataclass provide these infos and tools.
    """
    input_ids: torch.LongTensor
    model_config: ModelConfig
    cache_config: CacheConfig
    block_offsets: torch.IntTensor
    position_ids: torch.LongTensor
    attention_mask: torch.LongTensor
    q_seqlens: torch.LongTensor
    kv_seqlens: torch.IntTensor
    q_start_loc: torch.LongTensor
    kv_caches: List
    is_decoding: bool
    sum_kv_seqlen: int
    max_kv_seqlen: int = None
    local_adapter_ids: torch.LongTensor = None
    input_embeddings: torch.Tensor = None
    input_embedding_indexing: torch.Tensor = None
    input_multimodals: List[MultiModalData] = None
    vision_inputs: VisionModelInputs = None
    attn_metadata: Any = None
    kv_quant_policy: Literal[0, 4, 8] = 0
    model_metas: List[Dict[str, Any]] = None
    dp_meta: DPMeta = None
    enable_microbatch: bool = False
    # for draft model
    target_hidden_states: torch.Tensor = None

    # states for ssm
    state_caches: List = None
    state_offsets: torch.LongTensor = None

    _outputs: Dict = field(default_factory=dict)

    @classmethod
    def new(
        cls,
        inputs: ModelInputs,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        kv_caches: List = None,
        state_caches: List = None,
        kv_quant_policy: Literal[0, 4, 8] = 0,
    ):
        """Build step context.

        Args:
            inputs (ModelInputs): packaged model inputs.
            device (str): The device of the tensors.
        """
        q_seqlens = inputs.seq_length
        history_seqlens = inputs.history_lengths

        input_multimodals = None
        if inputs.vision_inputs is not None:
            input_multimodals = inputs.vision_inputs.input_multimodals

        # for vlm
        input_embeddings, input_embedding_indexing = None, None
        if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None):
            input_embeddings, input_embedding_indexing = \
                inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens)

        # position ids
        attention_mask, position_ids = cls.get_mask_and_position_ids(inputs)
        q_start_loc = q_seqlens.cumsum(0) - q_seqlens

        # seq_len + history_length
        kv_seqlens = q_seqlens + history_seqlens
        kv_seqlens -= inputs.num_ignored_history

        ret = StepContext(
            input_ids=inputs.input_ids,
            model_config=model_config,
            cache_config=cache_config,
            block_offsets=inputs.block_offsets,
            position_ids=position_ids,
            input_embeddings=input_embeddings,
            input_embedding_indexing=input_embedding_indexing,
            input_multimodals=input_multimodals,
            attention_mask=attention_mask,
            q_seqlens=q_seqlens,
            kv_seqlens=kv_seqlens,
            q_start_loc=q_start_loc,
            kv_caches=kv_caches,
            is_decoding=inputs.is_decoding,
            sum_kv_seqlen=inputs.sum_kv_seqlen,
            max_kv_seqlen=inputs.max_kv_seqlen,
            local_adapter_ids=inputs.local_adapter_ids,
            vision_inputs=inputs.vision_inputs,
            kv_quant_policy=kv_quant_policy,
            model_metas=inputs.model_metas,
            dp_meta=inputs.dp_meta,
            enable_microbatch=inputs.enable_microbatch,
            state_caches=state_caches,
            state_offsets=inputs.state_offsets,
            target_hidden_states=inputs.target_hidden_states,
        )

        ret = get_backend().update_step_context(ret)
        return ret

    @classmethod
    def get_mask_and_position_ids(cls, inputs: ModelInputs):
        """Get position ids."""
        q_seqlens = inputs.seq_length
        history_seqlens = inputs.history_lengths
        max_q_seqlen = inputs.max_q_seqlen
        target_position_ids = inputs.target_position_ids
        # decoding
        if max_q_seqlen == 1:
            attention_mask = torch.ones_like(q_seqlens)[:, None]
            if target_position_ids is not None:
                position_ids = target_position_ids
            else:
                position_ids = history_seqlens.unsqueeze(0).clone()
            return attention_mask, position_ids

        num_tokens = inputs.input_ids.numel()
        batch_size = inputs.seq_length.numel()
        device = q_seqlens.device

        # batch with same seqlens
        if max_q_seqlen * batch_size == num_tokens:
            attention_mask = None
            ranges = torch.arange(0, max_q_seqlen, device=device)
            position_ids = history_seqlens[:, None] + ranges[None, :]
            position_ids = position_ids.flatten()
            return attention_mask, position_ids[None]

        # get mask
        mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
        attention_mask = (mask_range < q_seqlens[:, None]).long()
        if target_position_ids is not None:
            return attention_mask, target_position_ids

        # position_ids
        indices = attention_mask.long().cumsum(-1) - 1
        position_ids = indices + history_seqlens.unsqueeze(-1)
        indices[1:] += q_seqlens.cumsum(0)[:-1, None]
        position_ids_1d = position_ids.new_empty(num_tokens)
        position_ids_1d[indices.flatten()] = position_ids.flatten()
        position_ids = position_ids_1d[None]
        return attention_mask, position_ids


@dataclass
class BuildModelContext:
    """Context for building model."""
    disable_vision_encoder: bool = False
    dllm_config: DLLMConfig = None
    strategy_factory: 'StrategyFactoryBase' = None
    enable_return_routed_experts: bool = False
    quant_config: QuantizationConfig = field(default_factory=QuantizationConfig)
    fp32_lm_head: bool = False
    tie_word_embeddings: bool = False


class StepContextManager(CtxMgrBase[StepContext]):

    def __init__(self, build_ctx: BuildModelContext = None):
        super().__init__(None)
        build_ctx = build_ctx or BuildModelContext()
        self.build_ctx = build_ctx

    @record_function('build_step_context')
    def build_context(
        self,
        inputs: ModelInputs,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        kv_caches: List = None,
        state_caches: List = None,
        kv_quant_policy: Literal[0, 4, 8] = 0,
    ):
        """Build context."""
        return StepContext.new(
            inputs,
            model_config,
            cache_config,
            kv_caches,
            state_caches,
            kv_quant_policy,
        )


@singleton
class StepCtxMgrApi(CtxMgrBase[StepContextManager]):
    """Context manager for StepContextManager."""

    def __init__(self):
        super().__init__(None)


set_step_ctx_manager = StepCtxMgrApi().set_context
get_step_ctx_manager = StepCtxMgrApi().current_context
step_ctx_manager = StepCtxMgrApi().context


================================================
FILE: lmdeploy/pytorch/models/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .q_modules import QLinear, QRMSNorm

__all__ = ['QLinear', 'QRMSNorm']


================================================
FILE: lmdeploy/pytorch/models/baichuan.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


def _is_baichuan_13b(config: Any):
    """Is baichuan 13b."""
    return config.num_hidden_layers == 40


class BaichuanAttention(nn.Module):
    """Rewrite module of Attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = num_heads
        hidden_size = config.hidden_size
        head_dim = hidden_size // num_heads
        self.is_13b = _is_baichuan_13b(config)

        # packed qkv
        self.W_pack = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            alibi=self.is_13b,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=False,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.W_pack(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.W_pack.split_qkv(qkv_states)

        # apply rotary embedding
        if not self.is_13b:
            cos, sin = rotary_pos_emb
            query_states, key_states = self.apply_rotary_pos_emb(
                query_states,
                key_states,
                cos,
                sin,
                inplace=True,
            )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class MLP(nn.Module):

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class DecoderLayer(nn.Module):
    """Baichuan decoder layer."""

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = BaichuanAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        """forward."""
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class BaichuanModel(nn.Module):
    """Baichuan model."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            DecoderLayer(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        self.is_13b = _is_baichuan_13b(config)
        if not self.is_13b:
            # build rotary embedding in LlamaModel
            emb_type = RopeType.LinearScaling
            rope_dim = config.hidden_size // config.num_attention_heads
            rope_max_pos_emb = config.max_position_embeddings
            rope_base = 10000
            scaling_factor = 1.0
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                rope_base,
                scaling_factor,
                emb_type=emb_type,
            )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        rotary_pos_emb = (None, None)
        if not self.is_13b:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
            cos, sin = cos[0], sin[0]
            rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class BaichuanForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of LlamaForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: Any,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build BaichuanModel
        self.model = BaichuanModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.W_pack' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                elif 'lm_head' in name:
                    loaded_weight = nn.functional.normalize(loaded_weight)
                    param = params_dict[name]
                    load_weight(param, loaded_weight)
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/chatglm2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from torch.nn import functional as F
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding,
                                 build_rotary_params)
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
                                        build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model

LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h] and returns output of the same size.
    """

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)

        self.projection_size = config.kv_channels * config.num_attention_heads
        self.num_attention_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_size = (self.projection_size // config.num_attention_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        self.query_key_value = build_qkv_proj(config.hidden_size,
                                              num_q_heads=self.num_attention_heads,
                                              num_kv_heads=self.num_kv_heads,
                                              head_size=self.head_size,
                                              bias=config.add_bias_linear or config.add_qkv_bias,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              num_replicate_kv_heads=num_replicate_kv_heads)

        # apply rotary
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            self.num_attention_heads,
            self.head_size,
            num_kv_heads=self.num_kv_heads,
        )

        # o_proj
        self.dense = build_o_proj(self.projection_size,
                                  config.hidden_size,
                                  bias=config.add_bias_linear,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device,
                                  is_tp=True)

    @staticmethod
    def _extract_rope(states: torch.Tensor):
        """Extract rope."""
        rope = states.chunk(2, -1)[0]
        rope = rope.unflatten(-1, (-1, 2))
        rope = rope.transpose(-2, -1).flatten(-2, -1).contiguous()
        return rope

    @staticmethod
    def _fill_rope(states: torch.Tensor, rope: torch.Tensor):
        """Fill rope."""
        rope_part = states.chunk(2, -1)[0]
        rope = rope.unflatten(-1, (2, -1))
        rope = rope.transpose(-2, -1).flatten(-2, -1)
        rope_part.copy_(rope)
        return states

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.query_key_value(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        (query_states, key_states, value_states) = self.query_key_value.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        q_rope = self._extract_rope(query_states)
        k_rope = self._extract_rope(key_states)
        q_rope, k_rope = self.apply_rotary_pos_emb(
            q_rope,
            k_rope,
            cos,
            sin,
            inplace=True,
        )
        query_states = self._fill_rope(query_states, q_rope)
        key_states = self._fill_rope(key_states, k_rope)

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.dense(attn_output)
        return attn_output


class MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)

        self.add_bias = config.add_bias_linear
        # gate up
        self.dense_h_to_4h = build_gateup_linear(
            config.hidden_size,
            [config.ffn_hidden_size, config.ffn_hidden_size],
            bias=self.add_bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.dense_4h_to_h = build_down_linear(config.ffn_hidden_size,
                                               config.hidden_size,
                                               bias=self.add_bias,
                                               quant_config=quantization_config,
                                               dtype=dtype,
                                               device=device,
                                               is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.dense_h_to_4h(x)
        act = self.act_fn(gate_up)
        return self.dense_4h_to_h(act)


class GLMBlock(torch.nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an output of the same size.
    """

    def __init__(self,
                 config: PretrainedConfig,
                 layer_number: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_number = layer_number
        self.apply_residual_connection_post_layernorm = \
            config.apply_residual_connection_post_layernorm
        assert not self.apply_residual_connection_post_layernorm

        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attention = SelfAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.layernorm_epsilon,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.layernorm_epsilon,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            layernorm_output = self.input_layernorm(hidden_states)
        else:
            layernorm_output, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        layernorm_input = self.self_attention(
            hidden_states=layernorm_output,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        layernorm_output, residual = self.post_attention_layernorm(layernorm_input, residual)
        mlp_output = self.mlp(layernorm_output)

        outputs = (mlp_output, residual)
        return outputs


class GLMTransformer(nn.Module):
    """Transformer class."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.num_layers = config.num_layers
        self.post_layer_norm = config.post_layer_norm

        def build_layer(layer_number):
            """Build layer."""
            return GLMBlock(config, layer_number, dtype=dtype, device=device)

        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

        if self.post_layer_norm:
            assert config.rmsnorm
            self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, dtype=dtype, device=device)

    def _get_layer(self, layer_number: int):
        """Get layer."""
        return self.layers[layer_number]

    def forward(
        self,
        hidden_states: torch.LongTensor,
        rotary_pos_emb: List[torch.Tensor],
        past_key_values: Optional[List[torch.FloatTensor]],
        attn_metadata: Any,
    ):
        """forward."""
        residual = None
        for index in range(self.num_layers):
            layer = self._get_layer(index)
            hidden_states, residual = layer(
                hidden_states,
                rotary_pos_emb,
                past_key_value=past_key_values[index],
                residual=residual,
                attn_metadata=attn_metadata,
            )

        if self.post_layer_norm:
            hidden_states, _ = self.final_layernorm(hidden_states, residual)
        return hidden_states


class Embedding(nn.Module):
    """Language model embeddings."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.hidden_size = config.hidden_size
        # Word embeddings (parallel).
        self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, dtype=dtype, device=device)
        self.fp32_residual_connection = config.fp32_residual_connection

    def forward(self, input_ids):
        """Rewrite to not transpose hidden_statens for all models."""
        # Embeddings.
        embeddings = self.word_embeddings(input_ids)
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings


class PatchEmbedding(nn.Module):
    """Vision embedding."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.proj = nn.Conv2d(config.in_channels,
                              config.hidden_size,
                              kernel_size=config.patch_size,
                              stride=config.patch_size,
                              dtype=dtype,
                              device=device)
        self.cls_embedding = nn.Parameter(torch.empty(1, config.hidden_size, dtype=dtype, device=device))
        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size, dtype=dtype, device=device)

    def forward(self, images):
        """forward."""
        x = self.proj(images)
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embedding.weight.unsqueeze(0)
        return x


class EVA2CLIPAttention(nn.Module):
    """Vision attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        hidden_size = config.hidden_size
        num_heads = config.num_heads
        head_dim = config.hidden_size // config.num_heads
        self.scale = head_dim**-0.5

        # packed qkv
        self.query_key_value = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_heads,
            head_size=head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # o_proj
        self.dense = build_rowwise_linear(hidden_size,
                                          hidden_size,
                                          bias=True,
                                          quant_config=quantization_config,
                                          dtype=dtype,
                                          device=device,
                                          is_tp=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        # qkv proj
        qkv_states = self.query_key_value(hidden_states)
        q, k, v = self.query_key_value.split_qkv(qkv_states)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.dense(attn_output)
        return attn_output


class EVA2CLIPMLP(nn.Module):
    """Vision MLP."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from transformers.activations import ACT2FN

        # gate up
        quantization_config = getattr(config, 'quantization_config', None)
        self.fc1 = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:
            self.activation_fn = nn.GELU()
        else:
            self.activation_fn = ACT2FN[config.hidden_act]

        # down
        self.fc2 = build_rowwise_linear(config.intermediate_size,
                                        config.hidden_size,
                                        bias=True,
                                        quant_config=quantization_config,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward."""
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        return x


class EVA2CLIPTransformerLayer(nn.Module):
    """Vision trans layer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
        self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps,
                                                     dtype=dtype,
                                                     device=device)

    def forward(self, hidden_states):
        """forward."""
        attention_input = hidden_states
        attention_output = self.input_layernorm(self.attention(attention_input))
        hidden_states = attention_input + attention_output
        mlp_input = hidden_states
        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
        output = mlp_input + mlp_output
        return output


class EVA2CLIPTransformer(nn.Module):
    """Vision transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layers = nn.ModuleList(
            [EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states):
        """forward."""
        for layer_module in self.layers:
            hidden_states = layer_module(hidden_states)
        return hidden_states


class GLU(nn.Module):
    """GLU."""

    def __init__(self,
                 config: PretrainedConfig,
                 in_features: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False, dtype=dtype, device=device)
        self.norm1 = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)
        self.act1 = nn.GELU()
        self.act2 = nn.functional.silu
        self.dense_h_to_4h = nn.Linear(config.hidden_size,
                                       config.ffn_hidden_size,
                                       bias=False,
                                       dtype=dtype,
                                       device=device)
        self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False, dtype=dtype, device=device)
        self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size,
                                       config.hidden_size,
                                       bias=False,
                                       dtype=dtype,
                                       device=device)

    def forward(self, x):
        x = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
        x = self.dense_4h_to_h(x)
        return x


@vlm_model
class EVA2CLIPModel(nn.Module):
    """Vision model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from argparse import Namespace
        vision_config = Namespace(**config.vision_config)

        self.patch_embedding = PatchEmbedding(vision_config, dtype=dtype, device=device)
        self.transformer = EVA2CLIPTransformer(vision_config, dtype=dtype, device=device)
        self.linear_proj = GLU(config, in_features=config.hidden_size, dtype=dtype, device=device)
        self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
                              out_channels=config.hidden_size,
                              kernel_size=2,
                              stride=2,
                              dtype=dtype,
                              device=device)
        self.boi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
        self.eoi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
        self.scaling_factor = vision_config.scaling_factor

    def forward(self, images):
        """forward."""
        x = self.patch_embedding(images)
        x = self.transformer(x)

        x = x[:, 1:]

        b, s, h = x.shape
        grid_size = int(s**0.5)
        x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
        x = self.conv(x)

        x = x.flatten(2).transpose(1, 2)
        x = self.linear_proj(x)
        boi = self.boi.expand(x.shape[0], -1, -1)
        eoi = self.eoi.expand(x.shape[0], -1, -1)
        x = torch.cat((boi, x, eoi), dim=1)
        x = x / self.scaling_factor
        return x


class ChatGLMModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.embedding = Embedding(config, dtype=dtype, device=device)

        # build rotary embedding
        emb_type = RopeType.LinearScaling
        rotary_dim = (config.hidden_size //
                      config.num_attention_heads if config.kv_channels is None else config.kv_channels)
        rope_max_pos_emb = 1 << 20
        rope_base = 10000 * getattr(config, 'rope_ratio', 1.0)
        rope_params = dict(emb_type=emb_type,
                           dim=rotary_dim // 2,
                           max_position_embeddings=rope_max_pos_emb,
                           base=rope_base)
        update_params = build_rotary_params(config)
        rope_params.update(update_params)
        self.rotary_pos_emb = build_rotary_embedding(**rope_params)

        # build encoder
        self.encoder = GLMTransformer(config, dtype=dtype, device=device)

        # output_layers
        self.output_layer = build_rowwise_linear(config.hidden_size,
                                                 config.padded_vocab_size,
                                                 bias=False,
                                                 dtype=dtype,
                                                 device=device)

        self.vision = None
        if hasattr(config, 'vision_config'):
            self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        images: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""

        # token embedding
        if inputs_embeds is None:
            images_features = None
            if images is not None:
                images_features = self.vision(images)
                images_features = images_features.flatten(0, 1)[None]
            inputs_embeds = self.embedding(input_ids)
            if images is not None:
                inputs_embeds.masked_scatter_(image_mask[..., None], images_features)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_pos_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        hidden_states = self.encoder(
            hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
        )

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embedding


class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):
    """Rewrote model of LlamaForCausalLM."""

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.transformer = ChatGLMModel(config, dtype=dtype, device=device)

        self.input_processor = ChatGLMInputProcessor(self.config, dtype)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        images: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            images=images,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.transformer.output_layer(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.transformer.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        images = None
        image_mask = None
        if context.input_multimodals is not None:
            images = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            images = [data for im_data in images for data in im_data]
            if len(images) != 0:
                image_token_id = images[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                images = torch.stack([data.data for data in images])
            else:
                images = None
                image_mask = None

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            images=images,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(num_img_tokens=0)] * batch_size
        return [dict(num_img_tokens=0) if meta is None else meta for meta in model_metas]

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        model_metas = self._get_model_metas(context)
        if not hasattr(self.config, 'vision_config'):
            return model_metas

        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_imgs = [[] for _ in model_metas]
        else:
            input_imgs = []
            for mm in input_multimodals:
                if mm is None:
                    input_imgs.append([])
                else:
                    input_imgs.append(mm.get('image', []))

        config = self.config
        image_size: int = config.vision_config['image_size']
        patch_size: int = config.vision_config['patch_size']
        vision_token_num = ((image_size // patch_size // 2) * (image_size // patch_size // 2) + 2)
        num_pad = vision_token_num - 3

        batched_num_img_tokens = []
        new_model_metas = []
        for meta, imgs in zip(model_metas, input_imgs):
            if meta is None:
                num_img_tokens = 0
            else:
                num_img_tokens = meta.get('num_img_tokens', 0)

            batched_num_img_tokens.append(num_img_tokens)

            num_img_tokens += num_pad * len(imgs)
            new_model_metas.append(dict(num_img_tokens=num_img_tokens))

        # prepare cogvlm position_ids
        q_seqlens = context.q_seqlens
        position_ids = context.position_ids

        if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
            num_img_tokens = torch.tensor(batched_num_img_tokens, device=position_ids.device)
            position_ids -= num_img_tokens[None]
        else:
            batched_position_ids = position_ids[0].split(q_seqlens)
            for pos_ids, num_img_tok, imgs in zip(batched_position_ids, batched_num_img_tokens, input_imgs):
                pos_ids -= num_img_tok
                if len(imgs) == 0:
                    continue

                seq_len = pos_ids.size(0)
                start = pos_ids[0].cpu().item()
                new_pos_ids = []

                imgs = sorted(imgs, key=lambda img: img.start)
                for img in imgs:
                    img_pad_pos = img.start + 1 - num_img_tok
                    num_pad = img.end - img.start - 2
                    new_pos_ids += list(range(start, img_pad_pos))
                    new_pos_ids += [img_pad_pos] * num_pad
                    start = img_pad_pos + 1
                    num_img_tok += num_pad

                remain = seq_len - len(new_pos_ids)
                new_pos_ids += list(range(start, start + remain))

                new_pos_ids = pos_ids.new_tensor(new_pos_ids)
                pos_ids[:] = new_pos_ids

            position_ids = torch.cat(batched_position_ids)[None]
        context.position_ids = position_ids

        return new_model_metas

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'transformer.vision' in name:
                if '.query_key_value' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)
                continue

            if 'rotary_pos_emb.inv_freq' in name:
                continue
            if ('rotary_pos_emb.cos_cached' in name or 'rotary_pos_emb.sin_cached' in name):
                continue
            if (self.config.tie_word_embeddings and 'output_layer.weight' in name):
                continue
            if '.query_key_value' in name:
                param = params_dict[name]
                q, k, v = param.weight_spliter(loaded_weight)
                load_weight(param, q, shard_id='q')
                load_weight(param, k, shard_id='k')
                load_weight(param, v, shard_id='v')
            elif '.dense_h_to_4h' in name:
                param = params_dict[name]
                gate, up = param.weight_spliter(loaded_weight)
                load_weight(param, gate, shard_id=0)
                load_weight(param, up, shard_id=1)
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class ChatGLMInputProcessor(BaseModelInputProcessor):
    """Input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

        if hasattr(config, 'vision_config'):
            vision_config = config.vision_config
            self.image_size = vision_config['image_size']
            self.patch_size = vision_config['patch_size']
            self.num_patches = (self.image_size // self.patch_size)**2
            self.num_positions = self.num_patches + 1
            self.vision_token_num = self.num_patches // 4

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            num_pad = input_mm['image_tokens']
            image_token_id = input_mm['image_token_id']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/cogvlm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from argparse import Namespace
from typing import Any, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model


class VisionExpertAttention(nn.Module):
    """Rewrite module of VisionExpertAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        is_cogvlm2 = hasattr(config, 'num_multi_query_heads')
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        self.hidden_size = hidden_size
        self.num_kv_heads = num_key_value_heads
        self.head_dim = head_dim

        # packed qkv
        self.vision_expert_query_key_value = build_qkv_proj(hidden_size,
                                                            num_q_heads=num_heads,
                                                            num_kv_heads=num_key_value_heads,
                                                            head_size=head_dim,
                                                            bias=is_cogvlm2,
                                                            quant_config=quantization_config,
                                                            dtype=dtype,
                                                            device=device,
                                                            num_replicate_kv_heads=num_replicate_kv_heads)
        self.language_expert_query_key_value = build_qkv_proj(hidden_size,
                                                              num_q_heads=num_heads,
                                                              num_kv_heads=num_key_value_heads,
                                                              head_size=head_dim,
                                                              bias=False,
                                                              quant_config=quantization_config,
                                                              dtype=dtype,
                                                              device=device,
                                                              num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
        )

        # o_proj
        self.vision_expert_dense = build_rowwise_linear(hidden_size,
                                                        hidden_size,
                                                        bias=False,
                                                        quant_config=quantization_config,
                                                        dtype=dtype,
                                                        device=device,
                                                        is_tp=True,
                                                        all_reduce=False)
        self.language_expert_dense = build_rowwise_linear(hidden_size,
                                                          hidden_size,
                                                          bias=False,
                                                          quant_config=quantization_config,
                                                          dtype=dtype,
                                                          device=device,
                                                          is_tp=True,
                                                          all_reduce=False)
        world_size, _ = get_tp_world_rank()
        self.world_size = world_size
        self.all_reduce = world_size > 1

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
        lang_ids: torch.LongTensor = None,
        vision_ids: torch.LongTensor = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        bsz, seqlen, _ = hidden_states.size()
        hidden_size = self.hidden_size // self.world_size
        kv_size = self.num_kv_heads * self.head_dim // self.world_size

        # qkv proj
        if lang_ids is None and vision_ids is None:
            qkv_states = self.language_expert_query_key_value(hidden_states)
        else:
            qkv_states = hidden_states.new_empty(bsz, seqlen, hidden_size + kv_size * 2)
            if lang_ids is not None:
                qkv_states[:, lang_ids] = self.language_expert_query_key_value(hidden_states[:, lang_ids])
            if vision_ids is not None:
                qkv_states[:, vision_ids] = self.vision_expert_query_key_value(hidden_states[:, vision_ids])
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = \
            self.language_expert_query_key_value.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        if lang_ids is None and vision_ids is None:
            attn_output = self.language_expert_dense(attn_output)
        else:
            new_attn_output = torch.empty_like(hidden_states)
            if lang_ids is not None:
                new_attn_output[:, lang_ids] = self.language_expert_dense(attn_output[:, lang_ids])
            if vision_ids is not None:
                new_attn_output[:, vision_ids] = self.vision_expert_dense(attn_output[:, vision_ids])
            attn_output = new_attn_output

        if self.all_reduce:
            dist.all_reduce(attn_output)
        return attn_output


class MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        assert config.hidden_act == 'silu'

        quantization_config = getattr(config, 'quantization_config', None)

        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(config.intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True,
                                              all_reduce=False)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class VisionExpertMLP(nn.Module):
    """Vision expert mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.language_mlp = MLP(config, dtype=dtype, device=device)
        self.vision_mlp = MLP(config, dtype=dtype, device=device)
        world_size, _ = get_tp_world_rank()
        self.all_reduce = world_size > 1

    def forward(
        self,
        hidden_states: torch.Tensor,
        lang_ids: torch.LongTensor = None,
        vision_ids: torch.LongTensor = None,
    ):
        """forward."""
        if lang_ids is None and vision_ids is None:
            output = self.language_mlp(hidden_states)
        else:
            output = torch.empty_like(hidden_states)
            if lang_ids is not None:
                output[:, lang_ids] = self.language_mlp(hidden_states[:, lang_ids])
            if vision_ids is not None:
                output[:, vision_ids] = self.vision_mlp(hidden_states[:, vision_ids])
        if self.all_reduce:
            dist.all_reduce(output)
        return output


class CogVLMDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = VisionExpertAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = VisionExpertMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        lang_ids: torch.LongTensor = None,
        vision_ids: torch.LongTensor = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
            lang_ids=lang_ids,
            vision_ids=vision_ids,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(
            hidden_states,
            lang_ids=lang_ids,
            vision_ids=vision_ids,
        )

        outputs = (hidden_states, residual)
        return outputs


class PatchEmbedding(nn.Module):
    """Vision embedding."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.proj = nn.Conv2d(config.in_channels,
                              config.hidden_size,
                              kernel_size=config.patch_size,
                              stride=config.patch_size,
                              dtype=dtype,
                              device=device)
        self.cls_embedding = nn.Parameter(torch.empty(1, config.hidden_size, dtype=dtype, device=device))
        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size, dtype=dtype, device=device)

    def forward(self, images):
        """forward."""
        x = self.proj(images)
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embedding.weight.unsqueeze(0)
        return x


class EVA2CLIPAttention(nn.Module):
    """Vision attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        hidden_size = config.hidden_size
        num_heads = config.num_heads
        head_dim = config.hidden_size // config.num_heads
        self.scale = head_dim**-0.5

        # packed qkv
        self.query_key_value = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_heads,
            head_size=head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # o_proj
        self.dense = build_rowwise_linear(hidden_size,
                                          hidden_size,
                                          bias=True,
                                          quant_config=quantization_config,
                                          dtype=dtype,
                                          device=device,
                                          is_tp=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        # qkv proj
        qkv_states = self.query_key_value(hidden_states)
        q, k, v = self.query_key_value.split_qkv(qkv_states)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.dense(attn_output)
        return attn_output


class EVA2CLIPMLP(nn.Module):
    """Vision MLP."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from transformers.activations import ACT2FN

        # gate up
        quantization_config = getattr(config, 'quantization_config', None)
        self.fc1 = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:
            self.activation_fn = nn.GELU()
        else:
            self.activation_fn = ACT2FN[config.hidden_act]

        # down
        self.fc2 = build_rowwise_linear(config.intermediate_size,
                                        config.hidden_size,
                                        bias=True,
                                        quant_config=quantization_config,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward."""
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        return x


class EVA2CLIPTransformerLayer(nn.Module):
    """Vision trans layer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)
        self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps,
                                                     dtype=dtype,
                                                     device=device)

    def forward(self, hidden_states):
        """forward."""
        attention_input = hidden_states
        attention_output = self.input_layernorm(self.attention(attention_input))
        hidden_states = attention_input + attention_output
        mlp_input = hidden_states
        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
        output = mlp_input + mlp_output
        return output


class EVA2CLIPTransformer(nn.Module):
    """Vision transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layers = nn.ModuleList(
            [EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states):
        """forward."""
        for layer_module in self.layers:
            hidden_states = layer_module(hidden_states)
        return hidden_states


class GLU(nn.Module):
    """GLU."""

    def __init__(self,
                 config: PretrainedConfig,
                 in_features: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False, dtype=dtype, device=device)
        self.norm1 = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)
        self.act1 = nn.GELU()
        self.act2 = nn.functional.silu
        self.dense_h_to_4h = nn.Linear(config.hidden_size,
                                       config.intermediate_size,
                                       bias=False,
                                       dtype=dtype,
                                       device=device)
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=dtype, device=device)
        self.dense_4h_to_h = nn.Linear(config.intermediate_size,
                                       config.hidden_size,
                                       bias=False,
                                       dtype=dtype,
                                       device=device)

    def forward(self, x):
        x = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
        x = self.dense_4h_to_h(x)
        return x


@vlm_model
class EVA2CLIPModel(nn.Module):
    """Vision model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        vision_config = Namespace(**config.vision_config)

        self.patch_embedding = PatchEmbedding(vision_config, dtype=dtype, device=device)
        self.transformer = EVA2CLIPTransformer(vision_config, dtype=dtype, device=device)
        self.linear_proj = GLU(config, in_features=vision_config.hidden_size, dtype=dtype, device=device)
        if vision_config.num_positions == 1226:
            # cogvlm-chat-hf
            self.conv = None
        else:
            # cogvlm2
            self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
                                  out_channels=vision_config.hidden_size,
                                  kernel_size=2,
                                  stride=2,
                                  dtype=dtype,
                                  device=device)
        self.boi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))
        self.eoi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))

    def forward(self, images):
        """forward."""
        x = self.patch_embedding(images)
        x = self.transformer(x)

        x = x[:, 1:]
        # cogvlm2
        if self.conv is not None:
            b, s, h = x.shape
            grid_size = int(s**0.5)
            x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
            x = self.conv(x)

            x = x.flatten(2).transpose(1, 2)
        x = self.linear_proj(x)
        boi = self.boi.expand(x.shape[0], -1, -1)
        eoi = self.eoi.expand(x.shape[0], -1, -1)
        x = torch.cat((boi, x, eoi), dim=1)
        return x


class CogVLMModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            CogVLMDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # vision model
        self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)

        # build rotary embedding
        emb_type = RopeType.LinearScaling
        rope_dim = config.hidden_size // config.num_attention_heads
        rope_max_pos_emb = 2048
        rope_base = 10000
        self.rotary_emb = build_rotary_embedding(
            rope_dim,
            rope_max_pos_emb,
            rope_base,
            emb_type=emb_type,
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        images: torch.Tensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        lang_ids: torch.LongTensor = None,
        vision_ids: torch.LongTensor = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            if images is not None:
                images_features = self.vision(images)

            inputs_embeds = self.embed_tokens(input_ids)
            if vision_ids is not None:
                inputs_embeds[0, vision_ids] = images_features.flatten(0, 1)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
                lang_ids=lang_ids,
                vision_ids=vision_ids,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1


class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # preprocessor
        self.input_processor = CogVLMInputProcessor(self.config, dtype)
        # build model
        self.model = CogVLMModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        images: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        lang_ids: torch.LongTensor = None,
        vision_ids: torch.LongTensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            images=images,
            inputs_embeds=inputs_embeds,
            lang_ids=lang_ids,
            vision_ids=vision_ids,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids

        # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)
        position_ids = context.position_ids
        lang_ids = None
        vis_ids = None

        # vision inputs
        images = None
        if context.input_multimodals is not None:
            images = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            images = [data for im_data in images for data in im_data]
            if len(images) == 0:
                images = None

        if images is not None:
            image_token_id = images[0].meta['image_token_id']
            vis_mask = input_ids[0] == image_token_id
            images = torch.stack([data.data for data in images])

            # get lang_ids
            vis_range = torch.arange(0, input_ids.size(-1), device=input_ids.device)
            vis_ids = vis_range[vis_mask]
            lang_ids = vis_range[~vis_mask]

        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            images=images,
            inputs_embeds=inputs_embeds,
            lang_ids=lang_ids,
            vision_ids=vis_ids,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if '.vision.' in name:
                    continue
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '_expert_query_key_value' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                elif '.query_key_value' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(num_img_tokens=0)] * batch_size
        return [dict(num_img_tokens=0) if meta is None else meta for meta in model_metas]

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_imgs = [[] for _ in model_metas]
        else:
            input_imgs = []
            for mm in input_multimodals:
                if mm is None:
                    input_imgs.append([])
                else:
                    input_imgs.append(mm.get('image', []))

        num_pad = self.input_processor.vision_token_num - 3

        batched_num_img_tokens = []
        new_model_metas = []
        for meta, imgs in zip(model_metas, input_imgs):
            if meta is None:
                num_img_tokens = 0
            else:
                num_img_tokens = meta.get('num_img_tokens', 0)

            batched_num_img_tokens.append(num_img_tokens)

            num_img_tokens += num_pad * len(imgs)
            new_model_metas.append(dict(num_img_tokens=num_img_tokens))

        # prepare cogvlm position_ids
        q_seqlens = context.q_seqlens
        position_ids = context.position_ids

        if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):
            num_img_tokens = torch.tensor(batched_num_img_tokens, device=position_ids.device)
            position_ids -= num_img_tokens[None]
        else:
            batched_position_ids = position_ids[0].split(q_seqlens)
            for pos_ids, num_img_tok, imgs in zip(batched_position_ids, batched_num_img_tokens, input_imgs):
                pos_ids -= num_img_tok
                if len(imgs) == 0:
                    continue

                seq_len = pos_ids.size(0)
                start = pos_ids[0].cpu().item()
                new_pos_ids = []

                imgs = sorted(imgs, key=lambda img: img.start)
                for img in imgs:
                    img_pad_pos = img.start + 1 - num_img_tok
                    num_pad = img.end - img.start - 2
                    new_pos_ids += list(range(start, img_pad_pos))
                    new_pos_ids += [img_pad_pos] * num_pad
                    start = img_pad_pos + 1
                    num_img_tok += num_pad

                remain = seq_len - len(new_pos_ids)
                new_pos_ids += list(range(start, start + remain))

                new_pos_ids = pos_ids.new_tensor(new_pos_ids)
                pos_ids[:] = new_pos_ids

            position_ids = torch.cat(batched_position_ids)[None]
        context.position_ids = position_ids

        return new_model_metas

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class CogVLMInputProcessor(BaseModelInputProcessor):
    """Input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype
        image_size: int = config.vision_config['image_size']
        patch_size: int = config.vision_config['patch_size']
        if config.vision_config['num_positions'] == 1226:
            # # cogvlm-chat-hf
            self.vision_token_num = 2 + (image_size // patch_size)**2
        else:
            # cogvlm2
            self.vision_token_num = 2 + (image_size // patch_size // 2)**2

    def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/deepseek.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class DeepseekAttention(nn.Module):
    """Rewrite module of MistralAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class DeepseekMoE(nn.Module):
    """Deepseek MoE."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.top_k > 1 and self.norm_topk_prob

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            all_reduce=False,
        )

        self.shared_experts = None
        if config.n_shared_experts is not None:
            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)
            self.shared_experts = DeepseekMLP(
                config=config,
                intermediate_size=intermediate_size,
                dtype=dtype,
                device=device,
                is_tp=True,
                all_reduce=False,
            )
        world_size, _ = get_tp_world_rank()
        if world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)

        topk_weights, topk_ids = self.softmax_topk(router_logits)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        if self.shared_experts is not None:
            shared_states = self.shared_experts(hidden_states)
            out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)

        return out_states


class DeepseekMLP(nn.Module):
    """Deepseek mlp."""

    def __init__(self,
                 config: Any,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(
            intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            all_reduce=all_reduce,
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class DeepseekDecoderLayer(nn.Module):
    """Llama decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = DeepseekAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = (DeepseekMoE(config, dtype=dtype, device=device) if
                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
                     and layer_idx % config.moe_layer_freq == 0) else DeepseekMLP(config, dtype=dtype, device=device))

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class DeepseekModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            DeepseekDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class DeepseekForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = DeepseekModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        num_experts = self.config.n_routed_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/deepseek_mtp.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding,
                                 build_rotary_params)
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from lmdeploy.utils import get_logger

from .deepseek_v2 import DeepseekV2Attention, DeepseekV2DecoderLayer, MoEGate, yarn_get_mscale
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin

logger = get_logger('lmdeploy')


class DeepseekV2BMM(nn.Module):
    """Wrapped bmm."""

    def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):
        super().__init__()

        weight = self.create_weight(batch, in_features, out_features, dtype=dtype, device=device)
        weight = torch.nn.Parameter(weight, requires_grad=False)
        self.register_parameter('weight', weight)
        weight.weight_loader = self.weight_loader

        self.batch = batch
        self.in_features = in_features
        self.out_features = out_features
        self.dtype = dtype
        self.device = device

    def create_weight(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):
        """Create weight."""
        return torch.empty((batch, in_features, out_features), dtype=dtype, device=device)

    def weight_loader(self, param: nn.Parameter, weight: torch.Tensor):
        """Weight loader."""
        param.data.copy_(weight)

    def forward(self, x: torch.Tensor, output: torch.Tensor):
        """forward."""
        torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1))


class DeepseekV2Attention(DeepseekV2Attention):
    """Deepseekv2 attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        nn.Module.__init__(self)
        quantization_config = getattr(config, 'quantization_config', None)
        self.q_lora_rank = config.q_lora_rank
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
        use_flash_mla = getattr(config, 'use_flash_mla', False)

        if self.q_lora_rank is None:
            self.q_proj = build_colwise_linear(
                self.hidden_size,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=False,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )
        else:
            self.q_a_proj = build_colwise_linear(
                self.hidden_size,
                config.q_lora_rank,
                bias=config.attention_bias,
                dtype=dtype,
                device=device,
                is_tp=False,
                quant_config=quantization_config,
            )
            self.q_a_layernorm = RMSNorm(config.q_lora_rank,
                                         1e-6,
                                         quant_config=quantization_config,
                                         dtype=dtype,
                                         device=device)
            self.q_b_proj = build_colwise_linear(
                config.q_lora_rank,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=False,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )

        self.kv_a_proj_with_mqa = build_colwise_linear(
            self.hidden_size,
            config.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=quantization_config,
        )
        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
                                      1e-6,
                                      quant_config=quantization_config,
                                      dtype=dtype,
                                      device=device)
        self.kc = DeepseekV2BMM(self.num_heads,
                                config.qk_nope_head_dim,
                                config.kv_lora_rank,
                                dtype=dtype,
                                device=device)

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        self.softmax_scale = self.q_head_dim**(-0.5)

        rope_scaling = get_rope_parameters(config)
        if rope_scaling is not None:
            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
            scaling_factor = rope_scaling.get('factor', 1.0)
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.softmax_scale = self.softmax_scale * mscale * mscale

        self.attn_fwd = Attention(self.num_heads,
                                  config.kv_lora_rank + self.qk_rope_head_dim,
                                  scale=self.softmax_scale,
                                  num_kv_heads=num_key_value_heads,
                                  v_head_size=config.kv_lora_rank,
                                  num_replicate_kv_heads=num_replicate_kv_heads,
                                  use_flash_mla=use_flash_mla)

        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)
        self.o_proj = build_o_proj(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=quantization_config,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        num_heads = self.num_heads
        nope_size = self.kv_lora_rank
        q_len = hidden_states.size(1)

        # qkv_proj
        query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj(hidden_states, num_heads=num_heads)

        cos, sin = rotary_pos_emb
        q_pe, k_pe = self.apply_rotary_pos_emb(
            q_pe,
            k_pe,
            cos,
            sin,
            inplace=False,
        )
        query_states[..., nope_size:] = q_pe
        key_states[..., nope_size:] = k_pe

        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[0][..., :nope_size],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)

        self.vc(attn_output, attn_bmm_out)
        attn_output = attn_bmm_out.flatten(-2, -1)[None]
        attn_output = self.o_proj(attn_output)

        return attn_output


class DeepseekV2MoE(nn.Module):
    """Deepseek v2 MoE."""

    def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.routed_scaling_factor = config.routed_scaling_factor
        self.renormalize = self.top_k > 1 and self.norm_topk_prob
        self.topk_method = config.topk_method
        self.n_group = config.n_group
        self.topk_group = config.topk_group

        self.gate = MoEGate(config, dtype=dtype, device=device, info=None)
        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=False,
            dtype=dtype,
            device=device,
            all_reduce=False,
            quant_config=quantization_config,
            layer_idx=layer_idx,
        )
        self.shared_experts = None
        if config.n_shared_experts is not None:
            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)
            self.shared_experts = DeepseekV2MLP(
                config=config,
                intermediate_size=intermediate_size,
                dtype=dtype,
                device=device,
            )

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        topk_weights, topk_ids = self.gate(hidden_states)

        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        if self.shared_experts is not None:
            shared_states = self.shared_experts(hidden_states)
            out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        return out_states


class DeepseekV2MLP(nn.Module):
    """Deepseek v2 mlp."""

    def __init__(self,
                 config: Any,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()

        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=False,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(
            intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=False,
            all_reduce=False,
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class DeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
    """Deepseekv2 decoder layer."""

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        nn.Module.__init__(self)
        self.layer_idx = layer_idx
        quantization_config = None

        # build attention layer
        self.self_attn = DeepseekV2Attention(config, dtype=dtype, device=device)

        # mlp
        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if
                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)


# modify from vllm


class SharedHead(nn.Module):
    """Deepseekv2 shared head."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
        # build lm_head
        self.head = build_rowwise_linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


def build_deepseek_rotary_embedding(config: PretrainedConfig):
    """Build deepseek rotary embedding."""
    emb_type = RopeType.LinearScaling
    rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
                                                                                 config.num_attention_heads)
    rope_max_pos_emb = config.max_position_embeddings
    rope_base = get_rope_theta(config)

    rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
    update_params = build_rotary_params(config)
    rope_params.update(update_params)
    return build_rotary_embedding(**rope_params)


class DeepSeekMultiTokenPredictorLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        dtype: torch.dtype = None,
        device: torch.device = None,
        decoder_layer_cls=DeepseekV2DecoderLayer,
        build_rotary_embedding_func=build_deepseek_rotary_embedding,
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)
        quantization_config = getattr(config, 'quantization_config', None)

        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.eh_proj = build_colwise_linear(
            config.hidden_size * 2,
            config.hidden_size,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=quantization_config,
            dp_disable_tp=True,
        )

        self.shared_head = SharedHead(config=config, dtype=dtype, device=device)

        self.mtp_block = decoder_layer_cls(config, layer_idx=layer_idx, dtype=dtype, device=device)

        self.rotary_emb = build_rotary_embedding_func(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        past_key_value: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        assert inputs_embeds is not None

        # masking inputs at position 0, as not needed by MTP
        inputs_embeds[position_ids == 0] = 0
        inputs_embeds = self.enorm(inputs_embeds)
        previous_hidden_states = self.hnorm(previous_hidden_states)

        hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))

        # rotary emb
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        hidden_states, residual = self.mtp_block(
            hidden_states,
            rotary_pos_emb,
            past_key_value,
            attn_metadata=attn_metadata,
        )
        hidden_states = residual + hidden_states
        return hidden_states


class DeepSeekMultiTokenPredictor(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        dtype: torch.dtype = None,
        device: torch.device = None,
        decoder_layer_cls=DeepseekV2DecoderLayer,
        build_rotary_embedding_func=build_deepseek_rotary_embedding,
    ):
        super().__init__()
        self.config = config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
        self.layers = torch.nn.ModuleDict({
            str(idx):
            DeepSeekMultiTokenPredictorLayer(
                config,
                idx,
                dtype=dtype,
                device=device,
                decoder_layer_cls=decoder_layer_cls,
                build_rotary_embedding_func=build_rotary_embedding_func,
            )
            for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers)
        })

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        current_step_idx = (spec_step_idx % self.num_mtp_layers)
        layer_idx = self.mtp_start_layer_idx + current_step_idx
        past_key_value = past_key_values[current_step_idx]
        return self.layers[str(layer_idx)](
            input_ids,
            position_ids,
            previous_hidden_states,
            past_key_value,
            inputs_embeds=inputs_embeds,
            attn_metadata=attn_metadata,
            spec_step_index=current_step_idx,
        )

    def get_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        current_step_idx = (spec_step_idx % self.num_mtp_layers)
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]

        hidden_states = mtp_layer.shared_head(hidden_states)
        logits = mtp_layer.shared_head.head(hidden_states)
        return logits


class DeepseekMTPModel(nn.Module, CudaGraphMixin):

    def __init__(
        self,
        config: PretrainedConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        decoder_layer_cls=DeepseekV2DecoderLayer,
        build_rotary_embedding_func=build_deepseek_rotary_embedding,
    ):
        super().__init__()
        self.config = config
        self.quantization_config = getattr(config, 'quantization_config', None)
        self.dtype = dtype
        self.ctx_mgr = ctx_mgr
        self.model = DeepSeekMultiTokenPredictor(config,
                                                 dtype=dtype,
                                                 device=device,
                                                 decoder_layer_cls=decoder_layer_cls,
                                                 build_rotary_embedding_func=build_rotary_embedding_func)

        self._load_buffers = dict()

    def get_logits(self, hidden_states: torch.Tensor, spec_step_idx: int = 0):
        """Compute logits of the model output."""
        return self.model.get_logits(hidden_states, spec_step_idx=spec_step_idx)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        target_hidden_states: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids,
                                   position_ids,
                                   target_hidden_states,
                                   inputs_embeds=inputs_embeds,
                                   past_key_values=past_key_values,
                                   attn_metadata=attn_metadata,
                                   spec_step_idx=spec_step_idx)
        return hidden_states

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,
                                                                                     max_tokens,
                                                                                     self.config.hidden_size,
                                                                                     dtype=self.dtype)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: torch.Tensor, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, input_ids=input_ids, **kwargs)

        num_tokens = input_ids.size(-1)
        input_buffers = graph_meta.input_buffers
        target_hidden_states = kwargs.get('target_hidden_states')
        assert target_hidden_states is not None
        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states
        new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']
        return new_inputs

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata
        target_hidden_states = context.target_hidden_states
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            target_hidden_states=target_hidden_states,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                               update_pe_mapping: List):
        """Load weight attention."""
        device = next(iter(params_dict.values())).device

        def __update_pe(weight, head_dim: int, pe_dim_offset: int):
            # (num_heads, q_head_dim, input_dim)
            weight = weight.unflatten(0, (-1, head_dim))
            # (num_heads, nope_head_dim, input_dim)
            w_pe = weight[:, pe_dim_offset:]
            # (num_heads, nope_head_dim//2, 2, input_dim)
            new_w_pe = w_pe.unflatten(1, (-1, 2))
            # (num_heads, nope_head_dim, input_dim)
            new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2)
            weight[:, pe_dim_offset:] = new_w_pe
            weight = weight.flatten(0, 1)
            return weight

        def __load_kcvc(name: str, weight: torch.Tensor):
            """Load kc and vc from weight."""
            config = self.config
            v_head_dim = config.v_head_dim
            qk_nope_head_dim = config.qk_nope_head_dim
            w_kc, w_vc = weight.unflatten(0, (-1, qk_nope_head_dim + v_head_dim)).split([qk_nope_head_dim, v_head_dim],
                                                                                        dim=1)
            w_vc = w_vc.transpose(1, 2).contiguous()
            kc_param_name = name.replace('.kv_b_proj', '.kc')
            param_kc = params_dict[kc_param_name]
            load_weight(param_kc, w_kc)
            vc_param_name = name.replace('.kv_b_proj', '.vc')
            param_vc = params_dict[vc_param_name]
            load_weight(param_vc, w_vc)

        def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype):
            """Dequant weight."""
            dim_w0, dim_w1 = weight.shape
            dim_s0, dim_s1 = scale.shape
            assert dim_w0 % dim_s0 == 0
            assert dim_w1 % dim_s1 == 0
            group0 = dim_w0 // dim_s0
            group1 = dim_w1 // dim_s1
            weight = weight.reshape(dim_s0, group0, dim_s1, group1)
            scale = scale.reshape(dim_s0, 1, dim_s1, 1)
            weight = weight.to(scale.dtype) * scale
            weight = weight.to(dtype)
            weight = weight.reshape(dim_w0, dim_w1)
            return weight

        def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
            """Dequant weight."""
            if name.endswith('.weight'):
                weight_name = name
                scale_name = name.replace('.weight', '.scale')
            elif name.endswith('.weight_scale_inv'):
                weight_name = name.replace('.weight_scale_inv', '.weight')
                scale_name = name
            self._load_buffers[name] = loaded_weight
            if (weight_name in self._load_buffers and scale_name in self._load_buffers):
                weight = self._load_buffers.pop(weight_name)
                scale = self._load_buffers.pop(scale_name)
                kc_param_name = weight_name.replace('.kv_b_proj', '.kc')
                dtype = params_dict[kc_param_name].dtype
                weight = __dequant_weight(weight, scale, dtype)
                __load_kcvc(weight_name, weight)

        for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:
            if mod_name not in name:
                continue
            if name.endswith('.weight_scale_inv'):
                weight = loaded_weight
            else:
                loaded_weight = loaded_weight.to(device)
                weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
            param = params_dict[name]
            load_weight(param, weight)
            break
        else:
            if '.kv_b_proj' in name:
                quantization_config = self.quantization_config
                quant_method = None
                if quantization_config is not None:
                    quant_method = quantization_config.get('quant_method')

                loaded_weight = loaded_weight.to(device)
                if quant_method == 'fp8':
                    # update blocked fp8 weight
                    __load_kcvc_blocked_fp8(name, loaded_weight)
                else:
                    __load_kcvc(name, loaded_weight)
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_nextn(name, nextn_keys):
            for nextn_key in nextn_keys:
                if nextn_key in name:
                    return True
            return False

        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        config = self.config

        qk_rope_head_dim = config.qk_rope_head_dim
        kv_lora_rank = config.kv_lora_rank
        qk_nope_head_dim = config.qk_nope_head_dim
        q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        kv_dim = kv_lora_rank + qk_rope_head_dim
        update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), ('q_b_proj', q_head_dim, qk_nope_head_dim),
                             ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)]

        num_experts = self.config.n_routed_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        num_hidden_layers = self.config.num_hidden_layers

        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)
        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            # keep nextn
            if not __skip_nextn(name, nextn_keys):
                continue
            if '.layers' in name:
                layer_idx = int(name.split('layers.')[1].split('.')[0])
                name = self._rewrite_spec_layer_name(layer_idx, name)
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            elif '.self_attn' in name and getattr(config, 'use_mla', True):
                # attention
                self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping)
            else:
                # other
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """Rewrite the weight name to match the format of the original model.

        Add .mtp_block for modules in transformer layer block for spec layer
        """
        spec_layer_weight_names = ['embed_tokens', 'enorm', 'hnorm', 'eh_proj', 'shared_head']
        spec_layer_weight = False
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
            name = name.replace(f'model.layers.{spec_layer}.', f'model.layers.{spec_layer}.mtp_block.')
        return name


================================================
FILE: lmdeploy/pytorch/models/deepseek_v2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from copy import deepcopy
from enum import Enum, auto
from os import getenv
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager, get_step_ctx_manager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, ParallelEmbedding, RMSNorm, RopeType, SiluAndMul,
                                 build_rotary_embedding, build_rotary_params)
from lmdeploy.pytorch.nn.eplb import EPLBDispatchInfo, EPLBManager
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import MoeType, SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


# microbatch
class ExecType(Enum):
    """Batch exec type."""
    One = auto()
    Two0101 = auto()
    Two0110 = auto()
    TwoLikeOne = auto()
    TwoPrefill = auto()
    TwoDecode = auto()


class BatchWorker:

    def __init__(self, tag: str, generator):
        self._tag = tag
        self._generator = generator
        self._count = 0
        self.output = None

    def next(self):
        assert not self.done

        try:
            next(self._generator)
        except StopIteration as e:
            assert e.value is not None
            self.output = e.value

        self._count += 1

    @property
    def done(self):
        return self.output is not None


def execute_batch(inputs: list, fn, delta_stages: int = 0, exec_type: ExecType = ExecType.One, extern_tag: str = ''):
    worker_list = [BatchWorker(str(idx), fn(**input, tag=str(idx) + extern_tag)) for idx, input in enumerate(inputs)]

    if exec_type == ExecType.One:
        assert len(inputs) == 1
        i = 0
        while not worker_list[0].done:
            worker_list[0].next()
            i += 1

    if exec_type == ExecType.TwoLikeOne:
        assert len(inputs) == 2
        i = 0
        while not worker_list[0].done:
            worker_list[0].next()
            i += 1
        i = 0
        while not worker_list[1].done:
            worker_list[1].next()
            i += 1

    if exec_type == ExecType.Two0101:
        assert len(inputs) == 2

        for _ in range(delta_stages):
            worker_list[0].next()
        i = 0
        while not worker_list[0].done:
            worker_list[0].next()
            worker_list[1].next()
            i += 1

        while not worker_list[1].done:
            worker_list[1].next()

    if exec_type == ExecType.Two0110:
        assert len(inputs) == 2

        for _ in range(delta_stages):
            worker_list[0].next()
        i = 0
        while not worker_list[0].done:
            if i % 2 == 0:
                worker_list[0].next()
                worker_list[1].next()
            else:
                worker_list[1].next()
                worker_list[0].next()
            i += 1

        while not worker_list[1].done:
            worker_list[1].next()

    if exec_type == ExecType.TwoPrefill:
        """
        before:
        A-attn0->A-attn1
        roll:
        B-attn0->B-attn1->A-dis->A-dis_wait->A-moe->B-dis->B-dis_wait->A-comb->
        B-moe->(A-share->A-comb_wait)->B-comb->A-attn0->A-attn1->(B-share->B-comb_wait)
        after:
        B-dis_wait->B-moe->B-comb->B-comb_wait and end
        """
        assert len(inputs) == 2 and delta_stages in [0, 2]

        for _ in range(2):
            worker_list[0].next()

        pipeline = [
            '1-attn0', '1-attn1', '0-dis', '0-dis_wait', '0-moe', '1-dis', '1-dis_wait', '0-comb', '1-moe',
            '0-share+0-comb_wait', '1-comb', '0-attn0', '0-attn1', '1-share+1-comb_wait'
        ]
        pipline_length = len(pipeline)
        i = 0
        while not worker_list[0].done:
            worker_list[int(pipeline[i % pipline_length][0])].next()
            i += 1

        while not worker_list[1].done:
            worker_list[1].next()

    if exec_type == ExecType.TwoDecode:
        """
        before:
        A-attn0->A-attn1->(A-dis->A-share)
        roll:
        B-attn0->A-dis_wait->A-moe->A-comb->B-attn1->A-comb_wait->(B-dis->B-share)->
        A-attn0->B-dis_wait->B-moe->B-comb->A-attn1->B-comb_wait->(A-dis->A-share)
        after:
        B-dis_wait->B-moe->B-comb->B-comb_wait and end
        """
        assert len(inputs) == 2 and delta_stages in [0, 3]

        for _ in range(3):
            worker_list[0].next()

        pipeline = [
            '1-attn0', '0-dis_wait', '0-moe', '0-comb', '1-attn1', '0-comb_wait', '1-dis+1-share', '0-attn0',
            '1-dis_wait', '1-moe', '1-comb', '0-attn1', '1-comb_wait', '0-dis+0-share'
        ]
        pipline_length = len(pipeline)
        i = 0
        while not worker_list[0].done:
            worker_list[int(pipeline[i % pipline_length][0])].next()
            i += 1

        while not worker_list[1].done:
            worker_list[1].next()

    for worker in worker_list:
        assert worker.done
    return [worker.output for worker in worker_list]


def get_new_meta(attn_metadata, start_idx: int, end_idx: int):
    new_attn_metadata = deepcopy(attn_metadata)
    new_attn_metadata.block_offsets = attn_metadata.block_offsets[start_idx:end_idx, ...]
    new_attn_metadata.q_start_loc = attn_metadata.q_start_loc[start_idx:end_idx] - attn_metadata.q_start_loc[start_idx]
    new_attn_metadata.kv_start_loc = attn_metadata.kv_start_loc[start_idx:end_idx] - \
        attn_metadata.kv_start_loc[start_idx] if attn_metadata.kv_start_loc is not None else None
    new_attn_metadata.q_seqlens = attn_metadata.q_seqlens[start_idx:end_idx]
    new_attn_metadata.kv_seqlens = attn_metadata.kv_seqlens[start_idx:end_idx] \
        if attn_metadata.kv_seqlens is not None else None
    new_attn_metadata.kv_flatten_size = sum(new_attn_metadata.kv_seqlens.tolist()) \
        if attn_metadata.kv_flatten_size is not None else None
    # create buffers for flash mla
    if attn_metadata.num_splits is not None:
        Attention.update_meta_flashmla(new_attn_metadata,
                                       get_step_ctx_manager().current_context().model_config.num_attention_heads)
    return new_attn_metadata


def get_new_rotary_pos_emb(rotary_pos_emb, start_loc, end_loc):
    new_rotary_pos_emb = (rotary_pos_emb[0][start_loc:end_loc, ...].contiguous(), rotary_pos_emb[1][start_loc:end_loc,
                                                                                                    ...].contiguous())
    return new_rotary_pos_emb


def get_new_input(hidden_states, rotary_pos_emb, past_key_values, residual, attn_metadata, start_idx, end_idx,
                  start_loc, end_loc):
    new_hidden_states = hidden_states[:, start_loc:end_loc, :].contiguous()
    new_rotary_pos_emb = get_new_rotary_pos_emb(rotary_pos_emb, start_loc, end_loc)
    new_past_key_values = past_key_values
    new_residual = residual[:, start_loc:end_loc, :].contiguous() if residual is not None else None
    new_attn_metadata = get_new_meta(attn_metadata, start_idx, end_idx)
    return new_hidden_states, new_rotary_pos_emb, new_past_key_values, new_residual, new_attn_metadata


def get_split_flags(attn_metadata, num=2):
    """Split flags for seqlens and startloc, support 2 only."""
    assert num == 2
    if attn_metadata.is_decoding:
        batch_size = attn_metadata.q_start_loc.numel()
        flag_a = {
            'start_idx': 0,
            'end_idx': batch_size // 2,
            'start_loc': 0,
            'end_loc': batch_size // 2,
        }
        flag_b = {
            'start_idx': batch_size // 2,
            'end_idx': batch_size,
            'start_loc': batch_size // 2,
            'end_loc': batch_size,
        }
    else:
        q_start_loc = attn_metadata.q_start_loc.tolist()
        q_seqlens = attn_metadata.q_seqlens.tolist()
        total_len = sum(q_seqlens)
        min_diff = total_len
        split_flag = 1
        for idx in range(1, len(q_seqlens)):
            diff = abs(sum(q_seqlens[:idx]) - sum(q_seqlens[idx:]))
            if diff < min_diff:
                min_diff = diff
                split_flag = idx
        flag_a = {
            'start_idx': 0,
            'end_idx': split_flag,
            'start_loc': q_start_loc[0],
            'end_loc': q_start_loc[split_flag],
        }
        flag_b = {
            'start_idx': split_flag,
            'end_idx': len(q_seqlens),
            'start_loc': q_start_loc[split_flag],
            'end_loc': q_start_loc[-1] + q_seqlens[-1],
        }
    return [flag_a, flag_b]


def split_input(hidden_states,
                rotary_pos_emb,
                past_key_values,
                residual,
                attn_metadata,
                moe_start_idx,
                moe_end_idx,
                num=2):
    """Split input, support 1 or 2 only."""
    # one batch
    if num == 1:
        input = {
            'hidden_states': hidden_states,
            'rotary_pos_emb': rotary_pos_emb,
            'past_key_values': past_key_values,
            'residual': residual,
            'attn_metadata': attn_metadata,
            'start_idx': moe_start_idx,
            'end_idx': moe_end_idx
        }
        extern_tag = 'D' if attn_metadata.is_decoding else 'P'
        return [input], ExecType.One, 0, extern_tag
    else:
        # two batch or more
        flag_list = get_split_flags(attn_metadata, num=num)

        inputs = []
        for flag in flag_list:
            (hidden_states_splited, rotary_pos_emb_splited, past_key_values_splited, residual_splited,
             attn_metadata_splited) = get_new_input(hidden_states, rotary_pos_emb, past_key_values, residual,
                                                    attn_metadata, flag['start_idx'], flag['end_idx'],
                                                    flag['start_loc'], flag['end_loc'])
            input = {
                'hidden_states': hidden_states_splited,
                'rotary_pos_emb': rotary_pos_emb_splited,
                'past_key_values': past_key_values,
                'residual': residual_splited,
                'attn_metadata': attn_metadata_splited,
                'start_idx': moe_start_idx,
                'end_idx': moe_end_idx
            }
            inputs.append(input)

        if attn_metadata.is_decoding:
            exec_type = ExecType.TwoDecode
            delta_stages = 0
            extern_tag = 'D'
        else:
            exec_type = ExecType.TwoPrefill
            delta_stages = 0
            extern_tag = 'P'

        return inputs, exec_type, delta_stages, extern_tag


def merge_output(output_list):
    # one batch
    if len(output_list) == 1:
        return output_list[0]
    # two batch or more
    hidden_states = torch.concat([output[0] for output in output_list], dim=1)
    residual = None
    if output_list[0][1] is not None:
        residual = torch.concat([output[1] for output in output_list], dim=1)
    return hidden_states, residual


def yarn_get_mscale(scale=1, mscale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2BMM(nn.Module):
    """Wrapped bmm."""

    def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):
        super().__init__()
        batch = self._update_batch(batch)

        weight = self.create_weight(batch, in_features, out_features, dtype=dtype, device=device)
        weight = torch.nn.Parameter(weight, requires_grad=False)
        self.register_parameter('weight', weight)
        weight.weight_loader = self.weight_loader

        self.batch = batch
        self.in_features = in_features
        self.out_features = out_features
        self.dtype = dtype
        self.device = device

    def _update_batch(self, batch: int):
        """Update out features."""
        world_size, _ = get_tp_world_rank('attn')
        batch = batch // world_size
        return batch

    def create_weight(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):
        """Create weight."""
        return torch.empty((batch, in_features, out_features), dtype=dtype, device=device)

    def weight_loader(self, param: nn.Parameter, weight: torch.Tensor):
        """Weight loader."""
        world_size, rank = get_tp_world_rank('attn')
        weight = weight.chunk(world_size, 0)[rank]
        param.data.copy_(weight)

    def forward(self, x: torch.Tensor, output: torch.Tensor):
        """forward."""
        torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1))


class DeepseekV2Attention(nn.Module):
    """Deepseekv2 attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.q_lora_rank = config.q_lora_rank
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
        use_flash_mla = getattr(config, 'use_flash_mla', False)

        if self.q_lora_rank is None:
            self.q_proj = build_colwise_linear(
                self.hidden_size,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )
        else:
            self.q_a_proj = build_colwise_linear(
                self.hidden_size,
                config.q_lora_rank,
                bias=config.attention_bias,
                dtype=dtype,
                device=device,
                is_tp=False,
                quant_config=quantization_config,
            )
            self.q_a_layernorm = RMSNorm(config.q_lora_rank,
                                         1e-6,
                                         quant_config=quantization_config,
                                         dtype=dtype,
                                         device=device)
            self.q_b_proj = build_colwise_linear(
                config.q_lora_rank,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )

        self.kv_a_proj_with_mqa = build_colwise_linear(
            self.hidden_size,
            config.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=quantization_config,
        )
        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
                                      1e-6,
                                      quant_config=quantization_config,
                                      dtype=dtype,
                                      device=device)
        self.kc = DeepseekV2BMM(self.num_heads,
                                config.qk_nope_head_dim,
                                config.kv_lora_rank,
                                dtype=dtype,
                                device=device)

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        self.softmax_scale = self.q_head_dim**(-0.5)

        rope_scaling = get_rope_parameters(config)
        if rope_scaling is not None:
            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
            scaling_factor = rope_scaling.get('factor', 1.0)
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.softmax_scale = self.softmax_scale * mscale * mscale

        self.attn_fwd = Attention(self.num_heads,
                                  config.kv_lora_rank + self.qk_rope_head_dim,
                                  scale=self.softmax_scale,
                                  num_kv_heads=num_key_value_heads,
                                  v_head_size=config.kv_lora_rank,
                                  num_replicate_kv_heads=num_replicate_kv_heads,
                                  use_flash_mla=use_flash_mla)

        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)
        self.o_proj = build_o_proj(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=True,
            quant_config=quantization_config,
        )

    def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int):
        """Q proj."""
        q_len = hidden_states.size(1)

        query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size)

        if self.q_lora_rank is None:
            q = self.q_proj(hidden_states)
        else:
            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
        q = q.view(q_len, num_heads, self.q_head_dim)
        # q_pe: (q_len, num_heads, qk_rope_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # q_nope: (q_len, num_heads, kv_lora_rank)
        q_nope_out = query_states[..., :nope_size]
        self.kc(q_nope, q_nope_out)
        return query_states, q_pe

    def _kv_proj(self, hidden_states, nope_size: int):
        """Kv proj."""
        # (q_len, 1, nope_size + pe_size)
        key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None])
        # (q_len, 1, pe_size)
        k_pe = key_states[..., nope_size:]
        # kv_a_layernorm
        value_states = key_states[..., :nope_size]
        value_states = self.kv_a_layernorm(value_states)
        key_states[..., :nope_size] = value_states
        return key_states, value_states, k_pe

    def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int):
        """Qkv proj."""
        nope_size = self.kv_lora_rank
        pe_size = self.qk_rope_head_dim
        query_states, q_pe = self._q_proj(hidden_states, num_heads, nope_size, pe_size)
        key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size)

        return query_states, key_states, value_states, q_pe, k_pe

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        dist_config = get_dist_manager().current_config()
        if dist_config.dp > 1:
            num_heads = self.num_heads
        else:
            world_size = dist_config.world_size
            num_heads = self.num_heads // world_size
        nope_size = self.kv_lora_rank
        q_len = hidden_states.size(1)

        # qkv_proj
        query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj(hidden_states, num_heads=num_heads)

        cos, sin = rotary_pos_emb
        q_pe, k_pe = self.apply_rotary_pos_emb(
            q_pe,
            k_pe,
            cos,
            sin,
            inplace=False,
        )
        query_states[..., nope_size:] = q_pe
        key_states[..., nope_size:] = k_pe

        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[0][..., :nope_size],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)

        self.vc(attn_output, attn_bmm_out)
        attn_output = attn_bmm_out.flatten(-2, -1)[None]
        attn_output = self.o_proj(attn_output)

        return attn_output


class MoEGate(nn.Module):
    """Deepseek Gate."""

    def __init__(self,
                 config: Any,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 info: EPLBDispatchInfo = None):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.n_routed_experts = config.n_routed_experts
        self.routed_scaling_factor = config.routed_scaling_factor
        self.scoring_func = config.scoring_func
        self.topk_method = config.topk_method
        self.n_group = config.n_group
        self.topk_group = config.topk_group
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.top_k > 1 and self.norm_topk_prob
        self.router_n_groups = getattr(config, 'router_n_groups', -1)
        assert self.top_k % self.router_n_groups == 0, f'{self.top_k} cannot be divided by {self.router_n_groups}'
        # topk selection algorithm
        self.norm_topk_prob = config.norm_topk_prob
        self.gating_dim = config.hidden_size
        self.weight = nn.Parameter(
            torch.empty((self.n_routed_experts, self.gating_dim), dtype=torch.float32, device=device))
        if self.topk_method == 'noaux_tc':
            from lmdeploy.pytorch.nn.moe.route import NoauxTCRouter
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((self.n_routed_experts, ), dtype=torch.float32, device=device))
            self.noaux_tc_router = NoauxTCRouter(self.scoring_func,
                                                 top_k=self.top_k,
                                                 n_group=self.n_group,
                                                 topk_group=self.topk_group,
                                                 n_routed_experts=self.n_routed_experts,
                                                 routed_scaling_factor=self.routed_scaling_factor,
                                                 renormalize=self.renormalize,
                                                 router_n_groups=self.router_n_groups)
        self.softmax_topk = SoftmaxTopK(self.top_k, n_groups=self.router_n_groups)
        self.fake_eplb = getenv('LMDEPLOY_FAKE_EPLB', 'False').lower() == 'true'
        self.eplb_dispatch_info = info

    def _compute_scores(self, logits: torch.Tensor):
        """Compute scores."""
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1, dtype=torch.float32)
        elif self.scoring_func == 'sigmoid':
            scores = logits.sigmoid()
        else:
            raise NotImplementedError('unsupported scoring function '
                                      f'for MoE gating: {self.scoring_func}')
        return scores

    def _postprocess_topk_weight(self, topk_weight: torch.Tensor):
        if self.renormalize:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator
            if not topk_weight.is_contiguous():
                topk_weight = topk_weight.contiguous()
        if not self.renormalize:
            topk_weight = topk_weight * self.routed_scaling_factor
        return topk_weight

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight)
        if self.fake_eplb:
            # Forcefully manipulate router_logits to simulate expert load balancing (EPLB).
            # This is a benchmark-only hack to achieve optimal performance metrics.
            router_logits = torch.randn_like(router_logits)

        if self.topk_method == 'greedy':
            topk_weight, topk_idx = self.softmax_topk(router_logits)

            topk_weight = self._postprocess_topk_weight(topk_weight)
        elif self.topk_method == 'group_limited_greedy':
            scores = router_logits
            grouped_logits = scores.unflatten(-1, (self.n_group, -1))
            group_scores = (grouped_logits.max(-1).values)
            group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]
            group_mask = torch.zeros_like(group_scores)  # [n, n_group]
            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
            group_mask = ~group_mask.bool()[..., None]
            grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)
            scores = grouped_logits.flatten(1, 2)
            topk_weight, topk_idx = self.softmax_topk(scores)

            topk_weight = self._postprocess_topk_weight(topk_weight)
        elif self.topk_method == 'noaux_tc':
            topk_weight, topk_idx = self.noaux_tc_router(router_logits, self.e_score_correction_bias)
        else:
            raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')

        if self.eplb_dispatch_info is not None:
            topk_idx = EPLBManager.topk_ids_logical_to_physical(topk_idx, self.eplb_dispatch_info)

        return topk_weight, topk_idx


class DeepseekV2MoE(nn.Module):
    """Deepseek v2 MoE."""

    def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.routed_scaling_factor = config.routed_scaling_factor
        self.renormalize = self.top_k > 1 and self.norm_topk_prob
        self.topk_method = config.topk_method
        self.n_group = config.n_group
        self.topk_group = config.topk_group

        dist_ctx = get_dist_manager().current_context()
        dist_config = dist_ctx.dist_config
        dp = dist_config.dp
        world_size = dist_config.world_size
        moe_all_reduce = dp > 1 and dist_config.tp > 1
        if get_dist_manager().current_context().dist_config.enable_eplb:
            eplb_dispatch_info = EPLBManager.get_dispatch_info(
                ep_rank=dist_ctx.ep_rank,
                layer_idx=layer_idx,
            )
            self.num_experts = EPLBManager.num_physical_experts()
            self.gate = MoEGate(config, dtype=dtype, device=device, info=eplb_dispatch_info)
        else:
            self.gate = MoEGate(config, dtype=dtype, device=device, info=None)
        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=False,
            dtype=dtype,
            device=device,
            all_reduce=moe_all_reduce,
            quant_config=quantization_config,
            layer_idx=layer_idx,
        )
        self.shared_experts = None
        if config.n_shared_experts is not None:
            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)
            self.shared_experts = DeepseekV2MLP(
                config=config,
                intermediate_size=intermediate_size,
                dtype=dtype,
                device=device,
                is_shared_expert=True,
            )

        if dp == 1 and world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        topk_weights, topk_ids = self.gate(hidden_states)

        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        if self.shared_experts is not None:
            shared_states = self.shared_experts(hidden_states)
            out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)

        return out_states


class DeepseekV2MLP(nn.Module):
    """Deepseek v2 mlp."""

    def __init__(self,
                 config: Any,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_shared_expert: bool = False):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if is_shared_expert:
            dist_config = get_dist_manager().current_config()
            dp = dist_config.dp
            if dp == 1:
                # split weight, do all reduce in moe
                is_tp = True
                all_reduce = False
            else:
                # do not split weight on dp
                # TODO: support dp+tp?
                is_tp = False
                all_reduce = False
        else:
            all_reduce = True
            is_tp = True

        # gate up
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(
            intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            all_reduce=all_reduce,
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class DeepseekV2DecoderLayer(nn.Module):
    """Deepseekv2 decoder layer."""

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = None

        # build attention layer
        if getattr(config, 'use_mla', True):
            self.self_attn = DeepseekV2Attention(config, dtype=dtype, device=device)
        else:
            # deepseek-vl2-tiny uses MHA LlamaAttention structure
            from lmdeploy.pytorch.models.llama import LlamaAttention
            self.self_attn = LlamaAttention(config, dtype=dtype, device=device)

        # mlp
        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if
                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs

    def forward_yield(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        tag: Any = None,
    ):
        """forward_yield."""
        is_decoding = attn_metadata.is_decoding
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # yield for attn0 and attn1
        yield
        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

        # MOE
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        topk_weights, topk_idx = self.mlp.gate(hidden_states)

        topk_weights = self.mlp.experts.renormalize(topk_weights)
        topk_weights = topk_weights.to(torch.float32)
        topk_idx = topk_idx.to(torch.int64)
        hidden_shape = hidden_states.shape
        shared_states = None

        state = {
            'hidden_states': hidden_states,
            'topk_idx': topk_idx,
            'topk_weights': topk_weights,
            'raw_hidden_shape': hidden_shape,
            'moe_type': MoeType.DSAsyncDecode if is_decoding else MoeType.DSAsyncPrefill,
        }

        self.mlp.experts.before_dispatch(state)

        # yield for attn1, dis (+share)
        yield
        recv_state = self.mlp.experts.dispatch(state)
        if self.mlp.shared_experts is not None and is_decoding:
            shared_states = self.mlp.shared_experts(hidden_states)
        # yield for dis, dis_wait
        yield
        self.mlp.experts.wait(recv_state)
        # yield for dis_wait, moe
        yield
        gemm_state = self.mlp.experts.gemm(recv_state)
        # yield for moe, comb
        yield
        out_state = self.mlp.experts.combine(gemm_state)
        # yield for comb, (+share) comb_wait
        yield
        if self.mlp.shared_experts is not None and not is_decoding:
            shared_states = self.mlp.shared_experts(hidden_states)
        self.mlp.experts.wait(out_state)
        # yield for (+share) comb_wait, (+share) attn0
        yield
        out_hidden_states = out_state['hidden_states'].view(hidden_shape)
        if shared_states is not None:
            out_hidden_states += shared_states
        elif self.mlp.shared_experts is not None:
            shared_states = self.mlp.shared_experts(hidden_states)
            out_hidden_states += shared_states
        else:
            pass
        out_hidden_states = out_hidden_states.reshape(batch_size, sequence_length, -1)
        outputs = (out_hidden_states, residual)
        return outputs


class DeepseekV2Model(nn.Module):
    """Mixtral model."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = ParallelEmbedding(config.vocab_size,
                                              config.hidden_size,
                                              self.padding_idx,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

        if get_dist_manager().current_context().dist_config.enable_eplb:
            ep_size_, _ = get_ep_world_rank()
            EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)
        self.layers = nn.ModuleList([
            DeepseekV2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device)

        emb_type = RopeType.LinearScaling
        rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
                                                                                     config.num_attention_heads)
        rope_max_pos_emb = config.max_position_embeddings
        rope_base = get_rope_theta(config)

        rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
        update_params = build_rotary_params(config)
        rope_params.update(update_params)
        self.rotary_emb = build_rotary_embedding(**rope_params)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds
        residual = None
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def forward_microbatch(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward_microbatch."""
        assert self.config.moe_layer_freq == 1
        moe_start_idx = min(self.config.first_k_dense_replace, len(self.layers))

        # embed and mlplayers
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds
        residual = None
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        for idx, decoder_layer in enumerate(self.layers[:moe_start_idx]):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        if moe_start_idx < len(self.layers):
            # run two micro batch
            num = 2
            input_list, exec_type, delta_stages, extern_tag = split_input(hidden_states,
                                                                          rotary_pos_emb,
                                                                          past_key_values,
                                                                          residual,
                                                                          attn_metadata,
                                                                          moe_start_idx,
                                                                          len(self.layers),
                                                                          num=num)

            output_list = execute_batch(inputs=input_list,
                                        fn=self.forward_yieldlayers,
                                        delta_stages=delta_stages,
                                        exec_type=exec_type,
                                        extern_tag=extern_tag)
            hidden_states, residual = merge_output(output_list)

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def forward_yieldlayers(self,
                            hidden_states: torch.Tensor,
                            rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
                            past_key_values: Optional[List[torch.FloatTensor]] = None,
                            residual: Optional[torch.Tensor] = None,
                            attn_metadata: Any = None,
                            start_idx: int = -1,
                            end_idx: int = -1,
                            tag: Any = None):
        """forward_yieldlayers."""
        for idx in range(start_idx, end_idx):
            past_key_value = past_key_values[idx]
            hidden_states, residual = yield from self.layers[idx].forward_yield(hidden_states,
                                                                                rotary_pos_emb=rotary_pos_emb,
                                                                                past_key_value=past_key_value,
                                                                                residual=residual,
                                                                                attn_metadata=attn_metadata,
                                                                                tag=tag)
        return hidden_states, residual

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class DeepseekV2ForCausalLM(nn.Module, CudaGraphMixin):
    """Mixture model for causalLM."""

    def __init__(self,
                 config: Any,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.quantization_config = getattr(config, 'quantization_config', None)
        self.dtype = dtype
        self.ctx_mgr = ctx_mgr
        self.model = DeepseekV2Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)
        self._load_buffers = dict()

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        if get_step_ctx_manager().current_context().enable_microbatch:
            hidden_states = self.model.forward_microbatch(
                input_ids=input_ids,
                position_ids=position_ids,
                past_key_values=past_key_values,
                attn_metadata=attn_metadata,
                inputs_embeds=inputs_embeds,
            )
        else:
            hidden_states = self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                past_key_values=past_key_values,
                attn_metadata=attn_metadata,
                inputs_embeds=inputs_embeds,
            )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                               update_pe_mapping: List):
        """Load weight attention."""
        device = next(iter(params_dict.values())).device

        def __update_pe(weight, head_dim: int, pe_dim_offset: int):
            # (num_heads, q_head_dim, input_dim)
            weight = weight.unflatten(0, (-1, head_dim))
            # (num_heads, nope_head_dim, input_dim)
            w_pe = weight[:, pe_dim_offset:]
            # (num_heads, nope_head_dim//2, 2, input_dim)
            new_w_pe = w_pe.unflatten(1, (-1, 2))
            # (num_heads, nope_head_dim, input_dim)
            new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2)
            weight[:, pe_dim_offset:] = new_w_pe
            weight = weight.flatten(0, 1)
            return weight

        def __load_kcvc(name: str, weight: torch.Tensor):
            """Load kc and vc from weight."""
            config = self.config
            v_head_dim = config.v_head_dim
            qk_nope_head_dim = config.qk_nope_head_dim
            w_kc, w_vc = weight.unflatten(0, (-1, qk_nope_head_dim + v_head_dim)).split([qk_nope_head_dim, v_head_dim],
                                                                                        dim=1)
            w_vc = w_vc.transpose(1, 2).contiguous()
            kc_param_name = name.replace('.kv_b_proj', '.kc')
            param_kc = params_dict[kc_param_name]
            load_weight(param_kc, w_kc)
            vc_param_name = name.replace('.kv_b_proj', '.vc')
            param_vc = params_dict[vc_param_name]
            load_weight(param_vc, w_vc)

        def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype):
            """Dequant weight."""
            dim_w0, dim_w1 = weight.shape
            dim_s0, dim_s1 = scale.shape
            assert dim_w0 % dim_s0 == 0
            assert dim_w1 % dim_s1 == 0
            group0 = dim_w0 // dim_s0
            group1 = dim_w1 // dim_s1
            weight = weight.reshape(dim_s0, group0, dim_s1, group1)
            scale = scale.reshape(dim_s0, 1, dim_s1, 1)
            weight = weight.to(scale.dtype) * scale
            weight = weight.to(dtype)
            weight = weight.reshape(dim_w0, dim_w1)
            return weight

        def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
            """Dequant weight."""
            if name.endswith('.weight'):
                weight_name = name
                scale_name = name.replace('.weight', '.scale')
            elif name.endswith('.weight_scale_inv'):
                weight_name = name.replace('.weight_scale_inv', '.weight')
                scale_name = name
            self._load_buffers[name] = loaded_weight
            if (weight_name in self._load_buffers and scale_name in self._load_buffers):
                weight = self._load_buffers.pop(weight_name)
                scale = self._load_buffers.pop(scale_name)
                kc_param_name = weight_name.replace('.kv_b_proj', '.kc')
                dtype = params_dict[kc_param_name].dtype
                weight = __dequant_weight(weight, scale, dtype)
                __load_kcvc(weight_name, weight)

        for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:
            if mod_name not in name:
                continue
            if name.endswith('.weight_scale_inv'):
                weight = loaded_weight
            else:
                loaded_weight = loaded_weight.to(device)
                weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
            param = params_dict[name]
            load_weight(param, weight)
            break
        else:
            if '.kv_b_proj' in name:
                quantization_config = self.quantization_config
                quant_method = None
                if quantization_config is not None:
                    quant_method = quantization_config.get('quant_method')

                loaded_weight = loaded_weight.to(device)
                if quant_method == 'fp8':
                    # update blocked fp8 weight
                    __load_kcvc_blocked_fp8(name, loaded_weight)
                else:
                    __load_kcvc(name, loaded_weight)
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_nextn(name, nextn_keys):
            for nextn_key in nextn_keys:
                if nextn_key in name:
                    return True
            return False

        def __skip_layers():
            """We might change the number of layers so we can debug the model
            with less gpus."""
            import re
            matches = re.findall(r'\.layers\.(\d+)\.', name)
            layer_id = int(matches[0])
            return layer_id >= self.config.num_hidden_layers

        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        config = self.config

        update_pe_mapping = []
        if getattr(config, 'use_mla', True):
            qk_rope_head_dim = config.qk_rope_head_dim
            kv_lora_rank = config.kv_lora_rank
            qk_nope_head_dim = config.qk_nope_head_dim
            q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
            kv_dim = kv_lora_rank + qk_rope_head_dim
            update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), ('q_b_proj', q_head_dim, qk_nope_head_dim),
                                 ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)]
        else:
            # deepseek-vl2-tiny uses MHA LlamaAttention, weight loading differs from MLA
            stacked_params_mapping.extend([
                # (param_name, shard_name, shard_id)
                ('.qkv_proj', '.q_proj', 'q'),
                ('.qkv_proj', '.k_proj', 'k'),
                ('.qkv_proj', '.v_proj', 'v'),
            ])

        num_experts = self.config.n_routed_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        num_hidden_layers = self.config.num_hidden_layers

        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)
        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if '.layers' in name:
                # skip nextn
                if __skip_nextn(name, nextn_keys):
                    continue

                if __skip_layers():
                    continue

            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            elif '.self_attn' in name and getattr(config, 'use_mla', True):
                # attention
                self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping)
            else:
                # other
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/deepseek_v32.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Sequence, Tuple

import torch
import torch.nn.functional as F
from torch import nn

from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank
from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, build_rotary_embedding,
                                 build_rotary_params)
from lmdeploy.pytorch.nn.eplb import EPLBManager
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.nsa import IndexerTopKFP8
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta

from .deepseek_v2 import (DeepseekV2Attention, DeepseekV2BMM, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
                          DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, yarn_get_mscale)


def rotate_activation(x: torch.Tensor) -> torch.Tensor:
    assert x.dtype == torch.bfloat16
    from fast_hadamard_transform import hadamard_transform
    hidden_size = x.size(-1)
    return hadamard_transform(x, scale=hidden_size**-0.5)


class LayerNorm(nn.Module):
    """Layer Normalization."""

    def __init__(self, dim: int, eps: float = 1e-6, device: torch.device = None):
        super().__init__()
        if device is None:
            device = 'cuda'
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32, device=device))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32, device=device))

    def forward(self, x: torch.Tensor):
        return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias, self.eps).type_as(x)


class Indexer(nn.Module):

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        try:
            import fast_hadamard_transform  # noqa: F401
        except ImportError:
            raise ImportError('Please install fast_hadamard_transform package.')
        quant_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        # self.dim: int = 2048
        self.dim: int = config.hidden_size
        self.n_heads: int = config.index_n_heads
        self.n_local_heads = config.index_n_heads
        self.head_dim: int = config.index_head_dim
        self.rope_head_dim: int = config.qk_rope_head_dim
        self.index_topk: int = config.index_topk
        self.q_lora_rank: int = config.q_lora_rank
        self.wq_b = build_colwise_linear(self.q_lora_rank,
                                         self.n_heads * self.head_dim,
                                         bias=False,
                                         dtype=dtype,
                                         device=device,
                                         is_tp=False,
                                         quant_config=quant_config)
        self.wk = build_colwise_linear(self.dim,
                                       self.head_dim,
                                       bias=False,
                                       dtype=dtype,
                                       device=device,
                                       is_tp=False,
                                       quant_config=quant_config)
        self.k_norm = LayerNorm(self.head_dim, device=device)
        self.weights_proj = build_colwise_linear(self.dim,
                                                 self.n_heads,
                                                 bias=False,
                                                 dtype=dtype,
                                                 device=device,
                                                 is_tp=False)
        self.softmax_scale = self.head_dim**-0.5
        self.apply_rotary_pos_emb = ApplyRotaryEmb()
        self.indexer_topk = IndexerTopKFP8(self.index_topk, self.softmax_scale, block_size=128, fill=-1)

    def forward(self,
                x: torch.Tensor,
                qr: torch.Tensor,
                freqs_cis: torch.Tensor,
                index_cache: Tuple[torch.Tensor, torch.Tensor],
                attn_metadata: Any = None):
        q = self.wq_b(qr)
        q = q.unflatten(-1, (-1, self.head_dim))
        q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
        k = self.wk(x)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)

        # apply rotary embedding
        cos, sin = freqs_cis
        q_pe, k_pe = self.apply_rotary_pos_emb(
            q_pe,
            k_pe[..., None, :],
            cos,
            sin,
            inplace=False,
        )
        k_pe = k_pe[0, :]
        k_nope = k_nope[0, :, None]
        q = torch.cat([q_pe, q_nope], dim=-1)
        k = torch.cat([k_pe, k_nope], dim=-1)
        q = rotate_activation(q)
        k = rotate_activation(k)

        weights = self.weights_proj(x) * self.n_heads**-0.5

        return self.indexer_topk(q[0], k[:, 0], weights[0], index_cache[0], index_cache[1], attn_metadata=attn_metadata)


class DeepseekV32Attention(DeepseekV2Attention):

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        nn.Module.__init__(self)
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)
        self.q_lora_rank = config.q_lora_rank
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
        use_flash_mla = getattr(config, 'use_flash_mla', False)

        if self.q_lora_rank is None:
            self.q_proj = build_colwise_linear(
                self.hidden_size,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )
        else:
            self.q_a_proj = build_colwise_linear(
                self.hidden_size,
                config.q_lora_rank,
                bias=config.attention_bias,
                dtype=dtype,
                device=device,
                is_tp=False,
                quant_config=quantization_config,
            )
            self.q_a_layernorm = RMSNorm(config.q_lora_rank,
                                         1e-6,
                                         quant_config=quantization_config,
                                         dtype=torch.float32,
                                         device=device)
            self.q_b_proj = build_colwise_linear(
                config.q_lora_rank,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
                quant_config=quantization_config,
                dp_disable_tp=True,
            )

        self.kv_a_proj_with_mqa = build_colwise_linear(
            self.hidden_size,
            config.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=quantization_config,
        )
        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
                                      1e-6,
                                      quant_config=quantization_config,
                                      dtype=torch.float32,
                                      device=device)
        self.kc = DeepseekV2BMM(self.num_heads,
                                config.qk_nope_head_dim,
                                config.kv_lora_rank,
                                dtype=dtype,
                                device=device)

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        self.softmax_scale = self.q_head_dim**(-0.5)

        rope_scaling = get_rope_parameters(config)
        if rope_scaling is not None:
            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
            if mscale_all_dim:
                scaling_factor = rope_scaling['factor']
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.softmax_scale = self.softmax_scale * mscale * mscale

        self.attn_fwd = Attention(self.num_heads,
                                  config.kv_lora_rank + self.qk_rope_head_dim,
                                  scale=self.softmax_scale,
                                  num_kv_heads=num_key_value_heads,
                                  v_head_size=config.kv_lora_rank,
                                  num_replicate_kv_heads=num_replicate_kv_heads,
                                  use_flash_mla=use_flash_mla)

        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)
        self.o_proj = build_o_proj(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=True,
            quant_config=quantization_config,
        )

        self.indexer = Indexer(config, layer_idx, dtype=dtype, device=device)

    def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int):
        """Q proj."""
        q_len = hidden_states.size(1)

        query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size)

        if self.q_lora_rank is None:
            qr = hidden_states
            q = self.q_proj(hidden_states)
        else:
            qr = self.q_a_layernorm(self.q_a_proj(hidden_states))
            q = self.q_b_proj(qr)
        q = q.view(q_len, num_heads, self.q_head_dim)
        # q_pe: (q_len, num_heads, qk_rope_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # q_nope: (q_len, num_heads, kv_lora_rank)
        q_nope_out = query_states[..., :nope_size]
        self.kc(q_nope, q_nope_out)
        return query_states, q_pe, qr

    def _kv_proj(self, hidden_states, nope_size: int):
        """Kv proj."""
        # (q_len, 1, nope_size + pe_size)
        key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None])
        # (q_len, 1, pe_size)
        k_pe = key_states[..., nope_size:]
        # kv_a_layernorm
        value_states = key_states[..., :nope_size]
        value_states = self.kv_a_layernorm(value_states)
        key_states[..., :nope_size] = value_states
        return key_states, value_states, k_pe

    def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int):
        """Qkv proj."""
        nope_size = self.kv_lora_rank
        pe_size = self.qk_rope_head_dim
        query_states, q_pe, qr = self._q_proj(hidden_states, num_heads, nope_size, pe_size)
        key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size)

        return query_states, key_states, value_states, q_pe, k_pe, qr

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Sequence[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        dist_ctx = get_dist_manager().current_context()
        tp_world_size = dist_ctx.dist_config.attn_tp
        num_heads = self.num_heads // tp_world_size
        nope_size = self.kv_lora_rank
        q_len = hidden_states.size(1)

        # qkv_proj
        query_states, key_states, value_states, q_pe, k_pe, qr = self._qkv_proj(hidden_states, num_heads=num_heads)

        cos, sin = rotary_pos_emb
        q_pe, k_pe = self.apply_rotary_pos_emb(
            q_pe,
            k_pe,
            cos,
            sin,
            inplace=False,
        )
        query_states[..., nope_size:] = q_pe
        key_states[..., nope_size:] = k_pe

        topk_indices = self.indexer(hidden_states, qr, rotary_pos_emb, past_key_value[-2:], attn_metadata=attn_metadata)

        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[0][..., :nope_size],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            nsa_indices=topk_indices,
        )
        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)

        self.vc(attn_output, attn_bmm_out)
        attn_output = attn_bmm_out.flatten(-2, -1)[None]
        attn_output = self.o_proj(attn_output)

        return attn_output


class DeepseekV32DecoderLayer(DeepseekV2DecoderLayer):

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        nn.Module.__init__(self)
        self.layer_idx = layer_idx
        quantization_config = None

        # build attention layer
        if getattr(config, 'use_mla', True):
            self.self_attn = DeepseekV32Attention(config, layer_idx, dtype=dtype, device=device)
        else:
            # deepseek-vl2-tiny uses MHA LlamaAttention structure
            from lmdeploy.pytorch.models.llama import LlamaAttention
            self.self_attn = LlamaAttention(config, dtype=dtype, device=device)

        # mlp
        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if
                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=torch.float32,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                dtype=torch.float32,
                                                device=device)


class DeepseekV32Model(DeepseekV2Model):

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        nn.Module.__init__(self)
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)
        if get_dist_manager().current_context().dist_config.enable_eplb:
            ep_size_, _ = get_ep_world_rank()
            EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)
        self.layers = nn.ModuleList([
            DeepseekV32DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size,
                            config.rms_norm_eps,
                            quant_config=None,
                            dtype=torch.float32,
                            device=device)

        emb_type = RopeType.LinearScaling
        rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
                                                                                     config.num_attention_heads)
        rope_max_pos_emb = config.max_position_embeddings
        rope_base = get_rope_theta(config)

        rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
        update_params = build_rotary_params(config)
        rope_params.update(update_params)
        self.rotary_emb = build_rotary_embedding(**rope_params)


class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):

    def __init__(self,
                 config: Any,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        nn.Module.__init__(self)
        self.config = config
        self.quantization_config = getattr(config, 'quantization_config', None)
        self.dtype = dtype
        self.ctx_mgr = ctx_mgr
        self.model = DeepseekV32Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)
        self._load_buffers = dict()


================================================
FILE: lmdeploy/pytorch/models/deepseek_vl2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/main/deepseek_vl2/models/modeling_deepseek_vl_v2.py

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .deepseek_v2 import DeepseekV2ForCausalLM
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model


@vlm_model
class MlpProjector(nn.Module):

    def __init__(self, cfg, dtype):

        super().__init__()

        self.cfg = cfg

        if cfg.projector_type == 'identity':
            modules = nn.Identity()

        elif cfg.projector_type == 'linear':
            modules = nn.Linear(cfg.input_dim, cfg.n_embed, dtype=dtype)

        elif cfg.projector_type == 'mlp_gelu':
            mlp_depth = cfg.depth
            modules = [nn.Linear(cfg.input_dim, cfg.n_embed, dtype=dtype)]
            for _ in range(1, mlp_depth):
                modules.append(nn.GELU())
                modules.append(nn.Linear(cfg.n_embed, cfg.n_embed, dtype=dtype))
            modules = nn.Sequential(*modules)

        elif cfg.projector_type == 'downsample_mlp_gelu':
            mlp_depth = cfg.depth
            mlp_ratio = cfg.mlp_ratio
            modules = [
                nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio,
                          cfg.n_embed * mlp_ratio,
                          dtype=dtype)
            ]
            for _ in range(1, mlp_depth - 1):
                modules.append(nn.GELU())
                modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio, dtype=dtype))
            modules.append(nn.GELU())
            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed, dtype=dtype))
            modules = nn.Sequential(*modules)

        else:
            raise ValueError(f'Unknown projector type: {cfg.projector_type}')

        if cfg.token_pooling:
            self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim, dtype=dtype)

        self.layers = modules

    def forward(self, x):
        if self.cfg.token_pooling:
            batch_size, wxh, channels = x.shape
            w = h = int(wxh**0.5)
            x = x.view(batch_size, w, h, channels)
            x = x.permute(0, 3, 1, 2)
            patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
            batch_size, channels, h_patches, w_patches, _, _ = patches.size()
            # concatenate along the channel dimension
            patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)

            # pass through the linear layer
            patches = patches.permute(0, 2, 1, 3).contiguous()
            patches = patches.view(batch_size, h_patches * w_patches, channels * 4)

            x = self.token_pooling_layer(patches)

        elif self.cfg.projector_type == 'downsample_mlp_gelu':
            bs, hw, input_dim = x.shape
            h = w = int((hw)**0.5)
            """Compute padding."""
            if h % self.cfg.downsample_ratio:
                pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
            else:
                pad = 0
            x = x.reshape(bs, h, w, input_dim)
            if pad > 0:
                x = F.pad(x, (0, 0, 0, pad, 0, pad), 'constant', 0)
            """4 to 1 concat"""
            x = x.permute(0, 3, 1, 2)  # B, C, H, W
            x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio,
                         padding=0)  # B, C*4, HW // 4
            x = x.permute(0, 2, 1)

        return self.layers(x)


class DeepseekVLV2ForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # ----------- vision encoder ------------
        self.vision = self._init_vision_module(dtype=dtype)

        # ----------- vl projector ------------
        projector_config = config.projector_config
        self.projector = MlpProjector(projector_config, dtype)

        # image token format
        self.tile_tag = config.tile_tag
        self.global_view_pos = config.global_view_pos

        # special tokens used to format image token sequence
        embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))
        if self.tile_tag == '2D':
            # <|view_separator|>, <|\n|>
            self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
            # fix the typo: view_seperater
            self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
        elif self.tile_tag == '1D':
            # <|tile_x|>, <|tile_global|>
            candidate_resolutions = config.candidate_resolutions
            if len(candidate_resolutions) == 0:
                raise ValueError(
                    f'len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}')
            tile_variants_num = len(candidate_resolutions)
            self.tile_indicators = nn.Parameter(
                torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std)
        else:
            raise ValueError(f'tile tag should be either 1D or 2D, but got {self.tile_tag}')

        # ----------- language model ------------
        language_config = config.language_config
        self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

        #  ----------- input processor ------------
        self.input_processor = DeepSeekVLV2InputProcessor(config, dtype)

    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_vl2.py#L359
    def _init_vision_module(
        self,
        dtype: torch.dtype,
    ) -> nn.Module:
        try:
            import timm
        except ImportError:
            raise ImportError('Please install timm') from ImportError

        model = timm.create_model(
            'vit_so400m_patch14_siglip_384.webli',
            pretrained=False,
            num_classes=0,
            dynamic_img_size=True,
            dynamic_img_pad=True,
        )
        model = model.to(dtype=dtype)
        return model

    def prepare_inputs_embeds(self,
                              input_ids: torch.LongTensor,
                              images: Optional[torch.FloatTensor] = None,
                              images_seq_mask: Optional[torch.LongTensor] = None,
                              images_spatial_crop: Optional[torch.LongTensor] = None,
                              **ignore_kwargs):
        """

        Args:
            input_ids (torch.LongTensor): [b, T]
            images (torch.FloatTensor): [b, max_n_images, 3, height, width]
            images_seq_mask (torch.BoolTensor): [b, T]
            images_spatial_crop (torch.LongTensor): [b, max_n_images, 2]

        Returns:
            input_embeds (torch.Tensor): [b, T, D]
        """

        if images is None or images_spatial_crop.sum() == 0:
            return self.language.get_input_embeddings()(input_ids)

        bs, max_n_images, _ = images_spatial_crop.shape
        batch_num_tiles = [0 for _ in range(bs)]
        total_tiles = []
        for idx in range(bs):
            for jdx in range(max_n_images):
                num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
                if num_width_tiles == 0 or num_height_tiles == 0:
                    break
                batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)

            total_tiles.append(images[idx, :batch_num_tiles[idx]])

        # [batch_all_tiles, 3, height, width]
        total_tiles = torch.cat(total_tiles, dim=0)
        assert total_tiles.shape[0] == sum(batch_num_tiles)
        if total_tiles.shape[0] == 0:
            return self.language.get_input_embeddings()(input_ids)

        # [batch_all_tiles, vit_seq_len, c]
        images_feature = self.vision.forward_features(total_tiles)  # timm siglip forward_features

        # [batch_all_tiles, hw, D]
        images_embeds = self.projector(images_feature)
        _, hw, n_dim = images_embeds.shape
        h = w = int(hw**0.5)

        # put image tokens into the input_embeds, [b, T, D]
        input_embeds = self.language.get_input_embeddings()(input_ids)

        # fill image token sequence according to self.tile_tag & self.global_view_pos
        tile_index = 0
        for idx in range(images_spatial_crop.shape[0]):
            images_in_this_batch = []
            for jdx in range(images_spatial_crop.shape[1]):

                # extra global & local features
                num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
                if num_width_tiles == 0 or num_height_tiles == 0:
                    break

                num_tiles_in_image = num_width_tiles * num_height_tiles

                # [hw, D]
                global_features = images_embeds[tile_index]

                # [num_height_tiles * num_width_tiles, hw, D]
                local_features = images_embeds[tile_index + 1:tile_index + 1 + num_tiles_in_image]

                tile_index += num_tiles_in_image + 1

                # format global and local features
                if self.tile_tag == '2D':

                    # ----------------- global view add newline -----------------
                    # [hw, D] -> [h, w, D]
                    global_features = global_features.view(h, w, n_dim)
                    # [D]     -> [h, 1, D]
                    new_lines_in_global = repeat(self.image_newline, 'd -> h 1 d', h=h)
                    # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
                    global_features = torch.cat([global_features, new_lines_in_global], dim=1)
                    # [h, w + 1, D] -> [h * (w + 1), D]
                    global_features = global_features.view(-1, n_dim)

                    # ----------------- local view add newline -----------------
                    # [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D]
                    local_features = rearrange(local_features,
                                               '(th tw) (h w) d -> (th h) (tw w) d',
                                               th=num_height_tiles,
                                               tw=num_width_tiles,
                                               h=h,
                                               w=w)

                    # [D] -> [num_height_tiles * h, 1, D]
                    new_lines_in_local = repeat(self.image_newline, 'd -> (th h) 1 d', th=num_height_tiles, h=h)

                    # [num_height_tiles * h, num_width_tiles * w + 1, D]
                    local_features = torch.cat([local_features, new_lines_in_local], dim=1)

                    # [num_height_tiles * h, num_width_tiles * w + 1, D]
                    #   --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
                    local_features = local_features.view(-1, n_dim)

                    # ----------------- merge global and local tiles -----------------
                    if self.global_view_pos == 'head':
                        global_local_features = torch.cat(
                            [global_features, self.view_seperator[None, :], local_features], dim=0)
                    else:
                        global_local_features = torch.cat(
                            [local_features, self.view_seperator[None, :], global_features], dim=0)

                else:
                    # abandoned,will not step into this logic
                    global_features = torch.cat([self.tile_indicators[0:1], global_features], dim=0)
                    local_features = torch.cat(
                        [self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1)
                    local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d')

                    if self.global_view_pos == 'head':
                        global_local_features = torch.cat([global_features, local_features], dim=0)
                    else:
                        global_local_features = torch.cat([local_features, global_features], dim=0)

                images_in_this_batch.append(global_local_features)

            if len(images_in_this_batch) > 0:
                images_in_this_batch = torch.cat(images_in_this_batch, dim=0).to(input_embeds.dtype)
                crt_image_mask = images_seq_mask[idx].unsqueeze(-1).to(input_embeds.device)
                input_embeds[idx].masked_scatter_(crt_image_mask, images_in_this_batch)

        return input_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        images_spatial_crop: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        # process image embeddings
        if inputs_embeds is None and pixel_values is not None:
            inputs_embeds = self.prepare_inputs_embeds(input_ids=input_ids,
                                                       images=pixel_values,
                                                       images_seq_mask=image_mask,
                                                       images_spatial_crop=images_spatial_crop)

        outputs = self.language.forward(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
        )
        return outputs

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.language.get_logits(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # vision inputs
        pixel_values = None
        images_spatial_crop = None
        image_mask = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            images_spatial_crop = [p_value[0].meta.get('images_spatial_crop', None) for p_value in pixel_values]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values]).unsqueeze(0)
            else:
                pixel_values = None
                image_mask = None

            if len(images_spatial_crop) > 0:
                images_spatial_crop = torch.cat([crop for crop in images_spatial_crop]).unsqueeze(0)
            else:
                images_spatial_crop = None

        return dict(
            input_ids=input_ids,  # [b, T]
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            pixel_values=pixel_values,  # [b, max_n_images, 3, height, width]
            images_spatial_crop=images_spatial_crop,  # [b, max_n_images, 2]
            image_mask=image_mask,  # [b, T]
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        lang_prefix = 'language.'
        lang_prefix_length = len(lang_prefix)
        new_weights = dict()
        params_dict = dict(self.named_parameters())

        for name, loaded_weight in weights:
            if name.startswith(lang_prefix):
                new_key = name[lang_prefix_length:]
                new_weights[new_key] = loaded_weight
                continue

            if 'qkv' in name and 'vision' not in name:
                param = params_dict[name]
                q, k, v = param.weight_spliter(loaded_weight)
                load_weight(param, q, shard_id='q')
                load_weight(param, k, shard_id='k')
                load_weight(param, v, shard_id='v')
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        self.language.load_weights(new_weights.items())

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class DeepSeekVLV2InputProcessor(BaseModelInputProcessor):
    """Deepseek-vl2 input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype
        vision_config = config.vision_config
        self.patch_size = vision_config.patch_size

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            images_spatial_crop = input_mm.get('images_spatial_crop', None)
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(
                                         image_token_id=image_token_id,
                                         images_spatial_crop=images_spatial_crop,
                                     ))

            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )

        return result


================================================
FILE: lmdeploy/pytorch/models/gemma.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from typing import Any, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, GeluAndMul, RMSNorm, RopeType, build_rotary_embedding,
                                 build_rotary_embedding_from_config)
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class GemmaAttention(nn.Module):
    """Rewrite module of GemmaAttention."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = config.head_dim
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.attention_bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()
        self.model_type = config.model_type

        # attention
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_kv_heads = num_key_value_heads
        self.scaling = 1 / math.sqrt(config.head_dim)
        if hasattr(config, 'query_pre_attn_scalar'):
            self.scaling = config.query_pre_attn_scalar**-0.5
        if self.model_type == 'gemma3_text':
            sliding_window_pattern = getattr(config, 'sliding_window_pattern', 6)
            is_sliding = bool((layer_idx + 1) % sliding_window_pattern)
            self.sliding_window = (getattr(config, 'sliding_window', -1) if is_sliding else -1)
        else:
            self.sliding_window = (getattr(config, 'sliding_window', -1) if not bool(layer_idx % 2) else -1)
        logit_softcapping = getattr(config, 'attn_logit_softcapping', 0.0)
        if logit_softcapping is None:
            logit_softcapping = 0.0
        self.attn_fwd = Attention(num_heads,
                                  head_dim,
                                  scale=self.scaling,
                                  num_kv_heads=num_key_value_heads,
                                  sliding_window=self.sliding_window,
                                  logit_softcapping=logit_softcapping)

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

        if self.model_type == 'gemma3_text':
            self.q_norm = RMSNorm(config.head_dim,
                                  config.rms_norm_eps,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device)
            self.k_norm = RMSNorm(config.head_dim,
                                  config.rms_norm_eps,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        rotary_pos_emb_local: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
        global_attn_masks: torch.Tensor = None,
        local_attn_masks: torch.Tensor = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        if self.model_type == 'gemma3_text':
            query_states = self.q_norm(query_states)
            key_states = self.k_norm(key_states)

        # apply rotary embedding
        if rotary_pos_emb_local is not None and self.sliding_window != -1:
            cos, sin = rotary_pos_emb_local
        else:
            cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        gemma3_naive_attn_with_masks = global_attn_masks is not None and local_attn_masks is not None
        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=not gemma3_naive_attn_with_masks,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # gemma3 VL applied different attn masks
        # intentionally compute attn twice to fill kv cache
        if gemma3_naive_attn_with_masks is True:
            attn_masks = local_attn_masks if self.sliding_window > 0 else global_attn_masks

            attn_output = self.naive_attn_with_masks(query_states,
                                                     key_states,
                                                     value_states,
                                                     out=attn_output,
                                                     attn_masks=attn_masks,
                                                     seq_lens=attn_metadata.q_seqlens)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output

    # adapted from https://github.com/vllm-project/vllm/blob/5eeabc2a4400fde9b030f2f72746a2b03db059bd/vllm/model_executor/models/gemma3.py#L218  # noqa
    def naive_attn_with_masks(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        out: torch.Tensor,
        attn_masks: torch.Tensor,
        seq_lens: torch.Tensor,
    ) -> torch.Tensor:
        q_len = q.shape[0]
        q = q.view(q_len, -1, self.head_dim)
        # Expand the key and value to handle GQA.
        num_queries_per_kv = self.num_heads // self.num_kv_heads
        k = k.view(q_len, -1, self.head_dim)
        k = k.repeat_interleave(num_queries_per_kv, dim=-2)
        v = v.view(q_len, -1, self.head_dim)
        v = v.repeat_interleave(num_queries_per_kv, dim=-2)

        start_idx = 0
        for seq_len, attn_mask in zip(seq_lens, attn_masks):
            end_idx = start_idx + seq_len
            query = q[start_idx:end_idx].unsqueeze(0)
            key = k[start_idx:end_idx].unsqueeze(0)
            value = v[start_idx:end_idx].unsqueeze(0)

            # Transpose.
            query = query.transpose(1, 2)
            key = key.transpose(1, 2)
            value = value.transpose(1, 2)

            output = F.scaled_dot_product_attention(
                query,
                key,
                value,
                attn_mask,
                self.scaling,
            )
            output = output.transpose(1, 2).flatten(-2, -1)
            out[start_idx:end_idx] = output
            start_idx = end_idx
        return out


class GemmaMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        hidden_activation = getattr(config, 'hidden_activation', None)
        if hidden_activation is None:
            hidden_activation = 'gelu_pytorch_tanh'
            assert hidden_activation == 'gelu_pytorch_tanh'
        self.act_fn = GeluAndMul(approximate='tanh')

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        out = self.down_proj(act)
        return out


class GemmaDecoderLayer(nn.Module):
    """Llama decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = GemmaAttention(config, layer_idx, dtype=dtype, device=device)

        # build MLP
        self.mlp = GemmaMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

        self.model_type = config.model_type
        if self.model_type in ('gemma2', 'gemma3_text'):
            self.pre_feedforward_layernorm = RMSNorm(config.hidden_size,
                                                     config.rms_norm_eps,
                                                     quant_config=quantization_config,
                                                     dtype=dtype,
                                                     device=device)
            self.post_feedforward_layernorm = RMSNorm(config.hidden_size,
                                                      config.rms_norm_eps,
                                                      quant_config=quantization_config,
                                                      dtype=dtype,
                                                      device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        rotary_pos_emb_local: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        global_attn_masks: torch.Tensor = None,
        local_attn_masks: torch.Tensor = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            rotary_pos_emb_local=rotary_pos_emb_local,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
            global_attn_masks=global_attn_masks,
            local_attn_masks=local_attn_masks,
        )

        # Fully Connected

        if self.model_type in ('gemma2', 'gemma3_text'):
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states, residual = self.pre_feedforward_layernorm(hidden_states, residual)
            hidden_states = self.mlp(hidden_states)
            hidden_states = self.post_feedforward_layernorm(hidden_states)
        else:
            hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
            hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Gemma3TextScaledWordEmbedding(nn.Embedding):
    """This module overrides nn.Embeddings' forward by multiplying with
    embeddings scale."""

    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 padding_idx: int,
                 dtype=torch.dtype,
                 embed_scale: Optional[float] = 1.0):
        super().__init__(num_embeddings, embedding_dim, padding_idx, dtype=dtype)
        self.embed_scale = embed_scale

    def forward(self, input_ids: torch.Tensor):
        return super().forward(input_ids) * self.embed_scale


class GemmaModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.model_type = config.model_type
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        if self.config.model_type == 'gemma3_text':
            self.embed_tokens = Gemma3TextScaledWordEmbedding(config.vocab_size,
                                                              config.hidden_size,
                                                              self.padding_idx,
                                                              dtype=dtype,
                                                              embed_scale=config.hidden_size**0.5)
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size,
                                             config.hidden_size,
                                             self.padding_idx,
                                             dtype=dtype,
                                             device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            GemmaDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.build_rope_emb(config)

    def build_rope_emb(self, config: PretrainedConfig):
        rope_dim = config.head_dim
        rope_max_pos_emb = config.max_position_embeddings

        if self.model_type != 'gemma3_text':
            self.rotary_emb = build_rotary_embedding_from_config(config)
            return

        # for gemma3
        if hasattr(config, 'rope_local_base_freq'):
            rope_base = config.rope_local_base_freq
            self.rotary_emb = build_rotary_embedding_from_config(config)

            if self.model_type == 'gemma3_text':
                self.rotary_emb_local = build_rotary_embedding(
                    rope_dim,
                    rope_max_pos_emb,
                    rope_base,
                    emb_type=RopeType.Default,
                )
        else:
            # for transformers>=5
            rope_dim = config.head_dim
            from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters
            rope_parameters = get_rope_parameters(config)
            full_attention = rope_parameters['full_attention']
            sliding_attention = rope_parameters['sliding_attention']
            # note that emb type has been fixed.
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                base=full_attention['rope_theta'],
                scaling_factor=full_attention['factor'],
                emb_type=RopeType.LinearScaling,
            )
            self.rotary_emb_local = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                base=sliding_attention['rope_theta'],
                emb_type=RopeType.Default,
            )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        global_attn_masks: torch.Tensor = None,
        local_attn_masks: torch.Tensor = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds
        if self.model_type != 'gemma3_text':
            hidden_states = hidden_states * (self.config.hidden_size**0.5)

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)
        rotary_pos_emb_local = None
        if self.model_type == 'gemma3_text':
            cos_local, sin_local = self.rotary_emb_local(hidden_states, position_ids)
            cos_local, sin_local = cos_local[0], sin_local[0]
            rotary_pos_emb_local = (cos_local, sin_local)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                rotary_pos_emb_local=rotary_pos_emb_local,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
                global_attn_masks=global_attn_masks,
                local_attn_masks=local_attn_masks,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class GemmaForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = GemmaModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)
        self.final_logit_softcapping = getattr(config, 'final_logit_softcapping', None)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        global_attn_masks: torch.Tensor = None,
        local_attn_masks: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            global_attn_masks=global_attn_masks,
            local_attn_masks=local_attn_masks,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        logits = self.lm_head(hidden_states)
        if self.final_logit_softcapping is not None:
            logits = logits / self.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.final_logit_softcapping
        return logits

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def update_weights(self):
        """Update weights."""
        self.lm_head.weight = self.model.embed_tokens.weight

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]
        norm_layers = [
            '.norm', '.input_layernorm', '.post_attention_layernorm', 'pre_feedforward_layernorm',
            'post_feedforward_layernorm', 'q_norm', 'k_norm'
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if 'lm_head' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                for weight_name in norm_layers:
                    if weight_name not in name:
                        continue
                    loaded_weight += 1
                    break
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/gemma3_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import build_model_from_hf_config
from .siglip import SiglipVisionModel
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin


class Gemma3RMSNorm(nn.Module):

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f'{tuple(self.weight.shape)}, eps={self.eps}'


class Gemma3MultiModalProjector(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        self.mm_input_projection_weight = nn.Parameter(
            torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size, dtype=dtype, device=device))

        self.mm_soft_emb_norm = Gemma3RMSNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)

        self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
        self.kernel_size = self.patches_per_image // self.tokens_per_side
        self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)

    def forward(self, vision_outputs: torch.Tensor):
        batch_size, _, seq_length = vision_outputs.shape

        reshaped_vision_outputs = vision_outputs.transpose(1, 2)
        reshaped_vision_outputs = reshaped_vision_outputs.reshape(batch_size, seq_length, self.patches_per_image,
                                                                  self.patches_per_image)
        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()

        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
        pooled_vision_outputs = pooled_vision_outputs.flatten(2)
        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)

        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)

        projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
        return projected_vision_outputs.type_as(vision_outputs)


class Gemma3VLInputProcessor(BaseModelInputProcessor):
    """Internvl input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

        vision_config = config.vision_config
        self.image_size = vision_config.image_size
        self.patch_size = vision_config.patch_size
        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1
        self.vision_token_num = self.num_patches // 4

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


class Gemma3ForConditionalGeneration(nn.Module, CudaGraphMixin, DeployModelMixin):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        text_config = config.text_config
        self.sliding_window = text_config.sliding_window
        self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device)
        self.vision_tower = SiglipVisionModel(config=config.vision_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
        self.multi_modal_projector = Gemma3MultiModalProjector(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
        self.input_processor = Gemma3VLInputProcessor(self.config, dtype=dtype)

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.language_model.get_logits(hidden_states)

    def get_image_features(self, pixel_values: torch.Tensor):
        """Projects the last hidden state from the vision model into language
        model space.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        """
        vision_outputs = self.vision_tower(pixel_values=pixel_values)
        image_features = self.multi_modal_projector(vision_outputs)
        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.FloatTensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        vision_embedding_indexing: torch.Tensor = None,
        text_embedding_indexing: torch.Tensor = None,
        **kwargs,
    ):
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to
                 `-100` are ignored (masked), the loss is only computed for the tokens with labels in
                 `[0, ..., config.text_config.vocab_size]`.

            logits_to_keep (`int` or `torch.Tensor`, *optional*):
                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only
                for that token can save memory, which becomes pretty significant for long sequences or large vocabulary
                size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length
                dimension. This is useful when using packed tensor format (single dimension for batch and
                sequence length).
        """

        if inputs_embeds is None and pixel_values is not None:
            # extract feature
            vit_embeds = self.get_image_features(pixel_values)
            lang_embeds = self.get_input_embeddings()(input_ids)
            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)

            inputs_embeds = lang_embeds
        if pixel_values is not None:
            kwargs = self.prepare_attn_masks(input_ids[0], position_ids[0], mask_dtype=pixel_values.dtype, **kwargs)

        hidden_states = self.language_model(input_ids,
                                            position_ids,
                                            inputs_embeds=inputs_embeds,
                                            past_key_values=past_key_values,
                                            attn_metadata=attn_metadata,
                                            **kwargs)

        return hidden_states

    # modified from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py#L539
    def prepare_attn_masks(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mask_dtype: torch.dtype,
        **kwargs,
    ):
        kwargs['has_images'] = True
        start_idices = (positions == 0).cpu().nonzero()
        num_seqs = len(start_idices)
        seq_lens = []
        for i in range(num_seqs):
            start_idx = start_idices[i].item()
            if i < num_seqs - 1:
                end_idx = start_idices[i + 1].item()
            else:
                end_idx = len(input_ids)
            seq_lens.append(end_idx - start_idx)
        kwargs['seq_lens'] = seq_lens

        global_attn_masks = []
        local_attn_masks = []
        start_idx = 0
        for seq_len in seq_lens:
            end_idx = start_idx + seq_len
            input_token_ids = input_ids[start_idx:end_idx]
            start_idx = end_idx
            # Create a global causal mask.
            global_attn_mask = torch.empty(
                1,
                1,
                seq_len,
                seq_len,
                dtype=mask_dtype,
                device=input_ids.device,
            )
            global_attn_mask.fill_(float('-inf'))
            # Fill the lower triangle with 0.
            global_attn_mask = global_attn_mask.triu(diagonal=1)

            # Consider the bidirectional attention between image tokens.
            img_mask = torch.zeros_like(global_attn_mask)
            img_pos = (input_token_ids == self.config.image_token_index)
            img_mask[:, :, :, img_pos] += 1
            img_mask[:, :, img_pos, :] += 1
            global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
            global_attn_masks.append(global_attn_mask)

            # Create a local causal mask with sliding window (1024).
            local_attn_mask = torch.ones_like(global_attn_mask)
            local_attn_mask = torch.tril(local_attn_mask, diagonal=-self.sliding_window)
            local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float('-inf'))
            local_attn_masks.append(local_attn_mask)
        kwargs['global_attn_masks'] = global_attn_masks
        kwargs['local_attn_masks'] = local_attn_masks
        return kwargs

    def prepare_inputs_for_generation(
        self,
        past_key_values=None,
        inputs_embeds=None,
        context: StepContext = None,
        **kwargs,
    ):
        # Overwritten -- custom `position_ids` and `pixel_values` handling
        model_inputs = self.language_model.prepare_inputs_for_generation(
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            context=context,
            **kwargs,
        )

        # vision inputs
        pixel_values = None
        image_mask = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = model_inputs['input_ids'] == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values])
            else:
                pixel_values = None
                image_mask = None
        model_inputs['image_mask'] = image_mask
        model_inputs['pixel_values'] = pixel_values
        return model_inputs

    def tie_weights(self):
        return self.language_model.tie_weights()

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        lang_prefix = 'language_model.'
        lang_prefix_length = len(lang_prefix)
        new_weights = dict()
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if name.startswith(lang_prefix):
                new_key = name[lang_prefix_length:]
                new_weights[new_key] = loaded_weight
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        self.language_model.load_weights(new_weights.items())

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


================================================
FILE: lmdeploy/pytorch/models/glm4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class Glm4Attention(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)

        # packed qkv
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.attention_bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(num_heads, head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim)

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=False,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    @staticmethod
    def _extract_rope(states: torch.Tensor):
        """Extract rope."""
        rope = states.chunk(2, -1)[0]
        rope = rope.unflatten(-1, (-1, 2))
        rope = rope.transpose(-2, -1).flatten(-2, -1).contiguous()
        return rope

    @staticmethod
    def _fill_rope(states: torch.Tensor, rope: torch.Tensor):
        """Fill rope."""
        rope_part = states.chunk(2, -1)[0]
        rope = rope.unflatten(-1, (2, -1))
        rope = rope.transpose(-2, -1).flatten(-2, -1)
        rope_part.copy_(rope)
        return states

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        # chatglm series, glm4-0414 have special treatments for rope
        cos, sin = rotary_pos_emb
        q_rope = self._extract_rope(query_states)
        k_rope = self._extract_rope(key_states)
        q_rope, k_rope = self.apply_rotary_pos_emb(
            q_rope,
            k_rope,
            cos,
            sin,
            inplace=True,
        )
        query_states = self._fill_rope(query_states, q_rope)
        key_states = self._fill_rope(key_states, k_rope)

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Glm4MLP(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Glm4DecoderLayer(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Glm4Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = Glm4MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build post attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

        # build post self attention layer norm
        self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

        # build post MLP layer norm
        self.post_mlp_layernorm = RMSNorm(config.hidden_size,
                                          config.rms_norm_eps,
                                          quant_config=quantization_config,
                                          dtype=dtype,
                                          device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # self attn
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # post self attention layer norm
        hidden_states = self.post_self_attn_layernorm(hidden_states)

        # post attention layer norm
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

        # MLP
        hidden_states = self.mlp(hidden_states)

        # post MLP layer norm
        hidden_states = self.post_mlp_layernorm(hidden_states)

        return (hidden_states, residual)


class Glm4Model(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            Glm4DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states


class Glm4ForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Glm4Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                # GLM4 gate up proj weights are packed
                if '.gate_up_proj' in name:
                    param = params_dict[name]
                    gate, up = param.weight_spliter(loaded_weight)
                    load_weight(param, gate, shard_id=0)
                    load_weight(param, up, shard_id=1)
                    continue
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/glm4_1v.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .glm4 import Glm4DecoderLayer
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model


def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int],
                           position_ids: torch.Tensor, rotary_emb_func: Callable):
    _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)
    _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids
    cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)
    _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)
    _sin = torch.zeros_like(_cos)
    mrope_section = mrope_section * 2

    def _apply_split(src, dst):
        start = 0
        for i, m in enumerate(src.split(mrope_section, dim=-1)):
            dst[:, start:start + mrope_section[i]] = m[i % 3]
            start += mrope_section[i]

    _apply_split(cos, _cos)
    _apply_split(sin, _sin)

    return _cos, _sin


class Glm4vTextModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.mrope_section = config.rope_scaling['mrope_section']

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            Glm4DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        mrope_position_ids: torch.LongTensor = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        if mrope_position_ids is None:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
            cos, sin = cos[0], sin[0]
        else:
            cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,
                                              self.rotary_emb)
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states


class Glm4VisionMLP(nn.Module):
    """Vision MLP."""

    def __init__(self,
                 config: PretrainedConfig,
                 bias: bool = False,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            in_features=config.hidden_size,
            all_out_features=[config.out_hidden_size, config.out_hidden_size],
            bias=bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(in_features=config.out_hidden_size,
                                              out_features=config.hidden_size,
                                              bias=bias,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

    def forward(self, x):
        """forward."""
        return self.down_proj(self.act_fn(self.gate_up_proj(x)))


class Glm4vVisionPatchEmbed(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:
        super().__init__()
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size

        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
        self.proj = nn.Conv3d(self.in_channels,
                              self.embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              dtype=dtype,
                              device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
                                           self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class Glm4vVisionRotaryEmbedding(nn.Module):
    """Vision rotary embedding."""

    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Glm4vVisionPatchMerger(nn.Module):

    def __init__(self,
                 dim: int,
                 context_dim: int,
                 hidden_act: str,
                 bias: bool = False,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__()

        self.proj = nn.Linear(dim, dim, bias=bias, dtype=dtype, device=device)
        self.post_projection_norm = nn.LayerNorm(dim, dtype=dtype, device=device)

        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            in_features=dim,
            all_out_features=[context_dim, context_dim],
            bias=bias,
            dtype=dtype,
            device=device,
            is_tp=True,
        )

        # down
        self.down_proj = build_rowwise_linear(in_features=context_dim,
                                              out_features=dim,
                                              bias=bias,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

        # gelu
        self.act1 = nn.GELU()

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        hidden_state = self.proj(hidden_state)
        hidden_state = self.act1(self.post_projection_norm(hidden_state))
        return self.down_proj(self.act_fn(self.gate_up_proj(hidden_state)))


class Glm4vVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim, dtype=dtype, device=device)
        self.register_buffer('position_ids', torch.arange(self.num_positions).expand((1, -1)), persistent=False)

    def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
        """Forward pass with integrated position encoding adaptation using 2D
        interpolation.

        Args:
            embeddings: Input embeddings tensor
            lengths (torch.Tensor): Sequence lengths for each image in the batch.
            image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
            h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
            w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.

        Returns:
            torch.Tensor: Embeddings with adapted position encoding added.
        """
        # Get position embedding parameters
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
            adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
                lengths = torch.tensor(lengths, device=device, dtype=torch.long)
            if not isinstance(image_shapes, torch.Tensor):
                image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
            pos_embed_2d = (pos_embed_weight.view(orig_size, orig_size,
                                                  hidden_size).permute(2, 0, 1).unsqueeze(0).to(device=device,
                                                                                                dtype=torch.float32))

            # Calculate target dimensions for each patch
            target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i])
                                  for i in range(len(lengths))]).to(device=device, dtype=torch.float32)
            target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i])
                                  for i in range(len(lengths))]).to(device=device, dtype=torch.float32)

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
            grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(pos_embed_2d,
                                                    grid,
                                                    mode='bicubic',
                                                    align_corners=False,
                                                    padding_mode='border')

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
            adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class Glm4vVisionAttention(nn.Module):
    """Vision attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        dim = config.hidden_size
        num_heads = config.num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim

        # packed qkv
        self.qkv = build_qkv_proj(
            dim,
            num_q_heads=num_heads,
            num_kv_heads=num_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attention = FlashAttention(
            num_heads,
            head_dim,
            causal=False,
        )

        # o_proj
        self.proj = build_rowwise_linear(dim,
                                         dim,
                                         bias=True,
                                         quant_config=quantization_config,
                                         dtype=dtype,
                                         device=device,
                                         is_tp=True)

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        # qkv proj
        qkv_states = self.qkv(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        q, k, v = self.qkv.split_qkv(qkv_states)

        cos, sin = rotary_pos_emb
        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)

        attn_output = self.attention(
            q,
            k,
            v,
            q_start_loc=cu_seqlens[:-1],
            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],
        )

        attn_output = attn_output.reshape(seq_length, -1)

        # o proj
        attn_output = self.proj(attn_output)
        return attn_output


class Glm4vVisionBlock(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:
        super().__init__()
        self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.attn = Glm4vVisionAttention(config, dtype=dtype, device=device)
        self.mlp = Glm4VisionMLP(config, bias=False, dtype=dtype, device=device)

    def forward(self,
                hidden_states,
                cu_seqlens,
                rotary_pos_emb,
                residual: Optional[torch.Tensor] = None) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm1(hidden_states)
        else:
            hidden_states, residual = self.norm1(hidden_states, residual)

        hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        hidden_states, residual = self.norm2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class Glm4vVisionModel(nn.Module):
    """Vision transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size

        self.embeddings = Glm4vVisionEmbeddings(config, dtype=dtype, device=device)
        self.patch_embed = Glm4vVisionPatchEmbed(config, dtype=dtype, device=device)

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2, device=device)

        self.blocks = nn.ModuleList([Glm4vVisionBlock(config, dtype=dtype, device=device) for _ in range(config.depth)])
        self.merger = Glm4vVisionPatchMerger(dim=config.out_hidden_size,
                                             context_dim=config.intermediate_size,
                                             hidden_act=config.hidden_act,
                                             dtype=dtype,
                                             device=device)

        self.post_conv_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.downsample = nn.Conv2d(
            in_channels=config.hidden_size,
            out_channels=config.out_hidden_size,
            kernel_size=config.spatial_merge_size,
            stride=config.spatial_merge_size,
            dtype=dtype,
            device=device,
        )
        self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)

    def rot_pos_emb(self, grid_thw):
        """Rotary position embedding."""
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb, pos_ids

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,
                grid_thw: torch.Tensor, image_type_ids: List[torch.Tensor]) -> torch.Tensor:
        """forward."""
        hidden_states = self.patch_embed(hidden_states)
        hidden_states = self.post_conv_layernorm(hidden_states)

        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])

        residual = None
        for blk in self.blocks:
            hidden_states, residual = blk(hidden_states,
                                          cu_seqlens=cu_seqlens,
                                          rotary_pos_emb=rotary_pos_emb,
                                          residual=residual)

        hidden_states = hidden_states + residual

        hidden_states = self.post_layernorm(hidden_states)

        hidden_states = hidden_states.view(-1, self.spatial_merge_size, self.spatial_merge_size,
                                           hidden_states.shape[-1])
        hidden_states = hidden_states.permute(0, 3, 1, 2)
        hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@vlm_model
class Glm4vForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):
    """ModelForCausalLM."""
    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # preprocessor
        self.input_processor = Glm4vInputProcessor(self.config)

        # build vision model
        self.visual = Glm4vVisionModel(config.vision_config, dtype=dtype, device=device)

        # build language model
        self.language_model = Glm4vTextModel(config, dtype=dtype, device=device)

        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        mrope_position_ids: torch.Tensor = None,
        pixel_values: torch.Tensor = None,
        vis_cu_seqlens: torch.Tensor = None,
        vis_pos_emb: torch.Tensor = None,
        image_type_ids: List[torch.Tensor] = None,
        grid_thw: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)
            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))
                image_embeds = self.visual(pixel_values,
                                           cu_seqlens=vis_cu_seqlens,
                                           rotary_pos_emb=vis_pos_emb,
                                           image_type_ids=image_type_ids,
                                           grid_thw=grid_thw)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)

        hidden_states = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.language_model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.embed_tokens

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""

        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_type_ids = None
        image_mask = None
        grid_thw = None
        if context.input_multimodals is not None:
            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            if len(image_data) > 0:
                # flatten batch
                image_data = [data for im_data in image_data for data in im_data]
                pixel_values = torch.cat([data.data for data in image_data])
                image_token_id = image_data[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()
                vis_pos_emb, image_type_ids = self.visual.rot_pos_emb(grid_thw)
                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                         grid_thw[:, 0]).to(pixel_values.device)
                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
                vis_pos_emb = vis_pos_emb.repeat(1, 2)
                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())

        mrope_position_ids = getattr(context, 'mrope_position_ids', None)

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_type_ids=image_type_ids,
            grid_thw=grid_thw,
            image_mask=image_mask,
        )

    @classmethod
    def rename_weight(cls, name: str) -> str:
        """Rename weight."""
        if name.startswith('model.language_model.'):
            return 'language_model.' + name[len('model.language_model.'):]
        elif name.startswith('model.visual.'):
            return 'visual.' + name[len('model.visual.'):]
        elif name.startswith('model.'):
            return name[len('model.'):]
        return name

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.qkv.' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                elif '.gate_up_proj' in name:
                    param = params_dict[name]
                    gate, up = param.weight_spliter(loaded_weight)
                    load_weight(param, gate, shard_id=0)
                    load_weight(param, up, shard_id=1)
                    continue
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        input_ids = kwargs.get('input_ids')
        num_tokens = input_ids.size(-1)
        new_batch_size = graph_meta.max_batchs

        is_decoding = graph_meta.is_decoding
        input_buffers = graph_meta.input_buffers
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
            if is_decoding:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
            else:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']

        return new_inputs

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(mrope_delta=0)] * batch_size
        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]

    def _update_model_meta_decoding(self, context: StepContext):
        """Update model meta for decoding."""
        model_metas = self._get_model_metas(context)
        position_ids = context.position_ids

        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
        mrope_deltas = position_ids.new_tensor(mrope_deltas)
        mrope_position_ids = position_ids + mrope_deltas[None]
        mrope_position_ids = mrope_position_ids.expand(3, -1)

        context.mrope_position_ids = mrope_position_ids
        return model_metas

    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
        """Get mrope ids."""
        t, h, w = grid_thw
        h //= 2
        w //= 2
        stride = torch.tensor([h * w, w, 1], device=device)[:, None]
        size = torch.tensor([t, h, w], device=device)[:, None]
        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
        pos_ids = pos_ids // stride % size
        return pos_ids

    def _update_model_meta_prefilling(self, context: StepContext):
        """Update model meta for prefilling."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_multimodals = [None] * len(model_metas)
        position_ids = context.position_ids
        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
        mrope_position_ids = []
        new_model_metas = []
        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
            images = []
            if input_mm is not None:
                images = input_mm.get('image', [])
            if model_meta is None or 'mrope_delta' not in model_meta:
                mrope_delta = 0
            else:
                mrope_delta = model_meta['mrope_delta']

            pos_start = pos_ids[0].item()
            mrope_pos_ids = pos_ids + mrope_delta
            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
            for img in images:
                grid_thw = img.meta['grid_thw'][0].tolist()
                _, h, w = grid_thw
                h //= 2
                w //= 2
                num_pad = img.end - img.start - max(h, w)
                mrope_delta -= num_pad
                fill_start = img.start - pos_start
                fill_end = img.end - pos_start
                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
                mrope_pos_ids[:, fill_end:] -= num_pad
                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids

            mrope_position_ids.append(mrope_pos_ids)
            new_model_metas.append(dict(mrope_delta=mrope_delta))

        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
        context.mrope_position_ids = mrope_position_ids

        return new_model_metas

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        if context.is_decoding:
            return self._update_model_meta_decoding(context)
        else:
            return self._update_model_meta_prefilling(context)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Glm4vInputProcessor(BaseModelInputProcessor):
    """Glm4v input processor."""

    def __init__(self, config: PretrainedConfig) -> None:
        self.config = config

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values']
            image_grid_thw = input_mm['image_grid_thw']
            offset = input_mm['offset']
            start = offset
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=start,
                                     end=start + num_pad,
                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/glm4_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class Glm4MoeAttention(nn.Module):
    """Rewrite module of Qwen3MoeAttention."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        self.use_qk_norm = config.use_qk_norm

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            num_replicate_kv_heads=num_replicate_kv_heads,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(num_heads, head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim)

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=is_tp)

        # q, k norm
        if self.use_qk_norm:
            self.q_norm = RMSNorm(head_dim,
                                  config.rms_norm_eps,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device)
            self.k_norm = RMSNorm(head_dim,
                                  config.rms_norm_eps,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply q, k norm
        if self.use_qk_norm:
            query_states = self.q_norm(query_states)
            key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Glm4MoeMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if intermediate_size is None:
            intermediate_size = config.intermediate_size

        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=is_tp,
                                              all_reduce=all_reduce)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Glm4MoE(nn.Module):
    """Moe block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.norm_topk_prob

        self.routed_scaling_factor = config.routed_scaling_factor

        # build gate
        # refers to https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/glm4_moe.py
        # NOTE In the transformers implementation, the gate isn't an nn.Linear,
        # https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
        self.gate = nn.Linear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            dtype=torch.float32,
        )
        self.gate.e_score_correction_bias = nn.Parameter(torch.empty(config.n_routed_experts, dtype=torch.float32))

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        # build experts
        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=False,
            layer_idx=layer_idx,
        )

        # build shared experts
        intermediate_size = config.moe_intermediate_size * config.n_shared_experts
        self.shared_experts = Glm4MoeMLP(
            config=config,
            intermediate_size=intermediate_size,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            all_reduce=False,
        )

        # get all reduce
        world_size, _ = get_tp_world_rank()
        if world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor):
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        # gate
        router_logits = self.gate(hidden_states.to(dtype=torch.float32))
        topk_weights, topk_ids = self.softmax_topk(router_logits)

        # experts
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )
        out_states = out_states * self.routed_scaling_factor

        # shared experts
        shared_states = self.shared_experts(hidden_states)

        out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)
        return out_states


class Glm4MoeDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Glm4MoeAttention(config, dtype=dtype, device=device)

        if layer_idx >= config.first_k_dense_replace:
            self.mlp = Glm4MoE(config, layer_idx=layer_idx, dtype=dtype, device=device)
        else:
            self.mlp = Glm4MoeMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # self attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # fully connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Glm4MoeModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            Glm4MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = self._build_rotary_embedding(config)

    def _build_rotary_embedding(self, config: PretrainedConfig):
        """Build rotary embedding."""
        return build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states


class Glm4MoeForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # build model
        self.model = Glm4MoeModel(
            config=config,
            dtype=dtype,
            device=device,
        )

        # build lm_head
        self.lm_head = build_rowwise_linear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        # load fused weights
        if any([k in name for k in ['fused_w1w3', 'fused_w2']]):
            return self._load_weight_fused_experts(name, loaded_weight, params_dict)

        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
        """Load weight of fused expert weights."""
        num_experts = self.config.num_experts
        fused_gateup_name = 'fused_w1w3'
        fused_down_name = 'fused_w2'
        if fused_gateup_name in name:
            chunk_size = loaded_weight.shape[0] // num_experts

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up')
                param = params_dict[param_name]
                w1 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size // 2)
                w3 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id + chunk_size // 2, length=chunk_size // 2)
                load_weight(param, w1, expert_id=expert_id, shard_id='gate')
                load_weight(param, w3, expert_id=expert_id, shard_id='up')

        elif fused_down_name in name:
            chunk_size = loaded_weight.shape[0] // num_experts

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down')
                param = params_dict[param_name]
                w2 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size)
                load_weight(param, w2, expert_id=expert_id, shard_id='down')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        mtp_param_list = []
        if hasattr(self.config, 'num_nextn_predict_layers'):
            num_hidden_layers = self.config.num_hidden_layers
            num_nextn_predict_layers = self.config.num_nextn_predict_layers
            mtp_param_list = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]

        # expert map
        num_experts = self.config.n_routed_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in weights:
            # skip MTP related weights
            if any(mtp_param_name in name for mtp_param_name in mtp_param_list):
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/glm4moe_mtp.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Iterable

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .deepseek_mtp import DeepseekMTPModel
from .glm4_moe import Glm4MoE, Glm4MoeAttention, Glm4MoeDecoderLayer, Glm4MoeMLP


class Glm4MoeMTPDecoderLayer(Glm4MoeDecoderLayer):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        nn.Module.__init__(self)
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Glm4MoeAttention(config, dtype=dtype, device=device, is_tp=False)

        if layer_idx >= config.first_k_dense_replace:
            self.mlp = Glm4MoE(config, layer_idx=layer_idx, dtype=dtype, device=device, is_tp=False)
            self.mlp._all_reduce = False
        else:
            self.mlp = Glm4MoeMLP(config, dtype=dtype, device=device, is_tp=False, all_reduce=False)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)


class Glm4MoeMTPModel(DeepseekMTPModel):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__(
            config,
            ctx_mgr,
            dtype=dtype,
            device=device,
            decoder_layer_cls=Glm4MoeMTPDecoderLayer,
            build_rotary_embedding_func=build_rotary_embedding_from_config,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter],
                             expert_params_mapping: list[list[str]]):
        """Load weight experts."""
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_nextn(name, nextn_keys):
            for nextn_key in nextn_keys:
                if nextn_key in name:
                    return True
            return False

        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        num_hidden_layers = self.config.num_hidden_layers

        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)
        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]

        # expert map
        num_experts = self.config.n_routed_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in weights:
            # keep nextn
            if not __skip_nextn(name, nextn_keys):
                continue
            if '.layers' in name:
                layer_idx = int(name.split('layers.')[1].split('.')[0])
                name = self._rewrite_spec_layer_name(layer_idx, name)
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/gpt_oss.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import get_build_model_context
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class GptOssAttention(nn.Module):
    """attention."""

    def __init__(self,
                 config: PretrainedConfig,
                 attention_type: str,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        num_attention_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        scaling = head_dim**-0.5

        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_attention_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.attention_bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        if attention_type == 'sliding_attention':
            sliding_window = config.sliding_window
        elif attention_type == 'full_attention':
            sliding_window = None
        else:
            raise ValueError(f'Unsupported attention type: {attention_type}')
        # attention
        self.attn_fwd = Attention(
            num_attention_heads,
            head_dim,
            scale=scaling,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=sliding_window,
            learnable_sink=True,
        )

        # o_proj
        self.o_proj = build_o_proj(num_attention_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

        # sinks
        self.sinks = self.build_sinks(config, device)

    @classmethod
    def build_sinks(cls, config: PretrainedConfig, device):
        """Build sinks."""
        from lmdeploy.pytorch.distributed import get_tp_world_rank
        world_size, _ = get_tp_world_rank()
        num_attention_heads = config.num_attention_heads
        assert num_attention_heads % world_size == 0, (
            f'num_attention_heads={num_attention_heads} should be divisible by TP={world_size}')
        num_attention_heads = num_attention_heads // world_size
        sinks = nn.Parameter(torch.empty(num_attention_heads, device=device))
        sinks.weight_loader = cls.weight_loader_sinks
        return sinks

    @classmethod
    def weight_loader_sinks(cls, param: nn.Parameter, loaded_weight: torch.Tensor):
        """Load weight of sinks."""
        from lmdeploy.pytorch.distributed import get_tp_world_rank
        world_size, rank = get_tp_world_rank()
        loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]
        param.data.copy_(loaded_weight)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            s_aux=self.sinks,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class GateupAct:

    def __init__(self, limit: float = 7.0, alpha: float = 1.702):
        self.limit = limit
        self.alpha = alpha
        self._run: Callable = None

    def _impl(self, gateup: torch.Tensor) -> torch.Tensor:
        """Moe act."""
        gate, up = gateup.chunk(2, dim=-1)
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        return (up + 1) * glu

    @staticmethod
    @functools.lru_cache(maxsize=None)
    def build(limit: float, alpha: float):
        return GateupAct(limit, alpha)

    def _try_compile(self, gateup: torch.Tensor) -> Callable:
        try:
            run = torch.compile(self._impl, dynamic=True)
            run(gateup)
            self._run = run
        except Exception:
            self._run = self._impl

    def __call__(self, gateup: torch.Tensor) -> torch.Tensor:
        """Call the act function."""
        if self._run is None:
            self._try_compile(gateup)

        return self._run(gateup)


class GptOssExperts(nn.Module):
    """experts."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        self.intermediate_size = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.hidden_size = config.hidden_size
        self.expert_dim = self.intermediate_size
        self.top_k = config.num_experts_per_tok
        self.alpha = 1.702
        self.limit = 7.0
        self._gateup_act = GateupAct.build(self.limit, self.alpha)

        self.experts = build_fused_moe(
            self.hidden_size,
            self.expert_dim,
            self.num_experts,
            bias=True,
            top_k=self.top_k,
            renormalize=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=True,
            layer_idx=layer_idx,
            act_func=self._gateup_act,
        )

    def forward(self, hidden_states: torch.Tensor, router_indices, routing_weights) -> torch.Tensor:
        """forward."""
        batch_size, sequence_length, _ = hidden_states.shape
        out_states = self.experts(
            hidden_states[0],
            routing_weights,
            router_indices,
        )

        out_states = out_states.reshape(batch_size, sequence_length, -1)
        return out_states


class GptOssTopKRouter(nn.Module):
    """Gate + topk + softmax."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.top_k = config.num_experts_per_tok
        self.num_experts = config.num_local_experts
        self.hidden_dim = config.hidden_size
        self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=dtype, device=device))
        self.bias = nn.Parameter(torch.empty(self.num_experts, dtype=dtype, device=device))

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight, self.bias)  # (seq_len, num_experts)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = router_top_value
        return router_scores, router_indices


class GptOssMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.router = GptOssTopKRouter(config, dtype=dtype, device=device)
        self.experts = GptOssExperts(config, layer_idx, dtype=dtype, device=device)

    def forward(self, hidden_states, all_routed_experts: torch.Tensor = None):
        router_scores, router_indices = self.router(hidden_states)  # (num_experts, seq_len)
        if all_routed_experts is not None:
            all_routed_experts[:, self.layer_idx, :] = router_indices
        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
        return routed_out


class GptOssDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_type = config.layer_types[layer_idx]

        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = GptOssAttention(config, self.attention_type, layer_idx=layer_idx, dtype=dtype, device=device)

        # build MLP
        self.mlp = GptOssMLP(config, layer_idx, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        all_routed_experts: torch.Tensor = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)

        outputs = (hidden_states, residual)
        return outputs


class GptOssModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            config.pad_token_id,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            GptOssDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        all_routed_experts: torch.Tensor = None,
    ):
        """Rewrite of forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
                all_routed_experts=all_routed_experts,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class GptOssForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = GptOssModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

        # for router replay
        bm_ctx = get_build_model_context()
        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        # router replay
        all_routed_experts = None
        if self.enable_return_routed_experts:
            if inputs_embeds is not None:
                num_tokens = inputs_embeds.size(1)
            else:
                num_tokens = input_ids.size(1)
            all_routed_experts = position_ids.new_empty(
                (num_tokens, self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.uint16)

        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            all_routed_experts=all_routed_experts,
        )

        if all_routed_experts is None:
            return hidden_states
        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts_gate_up(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str,
                                                                                                     nn.Parameter]):
        """Load weight of experts gate up."""
        num_experts = self.config.num_local_experts

        loaded_weight = loaded_weight.cuda()
        if 'gate_up_proj_bias' in name:
            param_name = name.replace('experts.gate_up_proj_bias', 'experts.experts.gate_up.bias')
        elif 'gate_up_proj' in name:
            param_name = name.replace('experts.gate_up_proj', 'experts.experts.gate_up.weight')
            loaded_weight = loaded_weight.transpose(1, 2)
        param = params_dict[param_name]
        for expert_id in range(num_experts):
            w1 = loaded_weight[expert_id, ::2]
            w3 = loaded_weight[expert_id, 1::2]
            load_weight(param, w1, expert_id=expert_id, shard_id='gate')
            load_weight(param, w3, expert_id=expert_id, shard_id='up')

    def _load_weight_experts_down(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
        """Load weight of experts down."""
        num_experts = self.config.num_local_experts

        loaded_weight = loaded_weight.cuda()
        if 'down_proj_bias' in name:
            param_name = name.replace('experts.down_proj_bias', 'experts.experts.down.bias')
        elif 'down_proj' in name:
            param_name = name.replace('experts.down_proj', 'experts.experts.down.weight')
            loaded_weight = loaded_weight.transpose(1, 2)
        param = params_dict[param_name]
        for expert_id in range(num_experts):
            w2 = loaded_weight[expert_id]
            load_weight(param, w2, expert_id=expert_id, shard_id='down')

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
        """Load weight of fused expert weights."""
        if 'gate_up' in name:
            self._load_weight_experts_gate_up(name, loaded_weight, params_dict)

        elif 'down' in name:
            self._load_weight_experts_down(name, loaded_weight, params_dict)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/internlm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class InternLMAttention(nn.Module):
    """Rewrite module of LlamaAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class InternLMMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=config.bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=config.bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class InternLMDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = InternLMAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = InternLMMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class InternLMModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            InternLMDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding in LlamaModel
        rope_dim = config.hidden_size // config.num_attention_heads
        rope_max_pos_emb = config.max_position_embeddings
        scaling_factor = 1.0
        rope_scaling = config.rotary
        rope_base = rope_scaling['base']
        rope_type = rope_scaling['type']
        if rope_type == 'dynamic':
            emb_type = RopeType.DynamicNTKScaling
            scaling_factor = rope_scaling.get('scaling_factor', 1.0)
        elif rope_type == 'origin':
            emb_type = RopeType.LinearScaling
        else:
            raise RuntimeError(f'Unsupported rope type: {rope_type}')

        self.rotary_emb = build_rotary_embedding(
            rope_dim,
            rope_max_pos_emb,
            rope_base,
            scaling_factor,
            emb_type=emb_type,
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class InternLMForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of LlamaForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build LLamaModel
        self.model = InternLMModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/internlm2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class InternLM2Attention(nn.Module):
    """Rewrite module of InternLM2Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = hidden_size // num_heads
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.wqkv = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            num_replicate_kv_heads=num_replicate_kv_heads,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
        )

        # o_proj
        self.wo = build_o_proj(num_heads * head_dim,
                               hidden_size,
                               bias=config.bias,
                               quant_config=quantization_config,
                               dtype=dtype,
                               device=device,
                               is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of InternLM2Attention.forward."""
        # qkv proj
        qkv_states = self.wqkv(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.wqkv.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.wo(attn_output)
        return attn_output


class InternLM2MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.w2 = build_down_linear(config.intermediate_size,
                                    config.hidden_size,
                                    bias=False,
                                    quant_config=quantization_config,
                                    dtype=dtype,
                                    device=device,
                                    is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.w2(act)


class InternLM2DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.attention = InternLM2Attention(config, dtype=dtype, device=device)

        # build MLP
        self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.attention_norm = RMSNorm(config.hidden_size,
                                      config.rms_norm_eps,
                                      quant_config=quantization_config,
                                      dtype=dtype,
                                      device=device)

        # build attention layer norm
        self.ffn_norm = RMSNorm(config.hidden_size,
                                config.rms_norm_eps,
                                quant_config=quantization_config,
                                dtype=dtype,
                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(hidden_states, residual)

        # Self Attention
        hidden_states = self.attention(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class InternLM2Model(nn.Module):
    """Internlm2 model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.tok_embeddings = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )
        # build all decode layers
        self.layers = nn.ModuleList([
            InternLM2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding in Model
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.tok_embeddings(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.tok_embeddings


class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """Rewrote model of InternLM2ForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'w1',
            'w3',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.model = InternLM2Model(config, dtype=dtype, device=device)
        # build lm_head
        self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)
        self.lm_head = self.output

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
        """Load lora weights."""

        from lmdeploy.pytorch.adapter.adapter import load_lora_weights

        num_heads = self.config.num_attention_heads
        num_key_value_heads = self.config.num_key_value_heads
        hidden_size = self.config.hidden_size
        head_dim = hidden_size // num_heads
        group_size = num_heads // num_key_value_heads

        def _rearange_wqkv(weights):
            for name, loaded_weight in weights:
                if 'wqkv.lora_B' in name:
                    loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim))
                    q = loaded_weight[:, :-2].flatten(0, 2)
                    k = loaded_weight[:, -2].flatten(0, 1)
                    v = loaded_weight[:, -1].flatten(0, 1)
                    loaded_weight = torch.cat([q, k, v], dim=0)
                yield name, loaded_weight

        weights_iter = _rearange_wqkv(weights)
        load_lora_weights(self, weights_iter, adapter_id)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.w1', 0),
            ('.gate_up_proj', '.w3', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.wqkv' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/internlm2_reward.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .internlm2 import InternLM2Model
from .utils.cudagraph import CudaGraphMixin


class InternLM2ForRewardModel(nn.Module, CudaGraphMixin):
    """Rewrote model of InternLM2ForRewardModel."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'w1',
            'w3',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.model = InternLM2Model(config, dtype=dtype, device=device)
        # build v_head
        self.v_head = build_rowwise_linear(config.hidden_size, 1, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.v_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        vision_embeddings = context.input_embeddings
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            raise ValueError('InternLM2RewardModel does not support vision embedding')

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
        """Load lora weights."""

        from lmdeploy.pytorch.adapter.adapter import load_lora_weights

        num_heads = self.config.num_attention_heads
        num_key_value_heads = self.config.num_key_value_heads
        hidden_size = self.config.hidden_size
        head_dim = hidden_size // num_heads
        group_size = num_heads // num_key_value_heads

        def _rearange_wqkv(weights):
            for name, loaded_weight in weights:
                if 'wqkv.lora_B' in name:
                    loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim))
                    q = loaded_weight[:, :-2].flatten(0, 2)
                    k = loaded_weight[:, -2].flatten(0, 1)
                    v = loaded_weight[:, -1].flatten(0, 1)
                    loaded_weight = torch.cat([q, k, v], dim=0)
                yield name, loaded_weight

        weights_iter = _rearange_wqkv(weights)
        load_lora_weights(self, weights_iter, adapter_id)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.w1', 0),
            ('.gate_up_proj', '.w3', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.wqkv' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/internlm2_ve.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.models.internlm2 import InternLM2Attention, InternLM2MLP
from lmdeploy.pytorch.nn import RMSNorm, RopeType, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class InternLM2VEDecoderLayer(nn.Module):
    """Decoder layer with visual expert."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.attention = InternLM2Attention(config, dtype=dtype, device=device)

        # build MLP
        self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device)

        # build visual expert
        self.feed_forward_ve = InternLM2MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.attention_norm = RMSNorm(config.hidden_size,
                                      config.rms_norm_eps,
                                      quant_config=quantization_config,
                                      dtype=dtype,
                                      device=device)

        # build attention layer norm
        self.ffn_norm = RMSNorm(config.hidden_size,
                                config.rms_norm_eps,
                                quant_config=quantization_config,
                                dtype=dtype,
                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
        vision_embedding_indexing: Optional[torch.Tensor] = None,
        text_embedding_indexing: Optional[torch.Tensor] = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(hidden_states, residual)

        # Self Attention
        hidden_states = self.attention(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        if vision_embedding_indexing is not None:
            hidden_states[:, vision_embedding_indexing, :] = self.feed_forward_ve(
                hidden_states[:, vision_embedding_indexing, :].reshape(-1, self.hidden_size)).unsqueeze(0)
            if text_embedding_indexing is not None:
                hidden_states[:, text_embedding_indexing, :] = self.feed_forward(
                    hidden_states[:, text_embedding_indexing, :].reshape(-1, self.hidden_size)).unsqueeze(0)
        else:
            hidden_states = self.feed_forward(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class InternLM2VEModel(nn.Module):
    """Internlm2 model with visual expert."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.tok_embeddings = nn.Embedding(config.vocab_size,
                                           config.hidden_size,
                                           self.padding_idx,
                                           dtype=dtype,
                                           device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            InternLM2VEDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding in Model
        rope_scaling = get_rope_parameters(config)
        scaling_factor = 1.0
        emb_type = RopeType.LinearScaling
        if rope_scaling is not None:
            scaling_factor = rope_scaling.get('factor', scaling_factor)
            rope_type = rope_scaling['type']
            if rope_type == 'linear':
                emb_type = RopeType.LinearScaling
            if rope_type == 'dynamic':
                emb_type = RopeType.DynamicNTKScaling
            else:
                raise RuntimeError(f'Unsupported rope type: {rope_type}')
        rope_dim = config.hidden_size // config.num_attention_heads
        rope_max_pos_emb = config.max_position_embeddings
        rope_base = get_rope_theta(config)
        self.rotary_emb = build_rotary_embedding(
            rope_dim,
            rope_max_pos_emb,
            rope_base,
            scaling_factor,
            emb_type=emb_type,
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_embedding_indexing: Optional[torch.Tensor] = None,
        text_embedding_indexing: Optional[torch.Tensor] = None,
    ):
        """Rewrite of forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.tok_embeddings(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
                vision_embedding_indexing=vision_embedding_indexing,
                text_embedding_indexing=text_embedding_indexing,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.tok_embeddings


class InternLM2VEForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of InternLM2ForCausalLM with visual expert."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'w1',
            'w3',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.model = InternLM2VEModel(config, dtype=dtype, device=device)
        # build lm_head
        self.output = build_rowwise_linear(config.hidden_size,
                                           config.vocab_size,
                                           bias=False,
                                           dtype=dtype,
                                           device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        vision_embedding_indexing: Optional[torch.Tensor] = None,
        text_embedding_indexing: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            vision_embedding_indexing=vision_embedding_indexing,
            text_embedding_indexing=text_embedding_indexing,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.output(hidden_states)

    def support_cuda_graph(
        self,
        input_ids: torch.Tensor,
        attn_metadata: Any = None,
        **kwargs,
    ):
        """Support cudagraph."""
        if not attn_metadata.is_decoding:
            return False
        seq_lens = input_ids.size(1)
        if seq_lens <= 512:
            return True
        return False

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.w1', 0),
            ('.gate_up_proj', '.w3', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.wqkv' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/internlm3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class InternLM3Attention(nn.Module):
    """Rewrite module of InternLM3Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.qkv_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            num_replicate_kv_heads=num_replicate_kv_heads,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of InternLM3Attention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class InternLM3MLP(nn.Module):
    """Internlm3 mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        mlp_bias = getattr(config, 'bias', False)
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=mlp_bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=mlp_bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class InternLM3DecoderLayer(nn.Module):
    """Llama decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = InternLM3Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = InternLM3MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class InternLM3Model(nn.Module):
    """Internlm3 model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            InternLM3DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of InternLM3Model.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """Rewrote model of InternLM3ForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build InternLM3Model
        self.model = InternLM3Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/interns1_pro.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from lmdeploy.vl.constants import Modality

from .interns1_pro_ts import InternS1ProTimeSeriesModel
from .patch import add_prefix, get_build_model_context
from .qwen3_moe import Qwen3MoeModel
from .qwen3_vl import Qwen3VLVisionModel
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class InternS1ProForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()

        self.config = config
        self.ctx_mgr = ctx_mgr

        # build preprocessor
        self.input_processor = InternS1ProInputProcessor(self.config, dtype)

        # build vision model
        self.visual = Qwen3VLVisionModel(
            config.vision_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('visual', prefix=prefix),
        )

        # build text model
        self.language_model = Qwen3MoeModel(config.text_config,
                                            dtype=dtype,
                                            device=device,
                                            prefix=add_prefix('language_model', prefix=prefix))

        # build lm_head
        self.lm_head = self.build_lm_head(config.text_config.hidden_size,
                                          config.text_config.vocab_size,
                                          bias=False,
                                          dtype=dtype,
                                          device=device)

        # build time series model
        if hasattr(config, 'ts_config'):
            self.time_series = InternS1ProTimeSeriesModel(config.ts_config, dtype=dtype, device=device)

        # for router replay
        bm_ctx = get_build_model_context()
        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        pixel_values: torch.Tensor = None,
        vis_cu_seqlens: torch.Tensor = None,
        vis_pos_emb: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        pos_embeds: torch.Tensor = None,
        grid_thw: torch.Tensor = None,
        # for time series
        ts_values: torch.Tensor = None,
        ts_lens: torch.Tensor = None,
        ts_sr: torch.Tensor = None,
        ts_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))

                # get image embeds
                # different from qwen3vl, interns1_1 does not use deepstack visual embeds
                image_embeds, _ = self.visual(pixel_values,
                                              cu_seqlens=vis_cu_seqlens,
                                              rotary_pos_emb=vis_pos_emb,
                                              pos_embeds=pos_embeds)

                # split image embeds per sample
                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
                image_embeds = torch.split(image_embeds, split_sizes)
                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)

                # mask and scatter to create final input embeddings
                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)

            elif ts_values is not None:
                ts_embeds = self.time_series(ts_values, ts_lens, ts_sr)  # [B, T, C]
                inputs_embeds = inputs_embeds.masked_scatter_(ts_mask[..., None], ts_embeds)

        # router replay
        all_routed_experts = None
        if self.enable_return_routed_experts:
            all_routed_experts = input_ids.new_empty((input_ids.size(1), self.config.text_config.num_hidden_layers,
                                                      self.config.text_config.num_experts_per_tok),
                                                     dtype=torch.uint16)

        hidden_states = self.language_model(input_ids=input_ids,
                                            position_ids=position_ids,
                                            past_key_values=past_key_values,
                                            attn_metadata=attn_metadata,
                                            inputs_embeds=inputs_embeds,
                                            all_routed_experts=all_routed_experts)

        if all_routed_experts is None:
            return hidden_states
        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""

        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_mask = None
        grid_thw = None
        pos_embeds = None
        # for time series
        ts_values = None
        ts_lens = None
        ts_sr = None
        ts_mask = None
        if context.input_multimodals is not None:
            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]
            # flatten batch
            mm_inputs = [item for sublist in mm_inputs for item in sublist]

            if len(mm_inputs) > 0:
                modality = mm_inputs[0].modality
                image_token_id = mm_inputs[0].meta.get('image_token_id')
                video_token_id = mm_inputs[0].meta.get('video_token_id')
                ts_token_id = mm_inputs[0].meta.get('ts_token_id')

                if modality == Modality.TIME_SERIES:
                    ts_values = torch.cat([inp.data for inp in mm_inputs])
                    ts_mask = input_ids == ts_token_id

                    ts_lens = mm_inputs[0].meta['ts_lens']
                    ts_sr = mm_inputs[0].meta['ts_sr']
                else:
                    pixel_values = torch.cat([inp.data for inp in mm_inputs])
                    mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id
                    image_mask = (input_ids == mm_token_id)

                    grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()
                    vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
                    pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw)
                    vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                             grid_thw[:, 0]).to(pixel_values.device)
                    vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
                    vis_pos_emb = vis_pos_emb.repeat(1, 2)
                    vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_mask=image_mask,
            grid_thw=grid_thw,
            pos_embeds=pos_embeds,
            # for time series
            ts_values=ts_values,
            ts_lens=ts_lens,
            ts_sr=ts_sr,
            ts_mask=ts_mask,
        )

    @classmethod
    def rename_weight(cls, name: str) -> str:
        """Rename weight."""
        if name.startswith('model.language_model.'):
            return 'language_model.' + name[len('model.language_model.'):]
        elif name.startswith('model.visual.'):
            return 'visual.' + name[len('model.visual.'):]
        elif name.startswith('model.'):
            return name[len('model.'):]
        return name

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""

        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    # modify from vllm qwen3vlmoe fused expert loading
    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                                   fused_expert_params_mapping: List):
        """Load weight of fused expert weights."""
        num_experts = self.config.text_config.num_experts

        for (param_name, weight_name) in fused_expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]

            loaded_weight = loaded_weight.transpose(-1, -2)  # no bias
            if 'gate_up' in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)
                w1 = loaded_weight[0]
                w3 = loaded_weight[1]
                for expert_id in range(num_experts):
                    load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate')
                    load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up')
            elif 'down' in name:
                w2 = loaded_weight
                for expert_id in range(num_experts):
                    load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert mapping
        num_experts = self.config.text_config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            # (param_name, weight_name, expert_id, shard_id)
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        # fused expert mapping
        fused_expert_params_mapping = [
            # (param_name, weight_name)
            ('.experts.gate_up.weight', '.experts.gate_up_proj'),
            ('.experts.down.weight', '.experts.down_proj'),
        ]

        params_dict = dict(self.named_parameters())
        buffers_dict = dict(self.named_buffers())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name)
                if is_fused_expert:
                    self._load_weight_fused_experts(name,
                                                    loaded_weight,
                                                    params_dict,
                                                    fused_expert_params_mapping=fused_expert_params_mapping)
                else:
                    self._load_weight_experts(name,
                                              loaded_weight,
                                              params_dict,
                                              expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    if '.qkv.' in name:
                        param = params_dict[name]
                        q, k, v = param.weight_spliter(loaded_weight)
                        load_weight(param, q, shard_id='q')
                        load_weight(param, k, shard_id='k')
                        load_weight(param, v, shard_id='v')
                    else:
                        if name in params_dict:
                            param = params_dict[name]
                            load_weight(param, loaded_weight)
                        elif name in buffers_dict:
                            param = buffers_dict[name]
                            load_weight(param, loaded_weight)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class InternS1ProInputProcessor(BaseModelInputProcessor):
    """InternS1Pro input processor."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype) -> None:
        self.config = config
        self.dtype = dtype

    def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:
        """Make image MultiModalData."""
        pixel_values = input_mm['pixel_values'].to(self.dtype)
        image_grid_thw = input_mm['image_grid_thw']
        offset = input_mm['offset']
        start = offset
        image_token_id = input_mm['image_token_id']
        num_pad = input_mm['image_tokens']
        if isinstance(num_pad, torch.Tensor):
            num_pad = num_pad.item()

        mm_data = MultiModalData(modality=Modality.IMAGE,
                                 data=pixel_values,
                                 start=start,
                                 end=start + num_pad,
                                 meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
        return mm_data

    def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:
        """Make video MultiModalData."""
        pixel_values_videos = input_mm['pixel_values_videos'].to(self.dtype)
        video_grid_thw = input_mm['video_grid_thw']
        offset = input_mm['offset']
        start = offset
        video_token_id = input_mm['video_token_id']
        num_pad = input_mm['video_tokens']
        if isinstance(num_pad, torch.Tensor):
            num_pad = num_pad.item()

        mm_data = MultiModalData(modality=Modality.VIDEO,
                                 data=pixel_values_videos,
                                 start=start,
                                 end=start + num_pad,
                                 meta=dict(
                                     grid_thw=video_grid_thw,
                                     video_token_id=video_token_id,
                                 ))
        return mm_data

    def _make_time_series_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:
        """Make time series MultiModalData."""
        ts_values = input_mm['ts_values'].to(self.dtype)
        offset = input_mm['offset']
        ts_token_id = input_mm['ts_token_id']
        ts_lens = input_mm['ts_lens']
        ts_sr = input_mm['ts_sr']
        num_pad = input_mm['ts_tokens']
        if isinstance(num_pad, torch.Tensor):
            num_pad = num_pad.item()

        mm_data = MultiModalData(modality=Modality.TIME_SERIES,
                                 data=ts_values,
                                 start=offset,
                                 end=offset + num_pad,
                                 meta=dict(ts_lens=ts_lens, ts_sr=ts_sr, ts_token_id=ts_token_id))
        return mm_data

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_mm_data = []
        for input_mm in input_multimodals:
            modality = input_mm.get('modality')
            if modality == Modality.IMAGE:
                mm_data = self._make_image_mm_data(input_mm)
            elif modality == Modality.VIDEO:
                mm_data = self._make_video_mm_data(input_mm)
            elif modality == Modality.TIME_SERIES:
                mm_data = self._make_time_series_mm_data(input_mm)
            input_mm_data.append(mm_data)

        result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data))

        return result


================================================
FILE: lmdeploy/pytorch/models/interns1_pro_ts.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.nn import LayerNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear

from .whisper import WhisperEncoderLayer


class InternS1ProTimeSeriesEncoder(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config

        self.embed_dim = config.d_model
        self.num_mel_bins = config.num_mel_bins
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_source_positions
        self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0

        self.conv1 = nn.Conv1d(self.num_mel_bins, self.embed_dim, kernel_size=3, padding=1, dtype=dtype, device=device)
        self.conv2 = nn.Conv1d(self.embed_dim,
                               self.embed_dim,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               dtype=dtype,
                               device=device)
        self.embed_positions = nn.Embedding(self.max_source_positions, self.embed_dim, dtype=dtype, device=device)

        self.layers = nn.ModuleList(
            [WhisperEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.encoder_layers)])
        self.layer_norm = LayerNorm(config.d_model, eps=1e-5, dtype=dtype, device=device)

        self.adapt_in = build_colwise_linear(
            in_features=config.ts_adapt_in_dim,
            out_features=80,
            bias=True,
            dtype=dtype,
            device=device,
        )
        self.adapt_out = build_rowwise_linear(
            in_features=self.embed_dim,
            out_features=config.ts_adapt_out_dim,
            bias=True,
            dtype=dtype,
            device=device,
        )

    def _make_causal_mask(self,
                          input_ids_shape: torch.Size,
                          dtype: torch.dtype,
                          device: torch.device,
                          past_key_values_length: int = 0):
        """Make causal mask used for bi-directional self-attention."""
        bsz, tgt_len = input_ids_shape
        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
        mask_cond = torch.arange(mask.size(-1), device=device)
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
        mask = mask.to(dtype)

        if past_key_values_length > 0:
            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

    def _prepare_decoder_attention_mask(self, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None

        if input_shape[-1] > 1:
            combined_attention_mask = self._make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        return combined_attention_mask

    def forward(self, input_features):
        # (N, T, C) -> (T, N, C) -> (N, C, T)
        input_features = input_features.permute(1, 0, 2)
        input_features = self.adapt_in(input_features)
        input_features = input_features.permute(1, 2, 0)

        # (N, C, T) -> (N, C, T//2)
        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        # (N, C, T) -> (N, T, C)
        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight

        if inputs_embeds.shape[1] > embed_pos.shape[0]:
            target_len = inputs_embeds.shape[1]
            padding = [0, 0, 0, target_len - embed_pos.shape[0]]

            embed_pos = nn.functional.pad(embed_pos, pad=padding, mode='constant', value=0)
            hidden_states = inputs_embeds[:, :embed_pos.shape[0], :] + embed_pos
        else:
            hidden_states = inputs_embeds + embed_pos[:inputs_embeds.shape[1], :]

        input_shape = inputs_embeds.size()[:-1]
        past_key_values_length = 0
        attention_mask = self._prepare_decoder_attention_mask(input_shape, inputs_embeds, past_key_values_length)

        for idx, encoder_layer in enumerate(self.layers):
            layer_outputs = encoder_layer(hidden_states, attention_mask)
            hidden_states = layer_outputs

        # (N, T, C) -> (T, N, C)
        hidden_states = hidden_states.permute(1, 0, 2)
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.adapt_out(hidden_states)

        # (T, N, C) -> (N, T, C)
        hidden_states = hidden_states.permute(1, 0, 2)

        return hidden_states


class InternS1ProTimeSeriesConcatSubsampling(nn.Module):

    def __init__(self, in_channels: int, concat_size: int):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels * concat_size

    def forward(self, ts_signals: torch.Tensor, ts_lens: torch.Tensor):
        if ts_signals.shape[1] % 2 != 0:
            ts_signals = ts_signals[:, :-1, :]
        even_frames = ts_signals[:, ::2, :]
        odd_frames = ts_signals[:, 1::2, :]
        ts_signals = torch.cat((even_frames, odd_frames), dim=2)
        ts_lens = ts_lens // 2
        return ts_signals, ts_lens


class InternS1ProTimeSeriesFixPositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=20000, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        pe = torch.zeros(max_len, d_model, dtype=torch.float)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # hf forces float32 during init, but becomes bf16 during forward
        pe = pe.unsqueeze(0).transpose(0, 1).to(dtype=dtype, device=device)  # (max_len, 1, d_model)
        self.register_buffer('pe', pe, persistent=True)

    def forward(self, x):
        # x: (seq_len, batch_size, d_model)
        x = x + self.pe[:x.size(0), :]
        return x.clone()


class InternS1ProTimeSeriesMultiChannelAdaptiveSubsampling(nn.Module):

    def __init__(self,
                 hidden_dim=128,
                 nhead=8,
                 num_encoder_layers=1,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=1,
                              out_channels=hidden_dim,
                              kernel_size=5,
                              stride=1,
                              padding=2,
                              dtype=dtype,
                              device=device)
        encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dtype=dtype, device=device)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.pos_encoder = InternS1ProTimeSeriesFixPositionalEncoding(d_model=hidden_dim, dtype=dtype, device=device)
        self.subsampling = InternS1ProTimeSeriesConcatSubsampling(128, 2)

    def forward(self, inputs, input_lens, sr):
        sr = torch.as_tensor(sr, dtype=torch.float32)
        strides = torch.floor(160 / ((1 + torch.exp(-sr / 100))**6))
        patch_sizes = strides * 2
        patched_outputs = []
        output_lens = []

        for i in range(len(inputs)):
            seq = inputs[i]  # [seq_len, num_channel]
            ps = patch_sizes[i].item()
            st = strides[i].item()
            le = input_lens[i]

            output_len = torch.ceil((le - ps) / st) + 1
            pad_len = ((output_len - 1) * st + ps - le).long().item()
            if seq.ndim == 1:
                seq = seq.unsqueeze(-1)
            seq = nn.functional.pad(seq, (0, 0, 0, pad_len), 'constant', 0)
            assert output_len > 0, (seq.shape, ps, st, le, output_len)
            output_lens.append(output_len)
            indices = (torch.arange(0, output_len * st, st).unsqueeze(1) + torch.arange(ps)).long()
            patched = seq[indices]

            output = self.forward_encoder(patched)  # [num_patch, D]
            patched_outputs.append(output)

        outputs = nn.utils.rnn.pad_sequence(patched_outputs, batch_first=True)
        output_lens = torch.tensor(output_lens).squeeze().to(outputs.device).long()
        if output_lens.ndim == 0:
            output_lens = output_lens.unsqueeze(0)

        outputs, output_lens = self.subsampling(outputs.clone(), output_lens.clone())
        return outputs, output_lens

    def forward_encoder(self, x):
        num_patch, patch_len, C = x.shape
        # conv1
        # treat each channel as an independent sample and feed it into conv1
        x = x.reshape(num_patch * C, 1, patch_len)
        x = nn.functional.relu((self.conv(x)))  # [B*C, D1, L]
        x = x.permute(2, 0, 1)  # [L, B*C, D1]

        x = self.pos_encoder(x)  # [L, B*C, D1]
        x = self.transformer_encoder(x)
        x = x.mean(0)

        x = x.reshape(num_patch, C, -1)

        return x.mean(1)


class InternS1ProTimeSeriesProjector(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layer_norm = LayerNorm(config.ts_hidden_dim, eps=1e-5, dtype=dtype, device=device)
        self.linear_1 = build_colwise_linear(in_features=config.ts_hidden_dim,
                                             out_features=config.out_hidden_size,
                                             bias=True,
                                             dtype=dtype,
                                             device=device)
        self.act = ACT2FN[config.activation_function]
        self.linear_2 = build_rowwise_linear(in_features=config.out_hidden_size,
                                             out_features=config.out_hidden_size,
                                             bias=True,
                                             dtype=dtype,
                                             device=device)

    def forward(self, ts_features):
        hidden_states = self.layer_norm(ts_features)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class InternS1ProTimeSeriesModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.encoder_embed = InternS1ProTimeSeriesMultiChannelAdaptiveSubsampling(dtype=dtype, device=device)
        self.encoder = InternS1ProTimeSeriesEncoder(config, dtype=dtype, device=device)
        self.projector = InternS1ProTimeSeriesProjector(config, dtype=dtype, device=device)

    def forward(
        self,
        time_series_signals: Optional[torch.FloatTensor] = None,
        ts_lens: Optional[torch.Tensor] = None,
        sr: Optional[torch.Tensor] = None,
        time_series_embeds: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple]:
        if time_series_signals is None and time_series_embeds is None:
            raise ValueError('You have to specify time_series_signals or time_series_embeds')

        # embedded values can be passed in directly, but the dimensions must match
        if time_series_embeds is not None and len(
                time_series_embeds.shape) == 3 and time_series_embeds.shape[-1] == self.config.ts_adapt_in_dim:
            time_series_embeds = time_series_embeds
        else:
            if ((isinstance(time_series_signals, list) and len(time_series_signals[0].shape) == 2)
                    or (isinstance(time_series_signals, torch.Tensor) and len(time_series_signals.shape) == 3)):
                time_series_embeds, ts_lens = self.encoder_embed(time_series_signals, ts_lens, sr)
            else:
                raise ValueError(f'wrong time_series_signals size: {time_series_signals[0].shape}')

        # [B, 64000, 1] -> [B, 200, 256] -> [B, 100, 1024]
        last_hidden_state = self.encoder(input_features=time_series_embeds)

        # ts_lens after encoder
        ts_lens = (ts_lens + 1) // 2
        assert torch.all(ts_lens > 0), f'The length of time_series_embeds is so small. ts_lens: {ts_lens}'

        last_hidden_state = self.projector(last_hidden_state)
        return last_hidden_state


================================================
FILE: lmdeploy/pytorch/models/internvl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from packaging import version
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, vlm_model


class Gating(nn.Module):

    def __init__(self, hidden_size=2048, expansion_factor=4, dtype=None, device=None):
        super().__init__()

        mid_dim = hidden_size * expansion_factor

        def mlp_block(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim, bias=True, dtype=dtype, device=device),
                nn.GELU(),
                nn.Identity(),
                nn.Linear(out_dim, in_dim, bias=True, dtype=dtype, device=device),
                nn.Identity(),
                nn.LayerNorm(in_dim, dtype=dtype, device=device),
            )

        self.block1 = mlp_block(hidden_size, mid_dim)
        self.block2 = mlp_block(hidden_size, mid_dim)
        self.block3 = mlp_block(hidden_size, mid_dim)
        self.block4 = mlp_block(hidden_size, mid_dim)

        self.gate = nn.Sequential(
            nn.LayerNorm(hidden_size, dtype=dtype, device=device),
            nn.Linear(hidden_size, 2, bias=True, dtype=dtype, device=device)  # 2 experts
        )

    def forward(self, x):
        x = x + self.block1(x)
        x = x + self.block2(x)
        x = x + self.block3(x)
        x = x + self.block4(x)

        logits = self.gate(x)  # shape: [B, 2]
        probs = torch.softmax(logits, dim=-1)
        return probs


class CrossAttentionPooling(nn.Module):

    def __init__(self, dim, num_heads=16, dtype=None, device=None):
        super().__init__()
        self.query_token = nn.Parameter(torch.randn(1, dim, dtype=dtype, device=device))  # [1, D]

        self.attn1 = nn.MultiheadAttention(embed_dim=dim,
                                           num_heads=num_heads,
                                           batch_first=True,
                                           dtype=dtype,
                                           device=device)
        self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)

        self.attn2 = nn.MultiheadAttention(embed_dim=dim,
                                           num_heads=num_heads,
                                           batch_first=True,
                                           dtype=dtype,
                                           device=device)
        self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)

        self.attn3 = nn.MultiheadAttention(embed_dim=dim,
                                           num_heads=num_heads,
                                           batch_first=True,
                                           dtype=dtype,
                                           device=device)
        self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)

        self.attn4 = nn.MultiheadAttention(embed_dim=dim,
                                           num_heads=num_heads,
                                           batch_first=True,
                                           dtype=dtype,
                                           device=device)
        self.norm4 = nn.LayerNorm(dim, dtype=dtype, device=device)

    def forward(self, batched_tokens: list[torch.Tensor]):
        """
        batched_tokens: List of Tensors of shape [Ti, D], length = B
        """
        B = len(batched_tokens)
        D = batched_tokens[0].shape[-1]
        device = batched_tokens[0].device

        # 1. Padding
        max_len = max(t.shape[0] for t in batched_tokens)
        dtype = self.query_token.dtype
        padded = torch.zeros(B, max_len, D, dtype=dtype, device=device)
        padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device)

        for i, t in enumerate(batched_tokens):
            L = t.shape[0]
            padded[i, :L] = t
            padding_mask[i, :L] = False

        # 2. Query token: [B, 1, D]
        query = self.query_token.unsqueeze(0).expand(B, -1, -1)  # learnable token for each sample

        # 3. First attention
        out1, _ = self.attn1(query, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]
        out1 = self.norm1(out1)

        # 4. Second attention
        out2, _ = self.attn2(out1, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]
        out2 = self.norm2(out2)

        out3, _ = self.attn2(out2, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]
        out3 = self.norm2(out3)

        out4, _ = self.attn2(out3, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]
        out4 = self.norm2(out4)

        return out4.squeeze(1)


class InternVisionEmbeddings(nn.Module):
    """Intern vision embedding."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size,
                                         dtype=dtype,
                                         device=device)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))

    def _get_pos_embed(self, pos_embed, H, W):
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size,
                                              -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic',
                                  align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
        return pos_embed

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        position_embedding = torch.cat(
            [self.position_embedding[:, :1, :],
             self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)],
            dim=1)
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': LayerNorm,
}


@torch.compile(dynamic=True)
def pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    """Pre rms norm."""
    q = q.to(torch.float32)
    k = k.to(torch.float32)
    variance_q = (q * q).sum(-1, keepdim=True)
    variance_k = (k * k).sum(-1, keepdim=True)
    variance = torch.stack([variance_q, variance_k], dim=0)
    return variance


@torch.compile(dynamic=True)
def post_rms_norm(q: torch.Tensor, k: torch.Tensor, weight_q: torch.Tensor, weight_k: torch.Tensor,
                  variance: torch.Tensor, eps: float, embed_dim: int, dtype: torch.dtype):
    """Post rms norm."""
    q = q.to(torch.float32)
    k = k.to(torch.float32)
    variance = variance / embed_dim + eps
    variance_q, variance_k = variance
    q = q * torch.rsqrt(variance_q)
    q = q.to(dtype) * weight_q
    k = k * torch.rsqrt(variance_k)
    k = k.to(dtype) * weight_k
    return q, k


class InternAttention(nn.Module):
    """Intern vl attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.qkv = build_qkv_proj(
            self.embed_dim,
            num_q_heads=self.num_heads,
            num_kv_heads=self.num_heads,
            head_size=self.head_dim,
            bias=config.qkv_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
            self.q_norm = RMSNorm(
                self.embed_dim,
                eps=config.layer_norm_eps,
                dtype=dtype,
                device=device,
                tp=True,
                align=self.head_dim,
            )
            self.k_norm = RMSNorm(
                self.embed_dim,
                eps=config.layer_norm_eps,
                dtype=dtype,
                device=device,
                tp=True,
                align=self.head_dim,
            )

        self.scale = self.head_dim**-0.5

        # o_proj
        self.proj = build_o_proj(self.embed_dim,
                                 self.embed_dim,
                                 bias=True,
                                 quant_config=quantization_config,
                                 dtype=dtype,
                                 device=device,
                                 is_tp=True,
                                 tp_align_size=self.head_dim)

    def pre_rms_norm(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
        """Pre rms norm."""
        return pre_rms_norm(q, k)

    def post_rms_norm(self, q: torch.Tensor, k: torch.Tensor, variance: torch.Tensor,
                      dtype: torch.dtype) -> torch.Tensor:
        """Post rms norm."""
        eps = self.config.layer_norm_eps
        return post_rms_norm(q, k, self.q_norm.weight, self.k_norm.weight, variance, eps, self.embed_dim, dtype)

    def qkv_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        import lmdeploy.pytorch.distributed as dist
        q_shape = q.shape
        k_shape = k.shape
        q = q.flatten(-2, -1)
        k = k.flatten(-2, -1)

        tp, _ = get_tp_world_rank()
        if tp == 1:
            q = self.q_norm(q).view(q_shape)
            k = self.k_norm(k).view(k_shape)
            return q, k

        # variance
        variance = self.pre_rms_norm(q, k)
        dist.all_reduce(variance)
        q, k = self.post_rms_norm(q, k, variance, q.dtype)
        q = q.view(q_shape)
        k = k.view(k_shape)

        return q, k

    def forward(self, hidden_states):
        """forward."""

        # qkv proj
        qkv_states = self.qkv(hidden_states)
        q, k, v = self.qkv.split_qkv(qkv_states)

        if self.qk_normalization:
            q, k = self.qkv_norm(q, k)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.proj(attn_output)
        return attn_output


class InternMLP(nn.Module):
    """Intern vl mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from transformers.activations import ACT2FN
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.act = ACT2FN[config.hidden_act]

        self.fc1 = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            dp_disable_tp=True,
        )

        self.fc2 = build_rowwise_linear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            dp_disable_tp=True,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class InternVisionEncoderLayer(nn.Module):
    """Intern vision encoder layer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = getattr(config, 'norm_type', 'rms_norm')

        self.attn = InternAttention(config, dtype=dtype, device=device)
        self.mlp = InternMLP(config, dtype=dtype, device=device)
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)

        self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
        self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))

    @enable_micro_batch(param_name='hidden_states', index=0)
    def _attn(self, hidden_states):
        hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1
        return hidden_states

    @enable_micro_batch(param_name='hidden_states', index=0)
    def _mlp(self, hidden_states):
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
        return hidden_states

    def forward(
        self,
        hidden_states,
    ):
        hidden_states = self._attn(hidden_states)
        hidden_states = self._mlp(hidden_states)
        return hidden_states


class InternVisionEncoder(nn.Module):
    """Intern vision encoder."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [InternVisionEncoderLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)])

    def forward(
        self,
        inputs_embeds,
    ):
        """forward."""
        hidden_states = inputs_embeds
        for _, encoder_layer in enumerate(self.layers):
            layer_outputs = encoder_layer(hidden_states, )
            hidden_states = layer_outputs
        return hidden_states


@vlm_model
class InternVisionModel(nn.Module):
    """Intern vision model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config

        self.embeddings = InternVisionEmbeddings(config, dtype=dtype, device=device)
        self.encoder = InternVisionEncoder(config, dtype=dtype, device=device)

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ):
        """forward."""
        assert pixel_values.dim() == 4
        hidden_states = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)
        last_hidden_state = encoder_outputs

        return last_hidden_state


class InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.select_layer = config.select_layer

        llm_config = config.llm_config
        self.llm_arch_name = llm_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'

        vision_config = config.vision_config
        if self.is_mono:
            from .internvl_patch import InternVisionPatchModel
            self.vision_model = InternVisionPatchModel(
                vision_config,
                dtype=dtype,
                device=device,
            )
        else:
            self.vision_model = InternVisionModel(vision_config, dtype=dtype, device=device)

        self.language_model = build_model_from_hf_config(llm_config, dtype=dtype, device=device)
        self.lm_head = self.language_model.lm_head
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.llm_config.hidden_size
        self.downsample_ratio = config.downsample_ratio
        self.mlp1 = nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, dtype=dtype, device=device),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size,
                      bias=True,
                      dtype=dtype,
                      device=device), nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size, bias=True, dtype=dtype, device=device))

        # for Mono-InternVL
        if self.is_mono:
            assert dtype != torch.float16, ('Currently Mono-InternVL does not support FP16 due to'
                                            'numerical instability. Please use BF16 instead.')

        self.input_processor = InternVLInputProcessor(self.config, dtype)

        self.compile_vit = False

        self.flash_mode = getattr(config, 'flash_mode', None)
        if self.flash_mode is not None:
            self.flash_relative_threshold = config.flash_relative_threshold
            self.flash_absolute_threshold = config.flash_absolute_threshold

            self.mlp2 = nn.Sequential(
                nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**4, dtype=dtype, device=device),
                nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**4,
                          llm_hidden_size * 2,
                          bias=True,
                          dtype=dtype,
                          device=device), nn.GELU(), nn.Identity(),
                nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(),
                nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device))

            self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device)
            self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device)

    def compile_model(self):
        torch_version = version.parse(torch.__version__)
        if torch_version < version.parse('2.5.0'):
            return

        tp, _ = get_tp_world_rank()
        if torch_version >= version.parse('2.6.0') and tp > 1:
            torch._inductor.config.reorder_for_compute_comm_overlap = True
            if isinstance(self.vision_model, InternVisionModel):
                self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward,
                                                                'inputs_embeds',
                                                                index=0)

        self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune-no-cudagraphs')
        self.compile_vit = True
        self.has_compiled_vit = False

    def _mark_dynamic_once(self, pixel_values, dims):
        """Call torch._dynamo.mark_dynamic to avoid recompile."""
        if not self.compile_vit or self.has_compiled_vit or pixel_values is None:
            return

        torch._dynamo.mark_dynamic(pixel_values, dims)
        self.has_compiled_vit = True

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale -->
        # N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
        x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values):
        """Extract vision feature."""
        assert self.select_layer == -1
        vit_embeds = self.vision_model(pixel_values)
        if self.is_mono:
            if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]:
                vit_embeds = vit_embeds[:, 1:, :]
        else:
            vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def compress_visual_tokens_in_sentence(
        self,
        input_embeds: torch.Tensor,
        input_ids: torch.Tensor,
        img_context_token_id: int,
        gate_result,
    ) -> tuple:
        # reshape
        B, N, C = input_embeds.shape
        input_embeds = input_embeds.reshape(B * N, C)
        input_ids = input_ids.reshape(B * N)

        N, C = input_embeds.shape
        lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id)

        keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device)

        total_blocks = 0
        block_counts = []
        for length in lengths.tolist():
            if length % 256 != 0:
                raise ValueError(f'l % 256 != 0, l = {length}')
            num_blocks = length // 256
            block_counts.append(num_blocks)
            total_blocks += num_blocks

        flag_idx = 0
        for s, e, l, num_blocks in zip(starts.tolist(), ends.tolist(), lengths.tolist(), block_counts):
            for i in range(num_blocks):
                block_start = s + i * 256
                block_end = block_start + 256

                compress = gate_result[flag_idx]
                flag_idx += 1

                if compress:
                    keep_mask[block_start + 64:block_end] = False

        # update
        new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :]
        new_input_ids = input_ids[keep_mask.to(input_ids.device)]
        new_image_mask = (new_input_ids == img_context_token_id)

        # reshape back
        new_input_ids = new_input_ids.reshape(B, -1)
        new_input_embeds = new_input_embeds.reshape(B, -1, C)

        # since multiple sequences may concat together, we need to update the seqlens individually
        # we calculate compressed token len for each sequence, and get new len for each sequence
        crt_ctx = self.ctx_mgr.current_context()
        seq_lengths = crt_ctx.q_seqlens
        # split the keep_mask into chunks corresponding to each original sequence
        mask_chunks = torch.split(keep_mask, seq_lengths.tolist())
        # the new length of each sequence is the number of tokens kept (sum of True values)
        new_seq_lengths = [chunk.sum().item() for chunk in mask_chunks]

        return new_input_embeds, new_input_ids, new_image_mask, new_seq_lengths

    def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id: int):
        input_ids = input_ids.squeeze(0)  # (N,)
        selected = (input_ids == img_context_token_id)
        padded = torch.cat(
            [torch.tensor([0], device=selected.device),
             selected.int(),
             torch.tensor([0], device=selected.device)])
        diff = torch.diff(padded)

        starts = (diff == 1).nonzero(as_tuple=True)[0]
        ends = (diff == -1).nonzero(as_tuple=True)[0]
        lengths = ends - starts

        return lengths, starts, ends

    def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor):
        """
        features: Tensor of shape [T, 1024, 1024]
        split_sizes: 1D Tensor like [3, 3, 4] — tile of each sample

        returns: List of Tensors of shape [tile_i * 1024, 1024]
        """
        # split features -> each sample a tile list
        tile_splits = torch.split(features, split_sizes, dim=0)

        # merge the first two dimensions: tile * 1024 × 1024
        merged = [x.reshape(-1, x.shape[-1]) for x in tile_splits]

        return merged

    def extract_feature_flash(self, pixel_values, lengths):

        vit_embeds_1024 = self.vision_model(pixel_values)

        vit_embeds_1024 = vit_embeds_1024[:, 1:, :]
        h = w = int(vit_embeds_1024.shape[1]**0.5)
        vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1)

        # begin moe
        lengths = [int(x) for x in lengths.tolist()]
        vit_embeds_1024_split_and_merge = self.split_and_merge(vit_embeds_1024, lengths)

        gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge)
        gate = self.gating(gate)

        vit_embeds_256 = vit_embeds_1024

        with torch.no_grad():
            vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.downsample_ratio**2)
            vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1])
            vit_embeds_64 = self.mlp2(vit_embeds_64)

            vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.downsample_ratio)
            vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1])
            vit_embeds_256 = self.mlp1(vit_embeds_256)

        return vit_embeds_64, vit_embeds_256, gate

    def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, img_context_token_id: int):
        lang_embeds = self.language_model.get_input_embeddings()(input_ids)

        self._mark_dynamic_once(pixel_values, [0])

        lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id)
        lengths = lengths // 256
        lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64)
        lengths = lengths_sum.repeat_interleave(1)
        vit_embeds_64, vit_embeds_256, gate_result = self.extract_feature_flash(pixel_values, lengths)

        relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold)
        gate_result = (gate_result[:, 0] > relative_threshold_value) & (gate_result[:, 0]
                                                                        >= self.flash_absolute_threshold)

        selected_embeds = [
            vit_embeds_64[i] if gate_result[i] else vit_embeds_256[i] for i in range(gate_result.size(0))
        ]

        vit_embeds = torch.cat(selected_embeds, dim=0)

        # compress visual tokens in sentence
        new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths = self.compress_visual_tokens_in_sentence(
            input_embeds=lang_embeds,
            input_ids=input_ids,
            img_context_token_id=img_context_token_id,
            gate_result=gate_result,
        )

        return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths

    def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int],
                              context: StepContext) -> StepContext:
        """Update the forward inputs, position_ids and attention metadata."""
        from lmdeploy.pytorch.model_inputs import ModelInputs

        crt_ctx = self.ctx_mgr.current_context()
        assert crt_ctx is not None, 'Current context cannot be None.'

        # update model metas
        prev_lens = [0] * len(context.model_metas)
        has_model_metas = context.model_metas is not None and context.model_metas[0] is not None
        context.is_model_meta_updated = has_model_metas
        if has_model_metas:
            prev_lens = [meta.get('new_seqlen', 0) for meta in context.model_metas]

            for idx, meta in enumerate(context.model_metas):
                meta.update({'new_seqlen': prev_lens[idx] + new_seqlens[idx]})
        else:
            context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens]

        # create new model inputs and context, to get updated position_ids and attn_metadata
        device = input_ids.device
        total_msgs = len(new_seqlens)
        kv_seqlens = torch.tensor([meta['new_seqlen'] for meta in context.model_metas], dtype=torch.long)
        new_model_inputs = ModelInputs(input_ids=input_ids,
                                       seq_length=torch.tensor(new_seqlens, device=device, dtype=torch.long),
                                       history_lengths=torch.tensor(prev_lens, device=device, dtype=torch.long),
                                       block_offsets=crt_ctx.block_offsets,
                                       is_decoding=False,
                                       num_ignored_history=torch.zeros(total_msgs, device=device, dtype=torch.long),
                                       max_q_seqlen=kv_seqlens.max().item(),
                                       max_kv_seqlen=kv_seqlens.max().item(),
                                       sum_kv_seqlen=kv_seqlens.sum().item(),
                                       model_metas=context.model_metas)
        new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config, crt_ctx.cache_config)

        # update attributes of the context in model agent
        context.q_seqlens = new_ctx.q_seqlens

        return new_ctx.position_ids, new_ctx.attn_metadata

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        vision_embedding_indexing: torch.Tensor = None,
        text_embedding_indexing: torch.Tensor = None,
        image_token_id: int = None,
        context: StepContext = None,
        **kwargs,
    ):
        if inputs_embeds is None and pixel_values is not None:
            if self.flash_mode:
                # extract feature and compress visual tokens
                vit_embeds, lang_embeds, input_ids, image_mask, new_seqlens = self.extract_and_compress(
                    pixel_values, input_ids, image_token_id)

                # update forward inputs
                position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens, context)
            else:
                # extract feature
                self._mark_dynamic_once(pixel_values, [0])
                vit_embeds = self.extract_feature(pixel_values)
                lang_embeds = self.language_model.get_input_embeddings()(input_ids)

            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)

            inputs_embeds = lang_embeds

        if self.is_mono:
            return self.language_model.forward(input_ids=input_ids,
                                               inputs_embeds=inputs_embeds,
                                               past_key_values=past_key_values,
                                               position_ids=position_ids,
                                               attn_metadata=attn_metadata,
                                               vision_embedding_indexing=vision_embedding_indexing,
                                               text_embedding_indexing=text_embedding_indexing)
        else:
            return self.language_model.forward(input_ids=input_ids,
                                               inputs_embeds=inputs_embeds,
                                               past_key_values=past_key_values,
                                               position_ids=position_ids,
                                               attn_metadata=attn_metadata)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = None

        # vision inputs
        pixel_values = None
        image_mask = None
        image_token_id = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values])
            else:
                pixel_values = None
                image_mask = None

        if self.is_mono and pixel_values is not None:
            vision_embedding_indexing = torch.arange(input_ids.shape[1], device=input_ids.device)
            vision_embedding_indexing = vision_embedding_indexing[image_mask[0]]

        # get inputs from context
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            vision_embedding_indexing = context.input_embedding_indexing
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        has_model_metas = context.model_metas is not None and context.model_metas[0] is not None
        context.is_model_meta_updated = has_model_metas
        if context.is_decoding:
            if has_model_metas:
                # NOTE, zhouxinyu, we need to consider the increasing batch in the decoding stage
                # currently implementation will keep the batch size same as the prefill stage

                # model meta from the previous step, therefore +1 for the current decoding step
                new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas]

                # update model metas for the next step
                context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens]

                # update position ids, attn_metadata
                new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long)
                position_ids = new_kv_seqlens - 1
                attn_metadata.kv_seqlens = new_kv_seqlens
                attn_metadata.cu_seqlens_k = torch.nn.functional.pad(
                    torch.cumsum(new_kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
        else:
            # in the case of long context, messages may be split into multiple segments and perform prefill sequentially
            # 1. this will only be done when flash_mode is on
            # 2. if it is a text segment, we update model metas before forward
            # 3. if it is an image segment, we update model metas later, after vision forward / compression
            is_text_segment = (inputs_embeds is None) and (pixel_values is None)

            if self.flash_mode and is_text_segment:
                crt_ctx = self.ctx_mgr.current_context()
                seq_lengths = crt_ctx.q_seqlens

                if has_model_metas:
                    prev_lens = [meta.get('new_seqlen', 0) for meta in context.model_metas]

                    for idx, meta in enumerate(context.model_metas):
                        meta.update({'new_seqlen': prev_lens[idx] + seq_lengths[idx].item()})

                    # update position ids, attn_metadata
                    prev_lens = torch.tensor(prev_lens, device=input_ids.device, dtype=torch.long)
                    ranges = torch.arange(0, input_ids.shape[1], device=input_ids.device)
                    position_ids = prev_lens[:, None] + ranges[None, :]
                    position_ids = position_ids
                    attn_metadata.kv_seqlens = prev_lens + seq_lengths
                else:
                    # init model metas
                    context.model_metas = [{'new_seqlen': seqlen} for seqlen in seq_lengths.tolist()]

        if self.is_mono and vision_embedding_indexing is not None:
            all_indices = torch.arange(input_ids.shape[1]).to(input_ids)
            text_embedding_indexing = all_indices[~torch.isin(all_indices, vision_embedding_indexing)]
            if vision_embedding_indexing.numel() == 0:
                vision_embedding_indexing = None
            if text_embedding_indexing.numel() == 0:
                text_embedding_indexing = None
            return dict(input_ids=input_ids,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        attn_metadata=attn_metadata,
                        pixel_values=pixel_values,
                        image_mask=image_mask,
                        inputs_embeds=inputs_embeds,
                        vision_embedding_indexing=vision_embedding_indexing,
                        text_embedding_indexing=text_embedding_indexing,
                        image_token_id=image_token_id,
                        context=context)
        else:
            return dict(input_ids=input_ids,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        attn_metadata=attn_metadata,
                        pixel_values=pixel_values,
                        image_mask=image_mask,
                        inputs_embeds=inputs_embeds,
                        image_token_id=image_token_id,
                        context=context)

    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
        """Load lora weights."""

        if hasattr(self.language_model, 'load_lora_weights'):
            return self.language_model.load_lora_weights(weights, adapter_id)
        else:
            from lmdeploy.pytorch.adapter.adapter import load_lora_weights

            return load_lora_weights(weights, adapter_id)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        lang_prefix = 'language_model.'
        lang_prefix_length = len(lang_prefix)
        new_weights = dict()
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if name.startswith(lang_prefix):
                new_key = name[lang_prefix_length:]
                new_weights[new_key] = loaded_weight
                continue

            if 'qkv' in name:
                param = params_dict[name]
                q, k, v = param.weight_spliter(loaded_weight)
                load_weight(param, q, shard_id='q')
                load_weight(param, k, shard_id='k')
                load_weight(param, v, shard_id='v')
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        self.language_model.load_weights(new_weights.items())

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class InternVLInputProcessor(BaseModelInputProcessor):
    """Internvl input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

        vision_config = config.vision_config
        self.image_size = vision_config.image_size
        self.patch_size = vision_config.patch_size
        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1
        self.vision_token_num = self.num_patches // 4

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/internvl3_hf.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from packaging import version
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, vlm_model


@torch.compile(dynamic=True)
def pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    """Pre rms norm."""
    q = q.to(torch.float32)
    k = k.to(torch.float32)
    variance_q = (q * q).sum(-1, keepdim=True)
    variance_k = (k * k).sum(-1, keepdim=True)
    variance = torch.stack([variance_q, variance_k], dim=0)
    return variance


@torch.compile(dynamic=True)
def post_rms_norm(q: torch.Tensor, k: torch.Tensor, weight_q: torch.Tensor, weight_k: torch.Tensor,
                  variance: torch.Tensor, eps: float, embed_dim: int, dtype: torch.dtype):
    """Post rms norm."""
    q = q.to(torch.float32)
    k = k.to(torch.float32)
    variance = variance / embed_dim + eps
    variance_q, variance_k = variance
    q = q * torch.rsqrt(variance_q)
    q = q.to(dtype) * weight_q
    k = k * torch.rsqrt(variance_k)
    k = k.to(dtype) * weight_k
    return q, k


class InternVLVisionPatchEmbeddings(nn.Module):
    """This class turns `pixel_values` of shape `(batch_size, num_channels,
    height, width)` into the initial `hidden_states` (patch embeddings) of
    shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches
        self.patch_shape = patch_shape

        self.projection = nn.Conv2d(num_channels,
                                    hidden_size,
                                    kernel_size=patch_size,
                                    stride=patch_size,
                                    dtype=dtype,
                                    device=device)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError(
                'Make sure that the channel dimension of the pixel values match with the one set in the configuration.')

        embeddings = self.projection(pixel_values)
        embeddings = embeddings.flatten(2).transpose(1, 2)

        return embeddings


class InternVLVisionEmbeddings(nn.Module):
    """Intern vision embedding."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.cls_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device))
        if config.use_mask_token:
            self.mask_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device))
        else:
            self.mask_token = None
        self.patch_embeddings = InternVLVisionPatchEmbeddings(config, dtype=dtype, device=device)

        self.num_positions = self.patch_embeddings.num_patches + 1

        if config.use_absolute_position_embeddings:
            self.position_embeddings = nn.Parameter(
                torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))
        else:
            self.position_embeddings = None

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int):
        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1

        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
        if num_patches == num_positions and height == width:
            return self.position_embeddings

        target_dtype = embeddings.dtype
        class_pos_embed = self.position_embeddings[:, :1]
        patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        new_height = height // self.patch_size[0]
        new_width = width // self.patch_size[1]
        sqrt_num_positions = int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.float().reshape(1, sqrt_num_positions, sqrt_num_positions,
                                                          -1).permute(0, 3, 1, 2)
        patch_pos_embed = F.interpolate(patch_pos_embed,
                                        size=(new_height, new_width),
                                        mode='bicubic',
                                        align_corners=False)
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim).to(target_dtype)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        patch_embeds = self.patch_embeddings(pixel_values)  # shape = [*, channel, width, height]
        batch_size = patch_embeds.shape[0]
        cls_token = self.cls_token.expand(batch_size, 1, -1)
        embeddings = torch.cat([cls_token, patch_embeds], dim=1)
        if self.position_embeddings is not None:
            position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
            embeddings = embeddings + position_embeddings
        return embeddings


NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': LayerNorm,
}


class InternVLVisionAttention(nn.Module):
    """Intern vl attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.qkv_proj = build_qkv_proj(
            self.embed_dim,
            num_q_heads=self.num_heads,
            num_kv_heads=self.num_heads,
            head_size=self.head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        self.use_qk_norm = config.use_qk_norm

        if self.use_qk_norm:
            self.q_norm = RMSNorm(
                self.embed_dim,
                eps=config.layer_norm_eps,
                dtype=dtype,
                device=device,
                tp=True,
                align=self.head_dim,
            )
            self.k_norm = RMSNorm(
                self.embed_dim,
                eps=config.layer_norm_eps,
                dtype=dtype,
                device=device,
                tp=True,
                align=self.head_dim,
            )

        self.scale = self.head_dim**-0.5

        # o_proj
        self.projection_layer = build_o_proj(self.embed_dim,
                                             self.embed_dim,
                                             bias=True,
                                             quant_config=quantization_config,
                                             dtype=dtype,
                                             device=device,
                                             is_tp=True,
                                             tp_align_size=self.head_dim)

    def pre_rms_norm(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
        """Pre rms norm."""
        return pre_rms_norm(q, k)

    def post_rms_norm(self, q: torch.Tensor, k: torch.Tensor, variance: torch.Tensor, dtype: torch.dtype):
        """Post rms norm."""
        eps = self.config.layer_norm_eps
        return post_rms_norm(q, k, self.q_norm.weight, self.k_norm.weight, variance, eps, self.embed_dim, dtype)

    def qkv_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        import lmdeploy.pytorch.distributed as dist
        q_shape = q.shape
        k_shape = k.shape
        q = q.flatten(-2, -1)
        k = k.flatten(-2, -1)

        tp, _ = get_tp_world_rank()
        if tp == 1:
            q = self.q_norm(q).view(q_shape)
            k = self.k_norm(k).view(k_shape)
            return q, k

        # variance
        variance = self.pre_rms_norm(q, k)
        dist.all_reduce(variance)
        q, k = self.post_rms_norm(q, k, variance, q.dtype)
        q = q.view(q_shape)
        k = k.view(k_shape)

        return q, k

    def forward(self, hidden_states):
        """forward."""

        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        q, k, v = self.qkv_proj.split_qkv(qkv_states)

        if self.use_qk_norm:
            q, k = self.qkv_norm(q, k)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.projection_layer(attn_output)
        return attn_output


class InternVLVisionMLP(nn.Module):
    """Intern vl mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()

        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.act = ACT2FN[config.hidden_act]

        self.fc1 = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            dp_disable_tp=True,
        )

        self.fc2 = build_rowwise_linear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            dp_disable_tp=True,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class InternVLVisionLayer(nn.Module):
    """Intern vision layer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = getattr(config, 'norm_type', 'rms_norm')

        self.attention = InternVLVisionAttention(config, dtype=dtype, device=device)
        self.mlp = InternVLVisionMLP(config, dtype=dtype, device=device)
        self.layernorm_before = NORM2FN[self.norm_type](self.embed_dim,
                                                        eps=config.layer_norm_eps,
                                                        dtype=dtype,
                                                        device=device)
        self.layernorm_after = NORM2FN[self.norm_type](self.embed_dim,
                                                       eps=config.layer_norm_eps,
                                                       dtype=dtype,
                                                       device=device)

        self.lambda_1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
        self.lambda_2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))

    @enable_micro_batch(param_name='hidden_states', index=0)
    def _attn(self, hidden_states):
        hidden_states = hidden_states + self.attention(self.layernorm_before(hidden_states).to(
            hidden_states[0].dtype)) * self.lambda_1
        return hidden_states

    @enable_micro_batch(param_name='hidden_states', index=0)
    def _mlp(self, hidden_states):
        hidden_states = hidden_states + self.mlp(self.layernorm_after(hidden_states).to(
            hidden_states.dtype)) * self.lambda_2
        return hidden_states

    def forward(
        self,
        hidden_states,
    ):
        hidden_states = self._attn(hidden_states)
        hidden_states = self._mlp(hidden_states)
        return hidden_states


class InternVLVisionEncoder(nn.Module):
    """Intern vision encoder."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList(
            [InternVLVisionLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)])

    def forward(
        self,
        inputs_embeds,
    ):
        """forward."""
        hidden_states = inputs_embeds
        for _, encoder_layer in enumerate(self.layer):
            layer_outputs = encoder_layer(hidden_states, )
            hidden_states = layer_outputs
        return hidden_states


@vlm_model
class InternVLVisionModel(nn.Module):
    """Intern vision model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config

        self.embeddings = InternVLVisionEmbeddings(config, dtype=dtype, device=device)
        self.encoder = InternVLVisionEncoder(config, dtype=dtype, device=device)
        self.layernorm = None
        if not config.use_mean_pooling:
            self.layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)

    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ):
        """forward."""
        assert pixel_values.dim() == 4
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(inputs_embeds=hidden_states)
        last_hidden_state = hidden_states
        if self.layernorm is not None:
            last_hidden_state = self.layernorm(hidden_states)

        return hidden_states, last_hidden_state


class InternVLMultiModalProjector(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        input_dim = config.vision_config.hidden_size * int(1 / config.downsample_ratio)**2
        self.layer_norm = LayerNorm(input_dim, eps=1e-5, dtype=dtype, device=device)

        quantization_config = getattr(config.text_config, 'quantization_config', None)
        self.act = ACT2FN[config.projector_hidden_act]
        self.linear_1 = build_colwise_linear(
            input_dim,
            config.text_config.hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            dp_disable_tp=True,
        )

        self.linear_2 = build_rowwise_linear(
            config.text_config.hidden_size,
            config.text_config.hidden_size,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            dp_disable_tp=True,
        )

    def forward(self, image_features):
        hidden_states = self.layer_norm(image_features)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        self.vision_tower = InternVLVisionModel(config.vision_config, dtype=dtype, device=device)
        self.multi_modal_projector = InternVLMultiModalProjector(config, dtype=dtype, device=device)
        self.language_model = build_model_from_hf_config(config.text_config, dtype=dtype, device=device)
        self.lm_head = self.language_model.lm_head
        self.vision_feature_layer = config.vision_feature_layer
        self.vision_feature_select_strategy = config.vision_feature_select_strategy

        self.input_processor = InternVLProcessor(self.config, dtype)

        self.compile_vit = False

    def compile_model(self):
        torch_version = version.parse(torch.__version__)
        if torch_version < version.parse('2.5.0'):
            return

        tp, _ = get_tp_world_rank()
        if torch_version >= version.parse('2.6.0') and tp > 1:
            torch._inductor.config.reorder_for_compute_comm_overlap = True
            if isinstance(self.vision_tower, InternVLVisionModel):
                self.vision_tower.encoder.forward = split_batch(self.vision_tower.encoder.forward,
                                                                'inputs_embeds',
                                                                index=0)

        self.get_image_features = torch.compile(self.get_image_features, mode='max-autotune-no-cudagraphs')
        self.compile_vit = True
        self.has_compiled_vit = False

    def _mark_dynamic_once(self, pixel_values, dims):
        """Call torch._dynamo.mark_dynamic to avoid recompile."""
        if not self.compile_vit or self.has_compiled_vit or pixel_values is None:
            return

        torch._dynamo.mark_dynamic(pixel_values, dims)
        self.has_compiled_vit = True

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()

    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        vision_feature_layer: Union[int, List[int]],
        vision_feature_select_strategy: str,
        **kwargs,
    ):
        """Obtains image last hidden states from the vision tower and apply
        multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
            vision_feature_layer (`int` or `List[int]`):
                Layer index or list of layer indices to extract features from.
        Returns:
            vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
        """
        downsample_ratio = self.config.downsample_ratio
        hidden_states, last_hidden_state = self.vision_tower(pixel_values=pixel_values)
        if vision_feature_layer == -1:
            vision_features = last_hidden_state
        else:
            vision_features = hidden_states[vision_feature_layer]
        if vision_feature_select_strategy == 'default':
            vision_features = vision_features[:, 1:, :]

        # Calculate dimensions based on vision features
        channels = vision_features.shape[1]
        feature_size = int(channels**0.5)
        batch_size = vision_features.shape[0]

        # Reshape tensor to spatial dimensions
        vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)

        # Apply downsampling using pixel shuffle
        vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)

        # Reshape tensor to prepare for projection
        vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])

        # Project features through multi-modal projector
        vision_features = self.multi_modal_projector(vision_features)

        return vision_features

    def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
        """Perform pixel shuffle downsampling on vision features.

        Args:
            vision_features (`torch.Tensor`):
                Input tensor of shape (batch_size, width, height, channels).
            scale_factor (`float`, *optional*, defaults to `0.5`):
                Factor by which to downsample. Default is 0.5, which halves the dimensions.

        Returns:
            vision_features (`torch.Tensor`):
                Downsampled tensor of shape (batch_size, height*scale_factor,
                                                width*scale_factor, channels/(scale_factor^2)).
        """
        batch_size, width, height, channels = vision_features.size()

        if height % scale_factor != 0 or width % scale_factor != 0:
            raise ValueError('Height and width must be divisible by scale_factor for proper downsampling.')

        # Reshape to allow downsampling
        vision_features = vision_features.view(batch_size, width, int(height * scale_factor),
                                               int(channels / scale_factor))
        # Permute dimensions to align downsampled axis correctly
        vision_features = vision_features.permute(0, 2, 1, 3).contiguous()

        # Reshape to achieve final downsampled dimensions
        vision_features = vision_features.view(batch_size, int(height * scale_factor), int(width * scale_factor),
                                               int(channels / (scale_factor**2)))

        # Swap height and width back for proper orientation
        vision_features = vision_features.permute(0, 2, 1, 3).contiguous()

        return vision_features

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        if inputs_embeds is None and pixel_values is not None:
            # extract feature
            self._mark_dynamic_once(pixel_values, [0])
            vit_embeds = self.get_image_features(
                pixel_values,
                self.vision_feature_layer,
                self.vision_feature_select_strategy,
            )
            lang_embeds = self.get_input_embeddings()(input_ids)
            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)

            inputs_embeds = lang_embeds
            input_ids = None

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError('You must specify exactly one of input_ids or inputs_embeds')

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        outputs = self.language_model.forward(input_ids=input_ids,
                                              inputs_embeds=inputs_embeds,
                                              past_key_values=past_key_values,
                                              position_ids=position_ids,
                                              attn_metadata=attn_metadata)
        return outputs

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = None

        # vision inputs
        pixel_values = None
        image_mask = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values])
            else:
                pixel_values = None
                image_mask = None

        # get inputs from context
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            vision_embedding_indexing = context.input_embedding_indexing
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            pixel_values=pixel_values,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )

    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
        """Load lora weights."""

        if hasattr(self.model.language_model, 'load_lora_weights'):
            return self.model.language_model.load_lora_weights(weights, adapter_id)
        else:
            from lmdeploy.pytorch.adapter.adapter import load_lora_weights

            return load_lora_weights(weights, adapter_id)

    @classmethod
    def rename_weight(cls, name: str) -> str:
        """Rename weight."""
        if name == 'lm_head.weight':
            return 'language_model.lm_head.weight'
        elif name.startswith('model.language_model.'):
            return 'language_model.model.' + name[len('model.language_model.'):]
        elif name.startswith('model.'):
            return name[len('model.'):]
        return name

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        lang_prefix = 'language_model.'
        lang_prefix_length = len(lang_prefix)
        new_weights = dict()
        params_dict = dict(self.named_parameters())
        vision_stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]
        for name, loaded_weight in weights:

            if name.startswith(lang_prefix):
                new_key = name[lang_prefix_length:]
                new_weights[new_key] = loaded_weight
                continue

            for (param_name, weight_name, shard_id) in vision_stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        self.language_model.load_weights(new_weights.items())

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class InternVLProcessor(BaseModelInputProcessor):
    """Internvl input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/internvl_patch.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig


class InternVisionEmbeddings(nn.Module):
    """Mono vision."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size,
                                         dtype=dtype,
                                         device=device)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))

    def _get_pos_embed(self, pos_embed, H, W):
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size,
                                              -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False)
        pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
        return pos_embed

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        position_embedding = torch.cat(
            [self.position_embedding[:, :1, :],
             self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)],
            dim=1)
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


class InternVisionPatchModel(nn.Module):
    """Mono vision."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config, dtype=dtype, device=device)

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ):
        if len(pixel_values.shape) != 4:
            raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')

        hidden_states = self.embeddings(pixel_values)[:, 1:]
        return hidden_states


================================================
FILE: lmdeploy/pytorch/models/llama.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.models.llama import LlamaConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class LlamaAttention(nn.Module):
    """Rewrite module of LlamaAttention."""

    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            num_replicate_kv_heads=num_replicate_kv_heads,
            is_tp=is_tp,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=is_tp)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class LlamaMLP(nn.Module):
    """Llama mlp."""

    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        mlp_bias = getattr(config, 'mlp_bias', False)
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=mlp_bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=mlp_bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=is_tp)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class LlamaDecoderLayer(nn.Module):
    """Llama decoder layer."""

    def __init__(self,
                 config: LlamaConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = LlamaAttention(config, dtype=dtype, device=device, is_tp=is_tp)

        # build MLP
        self.mlp = LlamaMLP(config, dtype=dtype, device=device, is_tp=is_tp)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class LlamaModel(nn.Module):
    """Llama model."""

    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.aux_hidden_state_layers: Tuple[int] = getattr(config, 'aux_hidden_state_layers', tuple())
        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding in LlamaModel
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # for eagle3
        aux_hidden_states = []
        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        if len(aux_hidden_states) > 0:
            aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
            return dict(hidden_states=hidden_states, aux_hidden_states=aux_hidden_states)
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class LlamaForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of LlamaForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: LlamaConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.dtype = dtype
        # build LLamaModel
        self.model = LlamaModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        hidden_states = hidden_states.to(dtype=self.dtype)
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: torch.Tensor, **kwargs):
        """Get outputs from buffers."""
        num_tokens = input_ids.size(-1)
        outputs = dict()
        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]
        if 'aux_hidden_states' in output_buffers:
            outputs['aux_hidden_states'] = output_buffers['aux_hidden_states'][:, :num_tokens]
        return outputs

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/llama4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.models.llama4 import Llama4Config, Llama4TextConfig, Llama4VisionConfig

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class Llama4TextAttention(nn.Module):
    """attention."""

    def __init__(self,
                 config: Llama4TextConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()

        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attn_scale = config.attn_scale
        self.floor_scale = config.floor_scale
        self.attn_temperature_tuning = config.attn_temperature_tuning
        self.is_causal = True
        self.use_rope = int((layer_idx + 1) % 4 != 0)  # rope unused for dense layers
        self.attn_bias = config.attention_bias

        # qkv
        self.qkv_proj = build_qkv_proj(
            config.hidden_size,
            num_q_heads=self.num_attention_heads,
            num_kv_heads=self.num_key_value_heads,
            head_size=self.head_dim,
            bias=self.attn_bias,
            dtype=dtype,
            device=device,
        )

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        self.attn_fwd = Attention(
            self.num_attention_heads,
            self.head_dim,
            num_kv_heads=self.num_key_value_heads,
            v_head_size=self.head_dim,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(config.num_attention_heads * self.head_dim,
                                           config.hidden_size,
                                           bias=self.attn_bias,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

        if self.config.use_qk_norm and self.use_rope:
            self.qk_norm = RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """forward."""
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        if self.use_rope:
            cos, sin = rotary_pos_emb
            # TODO: fuse apply rotary pos emb
            query_states = query_states.unflatten(-1, (-1, 2)).transpose(-1, -2).flatten(-2)
            key_states = key_states.unflatten(-1, (-1, 2)).transpose(-1, -2).flatten(-2)
            query_states, key_states = self.apply_rotary_pos_emb(
                query_states,
                key_states,
                cos,
                sin,
            )
            query_states = query_states.unflatten(-1, (2, -1)).transpose(-1, -2).flatten(-2)
            key_states = key_states.unflatten(-1, (2, -1)).transpose(-1, -2).flatten(-2)

        if hasattr(self, 'qk_norm'):
            query_states = self.qk_norm(query_states)
            key_states = self.qk_norm(key_states)

        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        attn_output = self.o_proj(attn_output)

        return attn_output


class Llama4TextMLP(nn.Module):
    """attention."""

    def __init__(self,
                 config: Llama4TextConfig,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True):
        super().__init__()

        if intermediate_size is None:
            intermediate_size = config.intermediate_size

        self.config = config

        mlp_bias = False
        mlp_args = dict(
            bias=mlp_bias,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
        )
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            **mlp_args,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(
            intermediate_size,
            config.hidden_size,
            all_reduce=all_reduce,
            **mlp_args,
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Llama4TextMoe(nn.Module):
    """attention."""

    def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts

        self.router = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=None,
        )
        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=1,
            renormalize=False,
            dtype=dtype,
            device=device,
            all_reduce=False,
            quant_config=quantization_config,
        )
        self.shared_expert = Llama4TextMLP(config, dtype=dtype, device=device, is_tp=True, all_reduce=False)

        dist_config = dist.get_dist_manager().current_config()
        self.tp = dist_config.tp

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch, seq_len, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.router(hidden_states)

        topk_weights, topk_ids = torch.topk(router_logits, self.top_k, dim=-1)
        input_weight = topk_weights.float().sigmoid().to(hidden_states.dtype)

        moe_hidden_states = hidden_states[:, None, :] * input_weight[:, :, None]
        moe_hidden_states = moe_hidden_states.view(-1, hidden_dim)
        topk_weights = torch.ones_like(input_weight).reshape(-1, 1)
        topk_ids = topk_ids.reshape(-1, 1)

        out_states = self.experts(
            moe_hidden_states,
            topk_weights,
            topk_ids,
        )

        out_states = out_states.reshape(-1, self.top_k, hidden_dim)
        out_states = out_states.sum(1)

        shared_states = self.shared_expert(hidden_states)
        out_states += shared_states
        out_states = out_states.reshape(batch, seq_len, -1)

        if self.tp > 1:
            dist.all_reduce(out_states)

        return out_states


class Llama4TextDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: Llama4TextConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.self_attn = Llama4TextAttention(config, layer_idx, dtype=dtype, device=device)
        self.use_chunked_attention = int((layer_idx + 1) % 4 != 0)  # <=> use rope
        self.is_moe_layer = layer_idx in config.moe_layers
        if self.is_moe_layer:  # the 128E model interleaves dense / sparse
            self.feed_forward = Llama4TextMoe(config, dtype=dtype, device=device)
        else:
            self.feed_forward = Llama4TextMLP(config,
                                              intermediate_size=config.intermediate_size_mlp,
                                              dtype=dtype,
                                              device=device)

        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        """forward."""

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Llama4TextModel(nn.Module):
    """Llama4 text model."""

    def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)
        self.layers = nn.ModuleList([
            Llama4TextDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)

        self.rotary_emb = self.build_llama4_rotary_embedding(config)

    @staticmethod
    def build_llama4_rotary_embedding(config: Llama4TextConfig):
        """Build llama4 rotary embedding."""
        return build_rotary_embedding_from_config(config)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        **kwargs,
    ):
        """Model forward."""
        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class Llama4ForCausalLM(nn.Module):

    def __init__(self,
                 config: Llama4TextConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.model = Llama4TextModel(config, dtype=dtype, device=device)
        self.vocab_size = config.vocab_size
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            device=device,
                                            dtype=dtype)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        **kwargs,
    ):
        """Model forward."""
        outputs = self.model(
            inputs_embeds=inputs_embeds,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            **kwargs,
        )

        return outputs

    def get_input_embeddings(self):
        """Input embeddings."""
        return self.model.embed_tokens

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)


class Llama4MultiModalProjector(nn.Module):

    def __init__(self, config: Llama4Config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.linear_1 = nn.Linear(
            config.vision_config.vision_output_dim,
            config.text_config.hidden_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

    def forward(self, image_features):
        """forward."""
        hidden_states = self.linear_1(image_features)
        return hidden_states


class Llama4UnfoldConvolution(nn.Module):
    """Llama4 unfold conv."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        kernel_size = config.patch_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
        self.linear = nn.Linear(
            config.num_channels * kernel_size[0] * kernel_size[1],
            config.hidden_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.unfold(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        hidden_states = self.linear(hidden_states)
        return hidden_states


class Llama4VisionRotaryEmbedding(nn.Module):

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        idx = config.image_size // config.patch_size
        img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
        img_idx[-1, -1] = -2  # ID_CLS_TOKEN
        frequencies_x = img_idx % idx  # get the coordinates of the 2d matrix along x
        frequencies_y = img_idx // idx  # get the coordinates of the 2d matrix along y
        freq_dim = config.hidden_size // config.num_attention_heads // 2
        rope_freq = 1.0 / (get_rope_theta(config)**(torch.arange(0, freq_dim, 2)[:(freq_dim // 2)].float() / freq_dim))
        freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
        freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
        self.freqs_ci = freq_cis.to(device)  # idx**2, idx**2, idx * 2

    def forward(self, hidden_states):
        return self.freqs_ci


def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
    ndim = query.ndim
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
    return freqs_ci.view(*shape)


def vision_apply_rotary_emb(
    query: torch.Tensor,
    key: torch.Tensor,
    freqs_ci: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
    key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
    freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)  # freqs_ci[:,:,None,:]
    freqs_ci = freqs_ci.to(query_.device)
    query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
    key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
    return query_out.type_as(query), key_out.type_as(key)  # but this drops to 8e-3


class Llama4VisionAttention(nn.Module):
    """Vision attn."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads

        # qkv
        self.qkv_proj = build_qkv_proj(
            self.embed_dim,
            num_q_heads=self.num_heads,
            num_kv_heads=self.num_heads,
            head_size=self.head_dim,
            bias=True,
            dtype=dtype,
            device=device,
            is_tp=True,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim,
                                           self.embed_dim,
                                           bias=True,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_ci: torch.Tensor,
    ):
        """forward."""
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
        query_states = query_states.reshape(hidden_shape)
        key_states = key_states.reshape(hidden_shape)
        value_states = value_states.reshape(hidden_shape)

        query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
        attention_interface = ALL_ATTENTION_FUNCTIONS['sdpa']
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            None,
            dropout=0.0,
            scaling=None,
            is_causal=False,  # HAS TO BE ENFORCED
            output_attentions=False,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output


class Llama4VisionMLP(nn.Module):
    """Vision mlp."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.activation_fn = nn.GELU()
        self.fc1 = build_colwise_linear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)
        self.fc2 = build_rowwise_linear(config.intermediate_size,
                                        config.hidden_size,
                                        bias=True,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Llama4VisionEncoderLayer(nn.Module):
    """Vision encoder layer."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config

        self.self_attn = Llama4VisionAttention(config, dtype=dtype, device=device)
        self.mlp = Llama4VisionMLP(config, dtype=dtype, device=device)

        self.input_layernorm = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)

    def forward(
        self,
        hidden_state: torch.Tensor,
        freqs_ci: torch.Tensor,
    ):
        """forward."""
        # Self Attention
        residual = hidden_state

        hidden_state = self.input_layernorm(hidden_state)

        hidden_state = self.self_attn(
            hidden_state,
            freqs_ci=freqs_ci,
        )
        hidden_state = residual + hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = residual + hidden_state

        return hidden_state


class Llama4VisionEncoder(nn.Module):
    """Vision encoder."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [Llama4VisionEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_ci: torch.Tensor,
    ):
        """forward."""
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_state=hidden_states,
                freqs_ci=freqs_ci,
            )
        return hidden_states


def pixel_shuffle(input_tensor: torch.Tensor, shuffle_ratio: int):
    # input_tensor: [batch_size, num_patches, channels]
    import math
    batch_size, num_patches, channels = input_tensor.shape
    patch_size = int(math.sqrt(num_patches))

    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
    batch_size, height, width, channels = input_tensor.size()

    reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

    reshaped_tensor = reshaped_tensor.view(batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio),
                                           int(channels / (shuffle_ratio**2)))
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
    return output_tensor


class Llama4VisionMLP2(torch.nn.Module):

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.fc1 = build_colwise_linear(self.intermediate_size,
                                        config.projector_input_dim,
                                        bias=False,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)
        self.fc2 = build_rowwise_linear(config.projector_output_dim,
                                        config.projector_output_dim,
                                        bias=False,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)
        self.activation_fn = nn.GELU()  # ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        """forward."""
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        return self.activation_fn(self.fc2(hidden_states))


class Llama4VisionPixelShuffleMLP(nn.Module):

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
        self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
        self.output_dim = config.projector_output_dim
        self.mlp = Llama4VisionMLP2(config, dtype=dtype, device=device)

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
        return self.mlp(encoded_patches)


class Llama4VisionModel(nn.Module):
    """Llama4 vision model."""

    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels

        self.num_patches = (self.image_size // self.patch_size)**2 + 1
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = Llama4UnfoldConvolution(config, dtype=dtype, device=device)

        self.class_embedding = nn.Parameter(self.scale * torch.empty(self.hidden_size, dtype=dtype, device=device))
        self.positional_embedding_vlm = nn.Parameter(
            self.scale * torch.empty(self.num_patches, self.hidden_size, dtype=dtype, device=device))
        self.rotary_embedding = Llama4VisionRotaryEmbedding(config, dtype=dtype, device=device)

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size, dtype=dtype, device=device)
        self.layernorm_post = nn.LayerNorm(self.hidden_size, dtype=dtype, device=device)

        # encoders
        self.model = Llama4VisionEncoder(config, dtype=dtype, device=device)
        self.vision_adapter = Llama4VisionPixelShuffleMLP(config, dtype=dtype, device=device)

    def get_input_embeddings(self):
        """This function is used to fetch the first embedding layer to activate
        grads on inputs."""
        return self.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
    ):
        """forward."""
        batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
        num_concurrent_media = 1
        num_chunks = 1
        hidden_state = self.patch_embedding(pixel_values)
        _, num_patches, hidden_dim = hidden_state.shape

        # Add cls token
        hidden_state = hidden_state.reshape(batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches,
                                            hidden_dim)
        class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
        num_patches += 1

        # Position embeddings
        hidden_state = hidden_state.reshape(batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches,
                                            hidden_dim)
        positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
        hidden_state = hidden_state + positional_embedding

        hidden_state = self.layernorm_pre(hidden_state)

        hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
        freqs_ci = self.rotary_embedding(pixel_values)

        output = self.model(
            hidden_state,
            freqs_ci=freqs_ci,
        )

        hidden_state = output

        hidden_state = self.layernorm_post(hidden_state)

        hidden_state = hidden_state[:, :-1, :]

        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
        hidden_state = self.vision_adapter(hidden_state)

        return hidden_state


class Llama4ForConditionalGeneration(nn.Module, CudaGraphMixin):

    def __init__(self,
                 config: Llama4Config,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        self.vision_model = Llama4VisionModel(config.vision_config, dtype=dtype, device=device)

        self.multi_modal_projector = Llama4MultiModalProjector(config, dtype=dtype, device=device)

        self._update_quant_config(config)
        self.language_model = Llama4ForCausalLM(config.text_config, ctx_mgr, dtype=dtype, device=device)
        self.vocab_size = config.text_config.vocab_size

        self.input_processor = Llama4InputProcessor(config, dtype)

    @staticmethod
    def _update_quant_config(config: Llama4Config):
        """Update quant config."""
        quant_config = getattr(config, 'quantization_config', None)

        if quant_config is None:
            return config

        quantization_config = dict(
            quant_dtype='float8_e4m3fn',
            quant_method='smooth_quant',
        )
        text_config = config.text_config
        setattr(text_config, 'quantization_config', quantization_config)

        return config

    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        **kwargs,
    ):
        """Get image features."""
        kwargs = {k: v for k, v in kwargs.items() if v is not None}
        hidden_state = self.vision_model(pixel_values, **kwargs)
        return hidden_state

    def get_input_embeddings(self):
        """Input embeddings."""
        return self.language_model.get_input_embeddings()

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.FloatTensor = None,
        image_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward."""
        image_embeds = None
        if pixel_values is not None:
            image_features = self.get_image_features(pixel_values=pixel_values, )
            vision_flat = image_features.view(-1, image_features.size(-1))
            image_embeds = self.multi_modal_projector(vision_flat)

        lang_embeds: torch.Tensor = self.get_input_embeddings()(input_ids)

        if image_embeds is not None:
            lang_embeds.masked_scatter_(image_mask[..., None], image_embeds)

        inputs_embeds = lang_embeds

        return self.language_model(
            inputs_embeds,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
        )

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.language_model.get_logits(hidden_states)

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # vision inputs
        pixel_values = None
        image_mask = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values])
            else:
                pixel_values = None
                image_mask = None

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            pixel_values=pixel_values,
            image_mask=image_mask,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def _load_experts_bf16(name, loaded_weight):
            if '.gate_up_proj' in name:
                loaded_weight = loaded_weight.to(device)
                name = name.replace('.gate_up_proj', '.gate_up.weight')
                param = params_dict[name]
                for exp_id in range(num_experts):
                    weight_gate, weight_up = loaded_weight[exp_id].chunk(2, -1)
                    load_weight(param, weight_gate.t(), expert_id=exp_id, shard_id='gate')
                    load_weight(param, weight_up.t(), expert_id=exp_id, shard_id='up')
            elif '.down_proj' in name:
                loaded_weight = loaded_weight.to(device)
                name = name.replace('.down_proj', '.down.weight')
                param = params_dict[name]
                for exp_id in range(num_experts):
                    weight = loaded_weight[exp_id].t()
                    load_weight(param, weight, expert_id=exp_id, shard_id='down')

        def _load_experts_fp8(name, loaded_weight):
            name = name.replace('.weight_scale', '.scale')
            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        def _load_experts(name, loaded_weight):
            """Load experts weight."""
            quantization_config = getattr(self.config, 'quantization_config', None)
            if quantization_config is None:
                _load_experts_bf16(name, loaded_weight)
            else:
                _load_experts_fp8(name, loaded_weight)

        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        num_experts = self.config.text_config.num_local_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        device = next(iter(params_dict.values())).device
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            if '.experts' in name:
                _load_experts(name, loaded_weight)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Llama4InputProcessor(BaseModelInputProcessor):
    """Llama4 input processor."""

    def __init__(self, config: Llama4Config, dtype) -> None:
        self.config = config
        self.dtype = dtype

        self.vision_config = config.vision_config

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""

        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/llama_eagle.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.nn import build_rotary_embedding_from_config
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .llama import LlamaDecoderLayer
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin


class EagleLlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False)

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if layer_idx == 0:
            del self.input_layernorm
            setattr(self, 'input_layernorm', lambda x: x)


class EagleLlamaModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            EagleLlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])
        # build fc
        self.fc = nn.Linear(
            config.hidden_size * 2,
            config.hidden_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

        # build rotary embedding in LlamaModel
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        previous_hidden_states: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""
        # token embedding
        if inputs_embeds is None:
            assert input_ids is not None
            inputs_embeds = self.embed_tokens(input_ids)
        previous_hidden_states = previous_hidden_states.to(inputs_embeds)
        hidden_states = torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
        hidden_states = self.fc(hidden_states)
        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )
        hidden_states = hidden_states + residual
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class EagleLlamaForCausalLM(nn.Module, CudaGraphMixin):

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self, config, ctx_mgr, dtype=None, device=None):
        nn.Module.__init__(self)
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.dtype = dtype
        # build LLamaModel
        self.model = EagleLlamaModel(config, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        target_hidden_states: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            previous_hidden_states=target_hidden_states,
        )
        return hidden_states

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata
        target_hidden_states = context.target_hidden_states
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            target_hidden_states=target_hidden_states,
        )

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,
                                                                                     max_tokens,
                                                                                     self.config.hidden_size,
                                                                                     dtype=self.dtype)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        num_tokens = kwargs['input_ids'].size(-1)

        is_decoding = graph_meta.is_decoding
        input_buffers = graph_meta.input_buffers
        padded_num_tokens = new_inputs['input_ids'].size(-1)

        target_hidden_states = kwargs.get('target_hidden_states')
        assert target_hidden_states is not None
        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states
        if is_decoding:
            new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'][:, :padded_num_tokens, :]
        else:
            new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']

        return new_inputs

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            name = 'model.' + name
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/llama_eagle3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .llama import LlamaDecoderLayer
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin


class Eagle3LlamaDecoderLayer(LlamaDecoderLayer):
    """Llama decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False)
        self.layer_idx = layer_idx

        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # override attention qkv
        self.self_attn.qkv_proj = build_qkv_proj(
            2 * hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.hidden_norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        embeds: torch.Tensor,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        attn_metadata: Any = None,
    ):

        residual = hidden_states
        embeds = self.input_layernorm(embeds)
        hidden_states = self.hidden_norm(hidden_states)
        hidden_states = torch.cat([embeds, hidden_states], dim=-1)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Eagle3LlamaModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.dtype = dtype
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build layer
        self.midlayer = Eagle3LlamaDecoderLayer(config, layer_idx=0, dtype=dtype, device=device)
        target_hidden_size = getattr(config, 'target_hidden_size', config.hidden_size)
        self.fc = build_rowwise_linear(
            target_hidden_size * 3,
            config.hidden_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
        # build rotary embedding in LlamaModel
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        previous_hidden_states: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""
        # token embedding
        if inputs_embeds is None:
            assert input_ids is not None
            inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)
        previous_hidden_states = previous_hidden_states.to(inputs_embeds)
        if previous_hidden_states.shape[-1] != inputs_embeds.shape[-1]:
            # previous_hidden_states if from target model
            previous_hidden_states = self.fc(previous_hidden_states)
        # rotary embedding
        cos, sin = self.rotary_emb(previous_hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        past_key_value = past_key_values[0]
        hidden_states, residual = self.midlayer(
            inputs_embeds,
            previous_hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )
        hidden_states, hidden_states_prenorm = self.norm(hidden_states, residual)
        outputs = dict(hidden_states=hidden_states, hidden_states_prenorm=hidden_states_prenorm)
        return outputs

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Eagle3LlamaForCausalLM(nn.Module, CudaGraphMixin):

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self, config, ctx_mgr, dtype=None, device=None):
        nn.Module.__init__(self)
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.dtype = dtype

        if config.num_hidden_layers != 1:
            raise ValueError('eagle3 only supports 1 decode layer')

        # build LLamaModel
        self.model = Eagle3LlamaModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.draft_vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)
        self.draft_id_to_target_id = nn.Parameter(
            torch.zeros(self.config.draft_vocab_size, dtype=torch.long, device=device),
            requires_grad=False,
        )
        self.include_embed_tokens = False

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        target_hidden_states: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            previous_hidden_states=target_hidden_states,
        )
        return hidden_states

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata
        target_hidden_states = context.target_hidden_states
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            target_hidden_states=target_hidden_states,
        )

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        logits = self.lm_head(hidden_states)
        return logits

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens
        target_hidden_states = kwargs.get('target_hidden_states')
        assert target_hidden_states is not None
        target_hidden_size = target_hidden_states.size(-1)
        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,
                                                                                     max_tokens,
                                                                                     target_hidden_size,
                                                                                     dtype=self.dtype)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        num_tokens = kwargs['input_ids'].size(-1)

        input_buffers = graph_meta.input_buffers

        target_hidden_states = kwargs.get('target_hidden_states')
        assert target_hidden_states is not None
        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states

        new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']

        return new_inputs

    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: torch.Tensor, **kwargs):
        """Get outputs from buffers."""
        num_tokens = input_ids.size(-1)
        outputs = dict()
        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]
        outputs['hidden_states_prenorm'] = output_buffers['hidden_states_prenorm'][:, :num_tokens]
        return outputs

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'd2t' in name:
                name = 'draft_id_to_target_id'
                base = torch.arange(self.config.draft_vocab_size,
                                    device=loaded_weight.device,
                                    dtype=loaded_weight.dtype)
                loaded_weight += base
            elif 'lm_head.weight' not in name:
                name = 'model.' + name
            if 'embed_tokens' in name:
                self.include_embed_tokens = True
            if 't2d' in name:
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/llava.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.llava.configuration_llava import LlavaConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model


class LlavaMultiModalProjector(nn.Module):

    def __init__(self, config: LlavaConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from transformers.activations import ACT2FN

        self.linear_1 = nn.Linear(config.vision_config.hidden_size,
                                  config.text_config.hidden_size,
                                  bias=True,
                                  dtype=dtype,
                                  device=device)
        self.act = ACT2FN[config.projector_hidden_act]
        self.linear_2 = nn.Linear(config.text_config.hidden_size,
                                  config.text_config.hidden_size,
                                  bias=True,
                                  dtype=dtype,
                                  device=device)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class CLIPVisionEmbeddings(nn.Module):
    """Clip vision embedding."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
            dtype=dtype,
            device=device,
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(
            self.num_positions,
            self.embed_dim,
            dtype=dtype,
            device=device,
        )
        self.register_buffer('position_ids',
                             torch.arange(self.num_positions, device=device).expand((1, -1)),
                             persistent=False)

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """This method allows to interpolate the pre-trained position
        encodings, to be able to use the model on higher resolution images.

        This method is also adapted to support torch.jit tracing.
        """

        num_patches = embeddings.shape[1] - 1
        position_embedding = self.position_embedding.weight.unsqueeze(0)
        num_positions = position_embedding.shape[1] - 1

        # always interpolate when tracing
        # to ensure the exported model works for dynamic input shapes
        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embedding(self.position_ids)

        from transformers.utils import torch_int

        class_pos_embed = position_embedding[:, :1]
        patch_pos_embed = position_embedding[:, 1:]

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode='bicubic',
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
        batch_size, _, height, width = pixel_values.shape
        if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
            raise ValueError(f"Input image size ({height}*{width}) doesn't match model"
                             f' ({self.image_size}*{self.image_size}).')
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings


class CLIPAttention(nn.Module):
    """Clip attention."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.qkv_proj = build_qkv_proj(
            self.embed_dim,
            num_q_heads=self.num_heads,
            num_kv_heads=self.num_heads,
            head_size=self.head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        self.scale = self.head_dim**-0.5

        # o_proj
        self.out_proj = build_rowwise_linear(self.embed_dim,
                                             self.embed_dim,
                                             bias=True,
                                             quant_config=quantization_config,
                                             dtype=dtype,
                                             device=device,
                                             is_tp=True)

    def forward(
        self,
        hidden_states,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
    ):
        """forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        q, k, v = self.qkv_proj.split_qkv(qkv_states)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if attention_mask is not None and causal_attention_mask is not None:
            attn_mask = attention_mask + causal_attention_mask
        elif causal_attention_mask is not None:
            attn_mask = causal_attention_mask
        else:
            attn_mask = attention_mask

        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.out_proj(attn_output)
        return attn_output


class CLIPMLP(nn.Module):
    """Clip mlp."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        from transformers.activations import ACT2FN
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )
        self.fc2 = build_rowwise_linear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class CLIPEncoderLayer(nn.Module):
    """Clip encoder layer."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = CLIPAttention(config, dtype=dtype, device=device)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.mlp = CLIPMLP(config, dtype=dtype, device=device)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        causal_attention_mask: torch.Tensor,
    ):
        """forward."""
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """Clip encoder."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [CLIPEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        inputs_embeds,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        vision_feature_layer: int = -1,
    ):
        """forward."""
        hidden_states = inputs_embeds
        num_vision_layers = len(self.layers) + vision_feature_layer + 1
        for _, encoder_layer in enumerate(self.layers[:num_vision_layers]):
            layer_outputs = encoder_layer(
                hidden_states,
                attention_mask,
                causal_attention_mask=causal_attention_mask,
            )

            hidden_states = layer_outputs

        return hidden_states


class CLIPVisionTransformer(nn.Module):
    """Clip vision transformer."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config, dtype=dtype, device=device)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.encoder = CLIPEncoder(config, dtype=dtype, device=device)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        interpolate_pos_encoding: bool = False,
        vision_feature_layer: int = -1,
    ) -> BaseModelOutputWithPooling:
        """forward."""
        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
        hidden_states = self.pre_layrnorm(hidden_states)

        encoder_outputs = self.encoder(inputs_embeds=hidden_states, vision_feature_layer=vision_feature_layer)

        last_hidden_state = encoder_outputs
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.post_layernorm(pooled_output)

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=None,
            attentions=None,
        )


@vlm_model
class CLIPVisionModel(nn.Module):
    """Clip vision model."""

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.vision_model = CLIPVisionTransformer(config, dtype=dtype, device=device)

    def forward(self,
                pixel_values: torch.FloatTensor,
                interpolate_pos_encoding: bool = False,
                vision_feature_layer: int = -1,
                **kwargs):
        """forward."""
        return self.vision_model(pixel_values,
                                 interpolate_pos_encoding=interpolate_pos_encoding,
                                 vision_feature_layer=vision_feature_layer)


def build_vision_model(vision_config, dtype: torch.dtype = None, device: torch.device = None):
    """Build vision model."""
    model_type = vision_config.model_type

    if model_type == 'clip_vision_model':
        return CLIPVisionModel(vision_config, dtype, device)
    else:
        raise NotImplementedError(f'<{model_type}> is not implemented.')


class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin, DeployModelMixin):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        text_config = config.text_config

        self.vision_tower = build_vision_model(config.vision_config, dtype=dtype, device=device)

        self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device)

        self.multi_modal_projector = LlavaMultiModalProjector(config, dtype=dtype, device=device)

        self.input_processor = LLavaInputProcessor(config, dtype)

    def get_image_features(self,
                           pixel_values,
                           vision_feature_layer: int = -1,
                           vision_feature_select_strategy: str = 'default'):
        """Get image features."""
        selected_image_feature = self.vision_tower(pixel_values, vision_feature_layer=vision_feature_layer)[0]
        if vision_feature_select_strategy == 'default':
            selected_image_feature = selected_image_feature[:, 1:]
        elif vision_feature_select_strategy == 'full':
            selected_image_feature = selected_image_feature
        else:
            raise ValueError(f'Unexpected select feature strategy: {vision_feature_select_strategy}'  # noqa: E501
                             )
        image_features = self.multi_modal_projector(selected_image_feature)
        image_features = image_features.flatten(0, 1)[None]

        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        if inputs_embeds is None:
            image_features = None
            if pixel_values is not None:
                vision_feature_layer = self.config.vision_feature_layer
                select_strategy = self.config.vision_feature_select_strategy
                image_features = self.get_image_features(pixel_values,
                                                         vision_feature_layer=vision_feature_layer,
                                                         vision_feature_select_strategy=select_strategy)
            inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
            if pixel_values is not None:
                inputs_embeds.masked_scatter_(image_mask[..., None], image_features)

        return self.language_model.forward(input_ids=input_ids,
                                           inputs_embeds=inputs_embeds,
                                           past_key_values=past_key_values,
                                           position_ids=position_ids,
                                           attn_metadata=attn_metadata)

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.language_model.get_logits(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # vision inputs
        pixel_values = None
        image_mask = None
        if context.input_multimodals is not None:
            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            pixel_values = [data for im_data in pixel_values for data in im_data]
            if len(pixel_values) > 0:
                image_token_id = pixel_values[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data for data in pixel_values])
            else:
                pixel_values = None
                image_mask = None

        # get inputs from context
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing

        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            pixel_values=pixel_values,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        # vis model
        lang_prefix = 'language_model.'
        prefix_length = len(lang_prefix)
        new_weights = dict()
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if name.startswith(lang_prefix):
                new_key = name[prefix_length:]
                new_weights[new_key] = loaded_weight
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)

        self.language_model.load_weights(new_weights.items())

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class LLavaInputProcessor(BaseModelInputProcessor):
    """Llava input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):

    from transformers.image_processing_utils import select_best_resolution

    if not isinstance(grid_pinpoints, list):
        raise TypeError('grid_pinpoints should be a list of tuples or lists')

    if not isinstance(image_size, (list, tuple)):
        image_size = image_size.tolist()

    height, width = select_best_resolution(image_size, grid_pinpoints)
    return height // patch_size, width // patch_size


def unpad_image(tensor, original_size):
    """Unpads a PyTorch tensor of a padded and resized image."""
    if not isinstance(original_size, (list, tuple)):
        original_size = original_size.tolist()
    original_height, original_width = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(round(original_height * scale_factor, 7))
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding:current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(round(original_width * scale_factor, 7))
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding:current_width - padding]

    return unpadded_tensor


def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
    """Calculate the number of patches after the preprocessing for images of
    any resolution."""
    from transformers.image_processing_utils import select_best_resolution
    if not isinstance(grid_pinpoints, list):
        raise TypeError('grid_pinpoints should be a list of tuples or lists')

    if not isinstance(image_size, (list, tuple)):
        image_size = image_size.tolist()

    best_resolution = select_best_resolution(image_size, grid_pinpoints)
    height, width = best_resolution

    num_patches = (height // patch_size) * (width // patch_size)
    # add the base patch
    num_patches += 1
    return num_patches


class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
        self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=dtype, device=device))
        self.input_processor = LLavaNextInputProcessor(config, dtype)

    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        image_sizes: torch.Tensor,
        vision_feature_layer: int,
        vision_feature_select_strategy: str,
    ):
        # ! infer image_num_patches from image_sizes
        image_num_patches = [
            image_size_to_num_patches(
                image_size=imsize,
                grid_pinpoints=self.config.image_grid_pinpoints,
                patch_size=self.config.vision_config.image_size,
            ) for imsize in image_sizes
        ]
        if pixel_values.dim() == 5:
            # stacked if input is
            # (batch_size, num_patches, num_channels, height, width)
            _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
            pixel_values = torch.cat(_pixel_values_list, dim=0)
        elif pixel_values.dim() != 4:
            # otherwise has to be stacked from list of
            # (num_patches, num_channels, height, width)
            raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
                             'expect to be of 4 or 5 dimensions')

        selected_image_feature = self.vision_tower(pixel_values, vision_feature_layer=vision_feature_layer)[0]
        if vision_feature_select_strategy == 'default':
            selected_image_feature = selected_image_feature[:, 1:]
        elif vision_feature_select_strategy == 'full':
            selected_image_feature = selected_image_feature
        image_features = self.multi_modal_projector(selected_image_feature)
        image_features = torch.split(image_features, image_num_patches, dim=0)
        return image_features

    def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):

        new_image_features = []
        feature_lens = []
        for image_idx, image_feature in enumerate(image_features):
            if image_feature.shape[0] > 1:
                base_image_feature = image_feature[0]
                image_feature = image_feature[1:]
                height = width = (self.config.vision_config.image_size // self.config.vision_config.patch_size)

                if vision_feature_select_strategy == 'default':
                    expected_num_patches = height * width
                elif vision_feature_select_strategy == 'full':
                    expected_num_patches = height * width + 1
                if expected_num_patches != base_image_feature.shape[0]:
                    raise ValueError('The number of patches is '
                                     'not consistent with the image size.')

                (num_patch_height, num_patch_width) = get_anyres_image_grid_shape(
                    image_sizes[image_idx],
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                image_feature = unpad_image(image_feature, image_sizes[image_idx])
                if image_newline is not None:
                    image_feature = torch.cat(
                        (
                            image_feature,
                            image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
                        ),
                        dim=-1,
                    )
                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                image_feature = torch.cat((base_image_feature, image_feature), dim=0)
            else:
                image_feature = image_feature[0]
                if image_newline is not None:
                    image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
            new_image_features.append(image_feature)
            feature_lens.append(image_feature.size(0))
        image_features = torch.cat(new_image_features, dim=0)
        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_sizes: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        if inputs_embeds is None:
            image_features = None
            if pixel_values is not None:
                vision_feature_layer = self.config.vision_feature_layer
                select_strategy = self.config.vision_feature_select_strategy
                image_sizes = image_sizes.tolist()
                image_features = self.get_image_features(pixel_values,
                                                         image_sizes,
                                                         vision_feature_layer=vision_feature_layer,
                                                         vision_feature_select_strategy=select_strategy)
                image_features = self.pack_image_features(
                    image_features,
                    image_sizes,
                    vision_feature_select_strategy=select_strategy,
                    image_newline=self.image_newline,
                )
                image_features = image_features[None]
            inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
            if pixel_values is not None:
                inputs_embeds.masked_scatter_(image_mask[..., None], image_features)

        return self.language_model.forward(input_ids=input_ids,
                                           inputs_embeds=inputs_embeds,
                                           past_key_values=past_key_values,
                                           position_ids=position_ids,
                                           attn_metadata=attn_metadata)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # vision inputs
        pixel_values = None
        image_sizes = None
        image_mask = None
        if context.input_multimodals is not None:
            img_mms = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            img_mms = [data for im_data in img_mms for data in im_data]
            if len(img_mms) > 0:
                image_token_id = img_mms[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                pixel_values = torch.cat([data.data.flatten(0, 1) for data in img_mms])
                image_sizes = torch.cat([data.meta['image_sizes'] for data in img_mms])
            else:
                pixel_values = None
                image_sizes = None

        # get inputs from context
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing

        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )


class LLavaNextInputProcessor(BaseModelInputProcessor):
    """Llava input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values'].to(self.dtype)
            image_sizes = input_mm['image_sizes']
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_sizes=image_sizes, image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/minicpm3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_merged_colwise_linear, build_rowwise_linear
from lmdeploy.pytorch.nn.rotary_embedding import (ApplyRotaryEmb, LongRoPEScalingParameters, get_rope_parameters,
                                                  get_rope_theta)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


# TODO use MLA of pytorch engine
class MiniCPMAttention(nn.Module):
    """Minicpm3 attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = None
        self.q_lora_rank = config.q_lora_rank
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.hidden_size // config.num_attention_heads
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim

        if self.q_lora_rank is None:
            self.q_proj = build_colwise_linear(
                self.hidden_size,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
            )
        else:
            self.q_a_proj = build_colwise_linear(
                self.hidden_size,
                config.q_lora_rank,
                bias=config.attention_bias,
                dtype=dtype,
                device=device,
                is_tp=False,
            )
            self.q_a_layernorm = RMSNorm(config.q_lora_rank,
                                         1e-6,
                                         quant_config=quantization_config,
                                         dtype=dtype,
                                         device=device)
            self.q_b_proj = build_colwise_linear(
                config.q_lora_rank,
                self.num_heads * self.q_head_dim,
                bias=False,
                dtype=dtype,
                device=device,
                is_tp=True,
            )

        self.kv_a_proj_with_mqa = build_colwise_linear(
            self.hidden_size,
            config.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=False,
        )
        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
                                      1e-6,
                                      quant_config=quantization_config,
                                      dtype=dtype,
                                      device=device)
        self.kv_b_proj = build_colwise_linear(
            config.kv_lora_rank,
            self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.apply_rotary_pos_emb = ApplyRotaryEmb()
        self.softmax_scale = self.q_head_dim**(-0.5)
        self.attn_fwd = Attention(self.num_heads,
                                  config.kv_lora_rank + self.qk_rope_head_dim,
                                  scale=self.softmax_scale,
                                  num_kv_heads=config.num_key_value_heads)

        self.o_proj = build_rowwise_linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
            dtype=dtype,
            device=device,
            is_tp=True,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        world_size, _ = get_tp_world_rank()
        num_heads = self.num_heads // world_size
        bsz, q_len, _ = hidden_states.size()

        # qkv_proj
        bsz, q_len, _ = hidden_states.size()

        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
        q = q.view(bsz, q_len, num_heads, self.q_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
        kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, num_heads,
                                                                      self.qk_nope_head_dim + self.v_head_dim))

        k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        q_pe, k_pe = self.apply_rotary_pos_emb(
            q_pe,
            k_pe,
            cos,
            sin,
            inplace=True,
        )

        query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
        query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe

        key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
        key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe

        if self.q_head_dim != self.v_head_dim:
            value_states = torch.nn.functional.pad(value_states, [0, self.q_head_dim - self.v_head_dim])

        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            inplace=False,
        )
        if self.q_head_dim != self.v_head_dim:
            attn_output = attn_output[:, :, :, :self.v_head_dim]

        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output


class MiniCPMMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(config.intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class MiniCPMDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = MiniCPMAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = MiniCPMMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)
        self.scale_depth = config.scale_depth
        self.num_hidden_layers = config.num_hidden_layers

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        attn_metadata: Any = None,
    ):

        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))

        outputs = (hidden_states, residual)
        return outputs


class MiniCPM3Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.scale_emb = config.scale_emb

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            MiniCPMDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
        # build rotary embedding
        emb_type = RopeType.LinearScaling
        rope_dim = config.qk_rope_head_dim
        rope_max_pos_emb = config.max_position_embeddings
        rope_base = get_rope_theta(config)
        rope_scaling = get_rope_parameters(config)
        if rope_scaling is not None:
            scaling_type = rope_scaling['type']
            assert scaling_type in ['longrope', 'su']
            emb_type = RopeType.LongRoPEScaling
            ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb)

            longrope_params = LongRoPEScalingParameters(short_factor=rope_scaling['short_factor'],
                                                        long_factor=rope_scaling['long_factor'],
                                                        original_max_position_embeddings=ori_pos_emb)
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                rope_base,
                longrope_params=longrope_params,
                emb_type=emb_type,
            )
        else:
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                rope_base,
                emb_type=emb_type,
            )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.scale_emb

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)
        # decoding
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, _ = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states = self.norm(hidden_states)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class MiniCPM3ForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of MiniCPM3ForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build LLamaModel
        self.model = MiniCPM3Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

        logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
        return logits

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            # ('.qkv_proj', '.q_proj', 'q'),
            # ('.qkv_proj', '.k_proj', 'k'),
            # ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/minicpmv26.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class MiniCPMV26Attention(nn.Module):
    """Rewrite module of MiniCPMV26Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=True,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class MiniCPMV26MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(config.intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class MiniCPMV26DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = MiniCPMV26Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = MiniCPMV26MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class MiniCPMV26Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            MiniCPMV26DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class MiniCPMVForCausalLM(nn.Module, CudaGraphMixin):
    """Rewrote model of MiniCPMVForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = MiniCPMV26Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters(prefix='llm'))
        for name, loaded_weight in weights:
            if 'vpm' in name or 'resampler' in name:
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/mistral.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class MistralAttention(nn.Module):
    """Rewrite module of MistralAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = config.head_dim
        if head_dim is None:
            head_dim = hidden_size // num_heads

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=False,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class MistralMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class MistralDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = MistralAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = MistralMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class MistralModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            MistralDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class MistralForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = MistralModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/mixtral.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class MixtralAttention(nn.Module):
    """Mixtral attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)

        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = hidden_size // num_heads

        # qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.window_size = config.sliding_window or -1
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=self.window_size,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        attn_output = self.o_proj(attn_output)

        return attn_output


class MixtralSparseMoeBlock(nn.Module):
    """Mixtral sparse moe block."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
            quant_config=None,
        )

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=True,
            dtype=dtype,
            device=device,
            all_reduce=True,
            quant_config=quantization_config,
        )

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)

        topk_weights, topk_ids = self.softmax_topk(router_logits)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        out_states = out_states.reshape(batch_size, sequence_length, -1)
        return out_states, router_logits


class MixtralDecoderLayer(nn.Module):
    """Mixtral decoder layer."""

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = MixtralAttention(config, dtype=dtype, device=device)
        self.block_sparse_moe = MixtralSparseMoeBlock(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states, _ = self.block_sparse_moe(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class MixtralModel(nn.Module):
    """Mixtral model."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)
        self.layers = nn.ModuleList([
            MixtralDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device)

        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds
        residual = None
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)
        for idx, decoder_layer in enumerate(self.layers):

            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class MixtralForCausalLM(nn.Module, CudaGraphMixin):
    """Mixture model for causalLM."""

    def __init__(self,
                 config: Any,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.model = MixtralModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        num_experts = self.config.num_local_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
                break
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/module_map.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch.models'

# ascend module
MODULE_MAP = dict()
ASCEND_MODULE_MAP = dict()
MACA_MODULE_MAP = dict()
CAMB_MODULE_MAP = dict()

DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP, maca=MACA_MODULE_MAP, camb=CAMB_MODULE_MAP)

# llama
MODULE_MAP.update({
    'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM',
})

# llama4
MODULE_MAP.update({
    'Llama4ForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama4.Llama4ForConditionalGeneration',
})

# baichuan
MODULE_MAP.update({
    'BaichuanForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanForCausalLM',
})

# chatglm
MODULE_MAP.update({
    'ChatGLMForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.ChatGLMForConditionalGeneration',  # noqa: E501
})

# glm4-0414
MODULE_MAP.update({
    'Glm4ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4.Glm4ForCausalLM',
})

# glm4.1-v
MODULE_MAP.update({
    'Glm4vForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4_1v.Glm4vForConditionalGeneration',
})

# glm4.5
MODULE_MAP.update({
    'Glm4MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4_moe.Glm4MoeForCausalLM',
})

# glm4.7

MODULE_MAP.update({'Glm4MoeLiteForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})

# glm4.7 mtp
MODULE_MAP.update({
    'Glm4MoeMTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4moe_mtp.Glm4MoeMTPModel',
})

# glm5
MODULE_MAP.update({'GlmMoeDsaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'})

# internlm
MODULE_MAP.update({
    'InternLMForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.InternLMForCausalLM',
})

# internlm2
MODULE_MAP.update({
    'InternLM2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.InternLM2ForCausalLM',
})

# mistral
MODULE_MAP.update({
    'MistralForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralForCausalLM',
})

# mixtral
MODULE_MAP.update({
    'MixtralForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM',
})

# gemma
MODULE_MAP.update({
    'GemmaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',
})

# gemma2
MODULE_MAP.update({
    'Gemma2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',
})

# gemma3 text
MODULE_MAP.update({
    'Gemma3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',
})

# gemma3 VL
MODULE_MAP.update({
    'Gemma3ForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma3_vl.Gemma3ForConditionalGeneration',
})

# deepseek
MODULE_MAP.update({
    'DeepseekForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.DeepseekForCausalLM',
})

# deepseek-v2
MODULE_MAP.update({'DeepseekV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})

# deepseek-v3
MODULE_MAP.update({'DeepseekV3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})

# deepseek-v32
MODULE_MAP.update({'DeepseekV32ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'})

# deepseek-vl2
MODULE_MAP.update({'DeepseekVLV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_vl2.DeepseekVLV2ForCausalLM'})

# llava
MODULE_MAP.update({
    'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration',  # noqa: E501
    'LlavaNextForConditionalGeneration':  # noqa: E501
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration'  # noqa: E501
})

# qwen
MODULE_MAP.update({
    'QWenLMHeadModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.QWenLMHeadModel',
})

# qwen1.5
MODULE_MAP.update({
    'Qwen2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.Qwen2ForCausalLM',
})

# qwen2 moe
MODULE_MAP.update({
    'Qwen2MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.Qwen2MoeForCausalLM',
})

# qwen3
MODULE_MAP.update({
    'Qwen3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3.Qwen3ForCausalLM',
})

# qwen3 moe
MODULE_MAP.update({
    'Qwen3MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_moe.Qwen3MoeForCausalLM',
})

# qwen2_vl
MODULE_MAP.update({
    'Qwen2VLForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_vl.Qwen2VLForConditionalGeneration',
})

# qwen2_5_vl
MODULE_MAP.update({
    'Qwen2_5_VLForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration',
})

# qwen3_vl
MODULE_MAP.update({
    'Qwen3VLForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl.Qwen3VLForConditionalGeneration',
})

# qwen3_vl_moe
MODULE_MAP.update({
    'Qwen3VLMoeForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration',
})

# qwen3.5
MODULE_MAP.update({
    'Qwen3_5ForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5.Qwen3_5ForConditionalGeneration',
})

# qwen3.5 moe
MODULE_MAP.update({
    'Qwen3_5MoeForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration',
})

# starcoder2
MODULE_MAP.update({
    'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM',
})

# phi-3
MODULE_MAP.update({
    'Phi3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3ForCausalLM',
})

# cogvlm
MODULE_MAP.update({
    'CogVLMForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.CogVLMForCausalLM',
})

# internvl
MODULE_MAP.update({'InternVLChatModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.InternVLChatModel'})

# internvl3-hf
MODULE_MAP.update({
    'InternVLForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl3_hf.InternVLForConditionalGeneration'
})

# interns1-hf
MODULE_MAP.update({
    'InternS1ForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl3_hf.InternVLForConditionalGeneration'
})

# interns1-pro
MODULE_MAP.update({
    'InternS1ProForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.interns1_pro.InternS1ProForConditionalGeneration',
})
MODULE_MAP.update({
    'InternS1_1_ForConditionalGeneration':
    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.interns1_pro.InternS1ProForConditionalGeneration',
})

# mono-internvl
MODULE_MAP.update({
    'InternLM2VEForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_ve.InternLM2VEForCausalLM',
})

# phi3 vision
MODULE_MAP.update({
    'Phi3VForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM',
})

# phi-3.5-moe
MODULE_MAP.update({
    'PhiMoEForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_moe.PhiMoEForCausalLM',
})

# minicpm3
MODULE_MAP.update({
    'MiniCPM3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpm3.MiniCPM3ForCausalLM',
})

# minicpmv2_6
MODULE_MAP.update({
    'MiniCPMV': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpmv26.MiniCPMVForCausalLM',
})

# internlm3
MODULE_MAP.update({
    'InternLM3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm3.InternLM3ForCausalLM',
})

# internlm2 reward model
MODULE_MAP.update(
    {'InternLM2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_reward.InternLM2ForRewardModel'})

# qwen2 reward model
MODULE_MAP.update({'Qwen2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_reward.Qwen2ForRewardModel'})

# gpt-oss
MODULE_MAP.update({
    'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM',
})

# qwen3 next model
MODULE_MAP.update({
    'Qwen3NextForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_next.Qwen3NextForCausalLM',
})

# SDAR
MODULE_MAP.update({
    'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM',
    'SDARMoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar_moe.SDARMoeForCausalLM',
})

CUSTOM_MODULE_MAP = dict()

# spec models
# eagle llama
MODULE_MAP.update({'EagleLlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle.EagleLlamaForCausalLM'})

# eagle3 llama
MODULE_MAP.update({'Eagle3LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle3.Eagle3LlamaForCausalLM'})

# deepseek mtp
MODULE_MAP.update({'DeepseekMTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_mtp.DeepseekMTPModel'})


================================================
FILE: lmdeploy/pytorch/models/patch.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import contextlib
import importlib
import inspect
import os.path as osp
import re
import sys
from typing import Any, Dict

import torch
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import BuildModelContext, StepContextManager
from lmdeploy.utils import get_logger

from ..config import ModelConfig
from ..devices import get_device_manager
from .module_map import CUSTOM_MODULE_MAP, DEVICE_SPECIAL_MODULE_MAP, MODULE_MAP

logger = get_logger('lmdeploy')


def _get_rewrite_qualname(origin_qualname: str, module_map: Dict[str, str]) -> str:
    """Get rewrite module from origin module name.

    Args:
        origin_qualname (str): The origin qualname of the module.

    Returns:
        str: The rewrite qualname.
    """
    if origin_qualname in module_map:
        return module_map[origin_qualname]
    for key, value in module_map.items():
        if re.search(key, origin_qualname):
            return value
    return None


def _class_from_qualname(qualname: str) -> Any:
    """Import class with qualname.

    Args:
        qualname (str): Qualname of the class

    Returns:
        Any: class or builder of the class
    """
    last_dot = qualname.rfind('.')
    modname = qualname[:last_dot]
    clsname = qualname[last_dot + 1:]

    # get class at runtime
    mod = importlib.import_module(modname)
    assert mod is not None, f'failed to import module: {modname}'
    cls_type = getattr(mod, clsname)
    return cls_type


def _find_rewrite_module_qualname(model, module_map: Dict[str, str]):
    """Find rewrite module."""
    module_name = inspect.getmodule(model).__name__
    class_name = model.__class__.__name__

    def _find_fullname():
        origin_qualname = f'{module_name}.{class_name}'
        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)
        return rewrite_qualname

    def _find_classname():
        origin_qualname = class_name
        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)
        return rewrite_qualname

    def _find_submodulename():
        # name with first module
        mod_name = module_name[module_name.rfind('.') + 1:]
        origin_qualname = f'{mod_name}.{class_name}'
        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)
        return rewrite_qualname

    rewrite_qualname = _find_fullname()
    if rewrite_qualname is None:
        rewrite_qualname = _find_classname()
    if rewrite_qualname is None:
        rewrite_qualname = _find_submodulename()

    origin_qualname = f'{module_name}.{class_name}'
    if rewrite_qualname is not None:
        logger.debug('Find rewrite of module\n'
                     f'{origin_qualname} <=> {rewrite_qualname}')
    return rewrite_qualname


def get_rewrite_cls(model: torch.nn.Module, module_map: Dict[str, str] = None):
    """Get rewrite cls."""
    if module_map is None:
        module_map = _get_module_map()
    rewrite_qualname = _find_rewrite_module_qualname(model, module_map=module_map)
    if rewrite_qualname is None:
        return None
    return _class_from_qualname(rewrite_qualname)


def _get_module_map():
    """Get module map."""
    module_map = MODULE_MAP.copy()
    device_type = get_device_manager().current_context().device_type
    if device_type != 'cuda':
        device_map = DEVICE_SPECIAL_MODULE_MAP.get(device_type, dict())
        module_map.update(device_map)
    # add custom module map
    module_map.update(CUSTOM_MODULE_MAP)
    return module_map


def update_custom_module_map(module_map_path: str):
    """Moad custom module map from file."""
    from importlib.machinery import SourceFileLoader

    from lmdeploy.pytorch.models.module_map import LMDEPLOY_PYTORCH_MODEL_PATH
    assert osp.exists(module_map_path), (f'custom module map path: "{module_map_path}" not exists.')

    module_map_path = osp.abspath(module_map_path)
    folder = osp.split(module_map_path)[0]
    sys.path.append(folder)
    custom_mod = SourceFileLoader('map_mod', module_map_path).load_module()
    sys.modules[f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod'] = custom_mod

    new_mod_map = dict()
    has_map = False
    if hasattr(custom_mod, 'MODULE_MAP'):
        has_map = True
        mod_map = custom_mod.MODULE_MAP
        assert isinstance(mod_map, Dict)
        new_mod_map.update(mod_map)

    if hasattr(custom_mod, 'CUSTOM_MODULE_MAP'):
        has_map = True
        mod_map = custom_mod.CUSTOM_MODULE_MAP
        assert isinstance(mod_map, Dict)
        new_mod_map.update(mod_map)

    if not has_map:
        raise RuntimeError(f'Found no map in "{module_map_path}".')

    for k, v in new_mod_map.items():
        if '.' not in v:
            v = f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod.{v}'
            new_mod_map[k] = v

    CUSTOM_MODULE_MAP.update(new_mod_map)


def _get_model_class(config, module_map):
    """Get model class."""
    auto_map = getattr(config, 'auto_map', dict())
    if 'AutoModelForCausalLM' in auto_map:
        mapname = auto_map['AutoModelForCausalLM']
        if '.' in mapname:
            mapname = mapname.split('.')[-1]
        if mapname in module_map:
            qualname = module_map[mapname]
            module_cls = _class_from_qualname(qualname)
            return module_cls
        raise RuntimeError(f'Can not found rewrite for auto_map: {mapname}')

    architectures = getattr(config, 'architectures', [])

    if architectures is None:
        # only for deepseek-vl2, which has different config formats
        # https://huggingface.co/deepseek-ai/deepseek-vl2/blob/main/config.json
        assert getattr(config.language_config, 'architectures', []) is not None
        qualname = module_map['DeepseekVLV2ForCausalLM']
        module_cls = _class_from_qualname(qualname)
        return module_cls

    for arch in architectures:
        if arch in module_map:
            qualname = module_map[arch]
            module_cls = _class_from_qualname(qualname)
            return module_cls

    raise RuntimeError(f'Can not found rewrite for architectures: {architectures}')


def build_model_from_hf_config(model_config: PretrainedConfig,
                               dtype: torch.dtype = None,
                               device: torch.device = None,
                               ctx_mgr: StepContextManager = None,
                               build_model_ctx: 'BuildModelContext' = None):
    """Build model from hf config."""
    if ctx_mgr is None:
        ctx_mgr = StepContextManager(build_model_ctx)
    module_map = _get_module_map()
    if device is None:
        device = torch.device('cuda')
    model_cls = _get_model_class(model_config, module_map)
    # update quant config
    if build_model_ctx is not None and hasattr(model_cls, 'update_quant_config'):
        build_model_ctx.quant_config = model_cls.update_quant_config(build_model_ctx.quant_config)

    with build_model_context(build_model_ctx):
        model = model_cls(model_config, ctx_mgr, dtype=dtype, device=device)
    return model.eval()


@torch.inference_mode()
def build_patched_model(config: ModelConfig, device: torch.device = None, build_model_ctx: 'BuildModelContext' = None):
    """Build patched model."""
    model_config = config.hf_config
    dtype = config.dtype
    return build_model_from_hf_config(model_config, dtype=dtype, device=device, build_model_ctx=build_model_ctx)


@torch.inference_mode()
def add_adapters(model: torch.nn.Module,
                 adapters: Dict[str, str],
                 dtype: torch.dtype = torch.float16,
                 device: torch.device = None):
    """Add adapters."""
    from peft import PeftConfig
    from peft.tuners.lora import LoraConfig
    from transformers.modeling_utils import load_state_dict

    from lmdeploy.pytorch.adapter.adapter import find_all_target, get_ranks_and_scalings, load_lora_weights
    from lmdeploy.pytorch.nn.linear import LoRA
    num_adapters = len(adapters)
    if num_adapters == 0:
        return

    if device is None:
        device = torch.device('cuda')

    # model could be graph runner
    if hasattr(model, 'get_model'):
        model = model.get_model()
    ctx_mgr = model.ctx_mgr

    adapter_names = list(adapters.keys())
    adapter_names = sorted(adapter_names)

    adapter_cfgs = [PeftConfig.from_pretrained(adapters[name]) for name in adapter_names]

    # insert one for no adapter
    adapter_cfgs = [LoraConfig(r=0, target_modules=[])] + adapter_cfgs
    adapter_names = [None] + adapter_names
    adapter_id_map = dict(zip(adapter_names, range(len(adapter_names))))

    # target layer name to add adapter
    target_names = set()
    for cfg in adapter_cfgs:
        target_names = target_names.union(cfg.target_modules)
    target_names = list(target_names)
    target_names = sorted(target_names)

    target_infos = dict()
    for _, target_name in enumerate(target_names):
        # get ranks and scalings
        ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device)
        # split in case target_name has '.' like 'attention.wo'
        # which cannot be used as name of a module
        # and it's not aligned with key in model.packed_modules_mapping
        target_name = target_name.split('.')[-1]
        found_mods, pack_idx = find_all_target(model, target_name)
        sum_rank = ranks.sum().item()

        in_features = 0
        out_features = 0
        colwise = True
        for _, mod in found_mods:
            assert hasattr(mod, 'lora_adapters')
            in_features = mod.in_features
            colwise = mod.colwise
            if pack_idx is None:
                base_slice = slice(0, mod.out_features)
                out_features = mod.out_features
                lora_b_spliter = getattr(mod, 'weight_spliter_lora_b', None)
            else:
                prev_feats = sum(mod.all_out_features[:pack_idx])
                out_features = mod.all_out_features[pack_idx]
                base_slice = slice(prev_feats, prev_feats + out_features)
                lora_b_spliter = None
            lora_a = torch.empty((sum_rank, in_features), dtype=dtype, device=device)
            lora_b = torch.empty((sum_rank, out_features), dtype=dtype, device=device)

            lora = LoRA(
                in_features,
                out_features,
                ranks=ranks,
                scalings=scalings,
                lora_a=lora_a,
                lora_b=lora_b,
                base_slice=base_slice,
                ctx_mgr=ctx_mgr,
                colwise=colwise,
                is_tp=mod.is_tp,
                lora_b_spliter=lora_b_spliter,
            )
            mod.lora_adapters[target_name] = lora

    # fill adapter data
    for name, path in adapters.items():
        adapter_id = adapter_id_map[name]
        checkpoint_path = f'{path}/adapter_model.bin'
        if not osp.exists(checkpoint_path):
            checkpoint_path = f'{path}/adapter_model.safetensors'
        state_dict = load_state_dict(checkpoint_path, map_location=device)

        if hasattr(model, 'load_lora_weights'):
            model.load_lora_weights(state_dict.items(), adapter_id=adapter_id)
        else:
            load_lora_weights(model, state_dict.items(), adapter_id=adapter_id)

    return target_infos


BUILD_MODEL_CTX = BuildModelContext()


@contextlib.contextmanager
def build_model_context(ctx: BuildModelContext):
    """Context manager for building model."""
    global BUILD_MODEL_CTX
    old_ctx = BUILD_MODEL_CTX
    ctx = ctx or old_ctx
    BUILD_MODEL_CTX = ctx
    yield
    BUILD_MODEL_CTX = old_ctx


def get_build_model_context() -> BuildModelContext:
    """Get build model context."""
    global BUILD_MODEL_CTX
    return BUILD_MODEL_CTX


def add_prefix(name: str, prefix: str) -> str:
    """Add prefix to module name."""
    return name if not prefix else f'{prefix}.{name}'


================================================
FILE: lmdeploy/pytorch/models/phi3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class Phi3Attention(nn.Module):
    """Rewrite module of Phi3Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        sliding_window = getattr(config, 'sliding_window', None)
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=sliding_window,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=False,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Phi3MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Phi3DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Phi3Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = Phi3MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Phi3Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            Phi3DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Phi3Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            if 'vision_embed_tokens' in name:
                continue
            if '.qkv_proj' in name:
                param = params_dict[name]
                q, k, v = param.weight_spliter(loaded_weight)
                load_weight(param, q, shard_id='q')
                load_weight(param, k, shard_id='k')
                load_weight(param, v, shard_id='v')
            elif '.gate_up_proj' in name:
                param = params_dict[name]
                gate, up = param.weight_spliter(loaded_weight)
                load_weight(param, gate, shard_id=0)
                load_weight(param, up, shard_id=1)
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/phi3_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, build_rotary_embedding,
                                                  get_rope_parameters, get_rope_theta)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


def sparsemixer(scores, top_k, jitter_eps):
    assert top_k == 2

    with torch.no_grad():
        # compute mask for sparsity
        mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
        factor = scores.abs().clamp(min=mask_logits_threshold)
        mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)

    # apply mask
    masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
    selected_experts = max_ind

    # compute scores for gradients
    masked_gates = torch.softmax(masked_gates, dim=-1)
    multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
    multiplier = multiplier_o

    # masked out first expert
    masked_scores = torch.scatter(
        scores,
        -1,
        selected_experts,
        float('-inf'),
    )
    with torch.no_grad():
        # compute mask for sparsity
        mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
        factor = scores.abs().clamp(min=mask_logits_threshold)
        mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)

    # apply mask
    masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
    selected_experts_top2 = max_ind
    # compute scores for gradients
    masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
    multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)

    multiplier_top2 = multiplier_top2_o

    multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
    selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)

    return (
        multiplier,
        selected_experts,
    )


class PhiMoEAttention(nn.Module):
    """PhiMoE attention."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = None

        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = hidden_size // num_heads

        # qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.sliding_window = getattr(config, 'sliding_window', None)
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            sliding_window=self.sliding_window,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=config.attention_bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        attn_output = self.o_proj(attn_output)
        return attn_output


class PhiMoESparseMoeBlock(nn.Module):
    """PhiMoE sparse moe block."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=2,
            renormalize=False,
            dtype=dtype,
            device=device,
            all_reduce=True,
        )

        self.router_jitter_noise = config.router_jitter_noise
        self.input_jitter_noise = config.input_jitter_noise

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        if self.input_jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise,
                                                                      1.0 + self.input_jitter_noise)
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)

        topk_weights, topk_ids = sparsemixer(router_logits, top_k=2, jitter_eps=self.router_jitter_noise)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        out_states = out_states.reshape(batch_size, sequence_length, -1)
        return out_states, router_logits


class PhiMoEDecoderLayer(nn.Module):
    """PhiMoE decoder layer."""

    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx

        # build attention layer
        self.self_attn = PhiMoEAttention(config, dtype=dtype, device=device)
        self.block_sparse_moe = PhiMoESparseMoeBlock(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)

        # build attention layer norm
        self.post_attention_layernorm = LayerNorm(config.hidden_size,
                                                  eps=config.rms_norm_eps,
                                                  dtype=dtype,
                                                  device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states, _ = self.block_sparse_moe(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class PhiMoEModel(nn.Module):
    """PhiMoE model."""

    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)
        self.layers = nn.ModuleList([
            PhiMoEDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = LayerNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        emb_type = RopeType.LinearScaling
        rope_dim = config.hidden_size // config.num_attention_heads
        rope_max_pos_emb = config.max_position_embeddings
        rope_base = get_rope_theta(config)
        rope_scaling = get_rope_parameters(config)
        if rope_scaling is not None:
            scaling_type = rope_scaling['type']
            assert scaling_type in ['longrope', 'su']
            emb_type = RopeType.LongRoPEScaling
            ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb)
            longrope_params = LongRoPEScalingParameters(short_factor=rope_scaling['short_factor'],
                                                        long_factor=rope_scaling['long_factor'],
                                                        original_max_position_embeddings=ori_pos_emb,
                                                        short_mscale=rope_scaling['short_mscale'],
                                                        long_mscale=rope_scaling['long_mscale'])
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                rope_base,
                longrope_params=longrope_params,
                emb_type=emb_type,
            )
        else:
            self.rotary_emb = build_rotary_embedding(
                rope_dim,
                rope_max_pos_emb,
                rope_base,
                emb_type=emb_type,
            )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class PhiMoEForCausalLM(nn.Module, CudaGraphMixin):
    """Mixture model for causalLM."""

    def __init__(self,
                 config: Any,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.model = PhiMoEModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=config.lm_head_bias,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        num_experts = self.config.num_local_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
                break
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/phi3_v.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .phi3 import Phi3ForCausalLM, Phi3Model
from .utils.model import vlm_model

CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,
                                                     dropout=0.0,
                                                     hidden_act='quick_gelu',
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     initializer_factor=1.0,
                                                     initializer_range=0.02,
                                                     intermediate_size=4096,
                                                     layer_norm_eps=1e-05,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


@vlm_model
class Phi3ImageEmbedding(nn.Module):
    """Image embedding."""

    # from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/c45209e90a4c4f7d16b2e9d48503c7f3e83623ed/image_embedding_phi3_v.py#L83 # noqa: E501
    def __init__(self,
                 config: PretrainedConfig,
                 wte=None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs):
        super().__init__()
        self.config = config
        hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size

        self.wte = wte

        if (isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model'):
            assert 'model_name' in config.img_processor, ('model_name must be provided for CLIPVisionModel')
            assert 'image_dim_out' in config.img_processor, ('image_dim_out must be provided for CLIPVisionModel')
            assert 'num_img_tokens' in config.img_processor, ('num_img_tokens must be provided for CLIPVisionModel')
            assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336'
            clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
            self.img_processor = CLIPVisionModel(clip_config).to(device).to(dtype)
            image_dim_out = config.img_processor['image_dim_out']
            self.num_img_tokens = config.img_processor['num_img_tokens']
        else:
            raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')

        self.image_dim_out = image_dim_out
        self.img_sizes = None

        self.use_hd_transform = kwargs.get('use_hd_transform', False)
        self.with_learnable_separator = kwargs.get('with_learnable_separator', False)
        self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert (self.use_hd_transform == self.with_learnable_separator), (
            'use_hd_transform and with_learnable_separator '
            'should have same value')
        if self.with_learnable_separator:
            assert self.use_hd_transform, ('learnable separator is only for hd transform')
            # 1024 * 4, merge spatial to channel dimension
            self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4], dtype=dtype, device=device))
            self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4], dtype=dtype, device=device))

        projection_cls = kwargs.get('projection_cls', 'linear')
        if projection_cls == 'linear':
            self.img_projection = nn.Linear(image_dim_out, hidden_size, dtype=dtype, device=device)
        elif projection_cls == 'mlp' and self.use_hd_transform:
            dim_projection = hidden_size
            depth = 2
            layers = [nn.Linear(image_dim_out * 4, dim_projection, dtype=dtype, device=device)]
            for _ in range(1, depth):
                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection, dtype=dtype, device=device)])
            self.img_projection = nn.Sequential(*layers)
        elif projection_cls == 'mlp':
            dim_projection = hidden_size
            depth = 2
            layers = [nn.Linear(image_dim_out, dim_projection, dtype=dtype, device=device)]
            for _ in range(1, depth):
                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection, dtype=dtype, device=device)])
            self.img_projection = nn.Sequential(*layers)
        else:
            raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented')

        self.vocab_size = config.vocab_size
        self.img_features = None

        if isinstance(config.img_processor, dict):
            self.layer_idx = config.img_processor.get('layer_idx', -2)
            self.type_feature = config.img_processor.get('type_feature', 'patch')
        else:
            self.layer_idx = -2
            self.type_feature = 'patch'

    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        LAYER_IDX = self.layer_idx
        TYPE_FEATURE = self.type_feature

        img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)
        img_feature = img_processor_output.hidden_states[LAYER_IDX]

        if TYPE_FEATURE == 'patch':
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == 'cls_patch':
            return img_feature

        raise NotImplementedError

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        image_sizes=None,
        image_mask: torch.Tensor = None,
    ) -> torch.FloatTensor:
        """forward."""
        inputs_embeds = self.wte(input_ids)
        assert self.use_hd_transform
        num_images, num_crops, c, h, w = pixel_values.shape
        assert c == 3 and h == w == 336
        img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(num_images, num_crops, -1,
                                                                                 self.image_dim_out)
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
        # update image feature to inputs_embeds
        inputs_embeds.masked_scatter_(image_mask[..., None], image_features_proj)
        return inputs_embeds

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (self.hd_transform_order == 'sub_glb'), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)

        all_image_embeddings = []
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)

            # [sub features, separator, global features]
            all_image_embeddings.extend([
                sub_image_features_hd_newline.squeeze(0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])

        image_features_proj = self.img_projection(
            torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype))

        return image_features_proj

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings],
                                              dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline


class Phi3VModel(Phi3Model):
    """Phi3v model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__(config=config, dtype=dtype, device=device)

        self.vision_embed_tokens = None
        if isinstance(config.embd_layer, dict):
            # vision embedding layer
            embedding_config = {'embedding_cls': config.embd_layer['embedding_cls'], **config.embd_layer}
            self.vision_embed_tokens = Phi3ImageEmbedding(config,
                                                          wte=self.embed_tokens,
                                                          dtype=dtype,
                                                          device=device,
                                                          **embedding_config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[torch.LongTensor] = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        if inputs_embeds is None and pixel_values is not None:
            inputs_embeds = self.vision_embed_tokens(
                input_ids,
                pixel_values,
                image_sizes,
                image_mask,
            )

        return super().forward(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )


class Phi3VForCausalLM(Phi3ForCausalLM):

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__(config, ctx_mgr, dtype=dtype, device=device)
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Phi3VModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

        self.input_processor = Phi3VInputProcessor(config, dtype)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        pixel_values: torch.Tensor = None,
        image_sizes: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """forward."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            image_mask=image_mask,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        output = super().prepare_inputs_for_generation(past_key_values=past_key_values,
                                                       inputs_embeds=inputs_embeds,
                                                       context=context)

        # vision inputs
        pixel_values = None
        if context.input_multimodals is not None:
            input_mms = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            # flatten batch
            input_mms = [data for im_data in input_mms for data in im_data]
            if len(input_mms) > 0:
                pixel_values = torch.cat([data.data for data in input_mms])
                image_sizes = torch.cat([data.meta['image_sizes'] for data in input_mms])
                image_token_id = input_mms[0].meta['image_token_id']
                image_mask = output['input_ids'] == image_token_id
                output['pixel_values'] = pixel_values
                output['image_sizes'] = image_sizes
                output['image_mask'] = image_mask

        return output

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        import itertools

        vis_prefix = 'vision_embed_tokens.'
        # create two ierators from weights for llm and vlm
        llm_weights, vlm_weights = itertools.tee(weights, 2)
        llm_weights = ((name, tensor) for name, tensor in llm_weights if vis_prefix not in name)
        vlm_weights = ((name, tensor) for name, tensor in vlm_weights if vis_prefix in name)
        super().load_weights(llm_weights)

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in vlm_weights:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Phi3VInputProcessor(BaseModelInputProcessor):
    """Phi3V input processor."""

    def __init__(self, config: PretrainedConfig, dtype) -> None:
        self.config = config
        self.dtype = dtype

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values']
            image_sizes = input_mm['image_sizes']
            offset = input_mm['offset']
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=offset,
                                     end=offset + num_pad,
                                     meta=dict(image_sizes=image_sizes, image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/q_modules.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from dataclasses import dataclass, fields

import torch
import torch.nn as nn

from ..kernels.w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,
                                           rms_norm_dynamic_quant)


@dataclass
class QTensor:
    """A data class representing a Quantized Tensor.

    This class wraps around a regular Pytorch tensor and adds quantization- specific parameters.
    """
    tensor: torch.Tensor
    scale: torch.Tensor
    zero_point: torch.Tensor = None

    def __post_init__(self):
        self.fields = [field.name for field in fields(self)]

    def __getattr__(self, name: str):
        """Allows attribute access to be forwarded to the wrapped tensor when
        the attribute doesn't exist in QTensor."""
        if name in self.fields:
            return super().__getattr__(name)
        return getattr(self.tensor, name)


class QRMSNorm(nn.Module):
    """It performs traditional RMS normalization and then quantizes the output
    to 8-bit integers."""

    def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        self.quant_dtype = quant_dtype

    @classmethod
    def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=torch.int8):
        """Class method to create a QRMSNorm instance from a floating-point
        module.

        `initialization = True` for real init. `initialization = False` for dummy init.
        """
        hidden_size = mod.weight.shape[0]
        eps = mod.variance_epsilon
        q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype)
        if initialization:
            q_mod.weight = nn.Parameter(mod.weight.detach())
        return q_mod

    def forward(self, hidden_states):
        """Defines the computation performed at every call.

        Performs RMS normalization followed by dynamic quantization on hidden_states. Returns a QTensor which wraps the
        quantized tensor along with its scale factor.
        """
        hidden_states_quant, rms_scale = rms_norm_dynamic_quant(hidden_states,
                                                                self.weight,
                                                                self.variance_epsilon,
                                                                quant_dtype=self.quant_dtype)
        return QTensor(hidden_states_quant, rms_scale)


class QLinear(nn.Module):
    """A Linear layer that operates on quantized inputs and weights.

    It performs matrix multiplication in 8-bit precision and dequantize the results back to float.
    """

    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 quant_dtype=torch.int8) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.quant_dtype = quant_dtype
        self.register_buffer('weight', torch.empty((out_features, in_features), device=device, dtype=quant_dtype))
        self.register_buffer('scale', torch.empty((out_features, 1), device=device, dtype=torch.float32))
        if bias:
            self.register_buffer('bias', torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)

    @classmethod
    def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=torch.int8):
        """Class method to create a QLinear instance from a floating-point
        module.

        `initialization = True` for real init. `initialization = False` for dummy init.
        """
        q_mod = cls(mod.in_features,
                    mod.out_features,
                    mod.bias is not None,
                    device=mod.weight.device,
                    dtype=mod.weight.dtype,
                    quant_dtype=quant_dtype)

        if initialization:
            weight_quant, scale = per_channel_quant(mod.weight.detach(), quant_dtype)
            q_mod.weight.data = weight_quant
            q_mod.scale = scale

        if mod.bias is not None:
            q_mod.bias.data = mod.bias.detach()
        return q_mod

    def forward(self, input):
        """Defines the computation performed at every call.

        Performs quantization if the input is a tensor, otherwise it assumes the input is already quantized (instance of
        QTensor). Then, it performs linear transformation using dynamic quantization method, resulting in an 8-bit
        integer output. Finally, it dequantizes the result back to a floating point tensor.
        """

        if isinstance(input, torch.Tensor):
            input_quant, input_scale = per_token_quant_int8(input, 1e-7, quant_dtype=self.quant_dtype)
        else:
            assert isinstance(input, QTensor)
            input_quant, input_scale = input.tensor, input.scale

        out = matmul_kernel_dynamic_quant(input_quant,
                                          self.weight,
                                          input_scale,
                                          self.scale,
                                          output_dtype=torch.float16,
                                          bias=self.bias)
        return out

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias
                                                                 is not None)


================================================
FILE: lmdeploy/pytorch/models/qwen.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class QWenAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h] and returns output of the same size.
    """

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)

        self.hidden_size = config.hidden_size
        self.split_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.projection_size = config.kv_channels * config.num_attention_heads
        self.num_attention_heads = config.num_attention_heads
        self.num_kv_heads = self.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.c_attn = build_qkv_proj(
            config.hidden_size,
            num_q_heads=self.num_attention_heads,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # apply rotary
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            self.num_attention_heads,
            self.head_dim,
            num_kv_heads=self.num_kv_heads,
        )

        # o_proj
        self.c_proj = build_o_proj(self.projection_size,
                                   config.hidden_size,
                                   bias=not config.no_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.c_attn(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        (query_states, key_states, value_states) = self.c_attn.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.c_proj(attn_output)
        return attn_output


class QWenMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        ff_dim_in = config.intermediate_size // 2
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [ff_dim_in, ff_dim_in],
            bias=not config.no_bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.c_proj = build_down_linear(ff_dim_in,
                                        config.hidden_size,
                                        bias=not config.no_bias,
                                        quant_config=quantization_config,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.c_proj(act)


class QWenBlock(torch.nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an output of the same size.
    """

    def __init__(self,
                 config: PretrainedConfig,
                 layer_number: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_number = layer_number
        hidden_size = config.hidden_size
        self.bf16 = config.bf16

        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.attn = QWenAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = QWenMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.ln_1 = RMSNorm(hidden_size,
                            config.layer_norm_epsilon,
                            quant_config=quantization_config,
                            dtype=dtype,
                            device=device)

        # build attention layer norm
        self.ln_2 = RMSNorm(hidden_size,
                            config.layer_norm_epsilon,
                            quant_config=quantization_config,
                            dtype=dtype,
                            device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            layernorm_output = self.ln_1(hidden_states)
        else:
            layernorm_output, residual = self.ln_1(hidden_states, residual)

        # Self Attention
        layernorm_input = self.attn(
            hidden_states=layernorm_output,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        layernorm_output, residual = self.ln_2(layernorm_input, residual)
        mlp_output = self.mlp(layernorm_output)

        outputs = (mlp_output, residual)
        return outputs


class QWenModel(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.embed_dim = config.hidden_size
        self.wte = nn.Embedding(self.vocab_size, self.embed_dim, dtype=dtype, device=device)

        # build all decode layers
        self.h = nn.ModuleList(
            [QWenBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.num_hidden_layers)])

        # build rotary embedding
        emb_type = RopeType.LinearScaling
        if config.rotary_pct == 1.0:
            self.rotary_ndims = None
        else:
            assert config.rotary_pct < 1
            self.rotary_ndims = int(config.kv_channels * config.rotary_pct)
        rope_dim = (self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels)
        rope_max_pos_emb = getattr(config, 'max_position_embeddings', 4096)
        rope_base = config.rotary_emb_base
        self.rotary_emb = build_rotary_embedding(
            rope_dim,
            rope_max_pos_emb,
            rope_base,
            emb_type=emb_type,
        )

        self.ln_f = RMSNorm(self.embed_dim, eps=config.layer_norm_epsilon, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.h):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, residual = self.ln_f(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.wte


class QWenLMHeadModel(nn.Module, CudaGraphMixin):
    """Rewrote model."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'w2',
            'w1',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.transformer = QWenModel(config, dtype=dtype, device=device)

        # output_layers
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.transformer.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.gate_up_proj', '.w2', 0),
            ('.gate_up_proj', '.w1', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'visual' in name:
                continue
            if 'rotary_pos_emb.inv_freq' in name:
                continue
            if ('rotary_pos_emb.cos_cached' in name or 'rotary_pos_emb.sin_cached' in name):
                continue
            if (self.config.tie_word_embeddings and 'lm_head.weight' in name):
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.c_attn' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class Qwen2Attention(nn.Module):
    """Rewrite module of Qwen2Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=True,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window if config.use_sliding_window else None,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=False,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen2MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen2DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Qwen2Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = Qwen2MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Qwen2Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            Qwen2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Qwen2Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen2_5_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.models.qwen2_vl import Qwen2Model
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixinV1, vlm_model


class Qwen2_5_PatchEmbed(nn.Module):
    """Patch Embed."""

    def __init__(self,
                 patch_size: int = 14,
                 temporal_patch_size: int = 2,
                 in_channels: int = 3,
                 embed_dim: int = 1152,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        kernel_size = [temporal_patch_size, patch_size, patch_size]
        self.proj = nn.Conv3d(in_channels,
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False,
                              dtype=dtype,
                              device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
                                           self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class Qwen2_5_VisionRotaryEmbedding(nn.Module):
    """Vision rotary embedding."""

    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Qwen2_5_VLVisionAttention(nn.Module):
    """Vision attention."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        dim = config.hidden_size
        num_heads = config.num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim

        # packed qkv
        self.qkv = build_qkv_proj(dim,
                                  num_q_heads=num_heads,
                                  num_kv_heads=num_heads,
                                  head_size=head_dim,
                                  bias=True,
                                  quant_config=quantization_config,
                                  dtype=dtype,
                                  device=device,
                                  prefix=add_prefix('qkv', prefix))

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attention = FlashAttention(
            num_heads,
            head_dim,
            causal=False,
        )

        # o_proj
        self.proj = build_rowwise_linear(
            dim,
            dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('proj', prefix),
        )

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        # qkv proj
        qkv_states = self.qkv(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        q, k, v = self.qkv.split_qkv(qkv_states)

        cos, sin = rotary_pos_emb
        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)

        attn_output = self.attention(
            q,
            k,
            v,
            q_start_loc=cu_seqlens[:-1],
            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],
        )

        attn_output = attn_output.reshape(seq_length, -1)

        # o proj
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2_5_VLMLP(nn.Module):
    """Vision mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            in_features=config.hidden_size,
            all_out_features=[config.intermediate_size, config.intermediate_size],
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(in_features=config.intermediate_size,
                                              out_features=config.hidden_size,
                                              bias=True,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

    def forward(self, x):
        """forward."""
        return self.down_proj(self.act_fn(self.gate_up_proj(x)))


class Qwen2_5_VLVisionBlock(nn.Module):
    """Vision block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.norm1 = RMSNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)
        self.norm2 = RMSNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)

        self.attn = Qwen2_5_VLVisionAttention(config, dtype=dtype, device=device)

        self.mlp = Qwen2_5_VLMLP(config, dtype=dtype, device=device)

    def forward(self,
                hidden_states: torch.Tensor,
                cu_seqlens: torch.Tensor,
                rotary_pos_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Qwen2_5_VLPatchMerger(nn.Module):
    """Qwen2_5_VLPatchMerger."""

    def __init__(self,
                 dim: int,
                 context_dim: int,
                 spatial_merge_size: int = 2,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.ln_q = RMSNorm(context_dim, eps=1e-6, dtype=dtype, device=device)

        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device),
            nn.GELU(),
            nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        return x


@vlm_model
class Qwen2_5_VisionTransformerPretrainedModel(nn.Module):
    """Vision transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size
        self.fullatt_block_indexes = config.fullatt_block_indexes
        self.window_size = config.window_size
        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

        self.patch_embed = Qwen2_5_PatchEmbed(
            patch_size=config.patch_size,
            temporal_patch_size=config.temporal_patch_size,
            in_channels=config.in_channels,
            embed_dim=config.hidden_size,
            dtype=dtype,
            device=device,
        )

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device)

        self.blocks = nn.ModuleList(
            [Qwen2_5_VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)])
        self.merger = Qwen2_5_VLPatchMerger(dim=config.out_hidden_size,
                                            context_dim=config.hidden_size,
                                            spatial_merge_size=config.spatial_merge_size,
                                            dtype=dtype,
                                            device=device)

    def rot_pos_emb(self, grid_thw):
        """Rotary position embedding."""
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(self,
                hidden_states: torch.Tensor,
                cu_seqlens: torch.Tensor,
                rotary_pos_emb: torch.Tensor,
                window_index: torch.Tensor = None,
                cu_window_seqlens: List = None) -> torch.Tensor:
        """forward."""
        hidden_states = self.patch_embed(hidden_states)

        # for window-based attention
        seq_len, _ = hidden_states.size()
        hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.repeat(1, 2)
        rotary_pos_emb = (rotary_pos_emb.cos(), rotary_pos_emb.sin())

        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)

        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens

            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)

        hidden_states = self.merger(hidden_states)
        reverse_indices = torch.argsort(window_index)
        hidden_states = hidden_states[reverse_indices, :]

        return hidden_states


class Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # preprocessor
        self.input_processor = Qwen2_5_VLInputProcessor(self.config)

        # build vision model
        self.visual = Qwen2_5_VisionTransformerPretrainedModel(
            config.vision_config,
            dtype=dtype,
            device=device,
        )
        # get text_config
        text_config = getattr(config, 'text_config', config)
        # build model
        self.model = Qwen2Model(text_config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        mrope_position_ids: torch.Tensor = None,
        pixel_values: torch.Tensor = None,
        vis_cu_seqlens: torch.Tensor = None,
        vis_pos_emb: torch.Tensor = None,
        window_index: torch.Tensor = None,
        cu_window_seqlens: List = None,
        image_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)
            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                image_embeds = self.visual(pixel_values,
                                           cu_seqlens=vis_cu_seqlens,
                                           rotary_pos_emb=vis_pos_emb.to(dtype),
                                           window_index=window_index,
                                           cu_window_seqlens=cu_window_seqlens)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)

        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""

        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_mask = None
        window_index = None
        cu_window_seqlens = None
        if context.input_multimodals is not None:
            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]

            if len(image_data) > 0:
                # flatten batch
                image_data = [data for im_data in image_data for data in im_data]
                pixel_values = torch.cat([data.data for data in image_data])
                image_token_id = image_data[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()
                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)

                # calculation for window-based attention
                window_index, cu_window_seqlens = self.visual.get_window_index(grid_thw)
                cu_window_seqlens = torch.tensor(
                    cu_window_seqlens,
                    device=pixel_values.device,
                    dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
                )
                cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                         grid_thw[:, 0]).to(pixel_values.device)
                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)

        mrope_position_ids = getattr(context, 'mrope_position_ids', None)

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            window_index=window_index,
            cu_window_seqlens=cu_window_seqlens,
            image_mask=image_mask,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.qkv.' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        input_ids = kwargs.get('input_ids')
        num_tokens = input_ids.size(-1)
        new_batch_size = graph_meta.max_batchs

        is_decoding = graph_meta.is_decoding
        input_buffers = graph_meta.input_buffers
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
            if is_decoding:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
            else:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']

        return new_inputs

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(mrope_delta=0)] * batch_size
        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]

    def _update_model_meta_decoding(self, context: StepContext):
        """Update model meta for decoding."""
        model_metas = self._get_model_metas(context)
        position_ids = context.position_ids

        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
        mrope_deltas = position_ids.new_tensor(mrope_deltas)
        mrope_position_ids = position_ids + mrope_deltas[None]
        mrope_position_ids = mrope_position_ids.expand(3, -1)

        context.mrope_position_ids = mrope_position_ids
        return model_metas

    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
        """Get mrope ids."""
        t, h, w = grid_thw
        h //= 2
        w //= 2
        stride = torch.tensor([h * w, w, 1], device=device)[:, None]
        size = torch.tensor([t, h, w], device=device)[:, None]
        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
        pos_ids = pos_ids // stride % size
        return pos_ids

    def _update_model_meta_prefilling(self, context: StepContext):
        """Update model meta for prefilling."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_multimodals = [None] * len(model_metas)
        position_ids = context.position_ids
        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
        mrope_position_ids = []
        new_model_metas = []
        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
            images = []
            if input_mm is not None:
                images = input_mm.get('image', [])
            if model_meta is None or 'mrope_delta' not in model_meta:
                mrope_delta = 0
            else:
                mrope_delta = model_meta['mrope_delta']

            pos_start = pos_ids[0].item()
            mrope_pos_ids = pos_ids + mrope_delta
            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
            for img in images:
                grid_thw = img.meta['grid_thw'][0].tolist()
                _, h, w = grid_thw
                h //= 2
                w //= 2
                num_pad = img.end - img.start - max(h, w)
                mrope_delta -= num_pad
                fill_start = img.start - pos_start
                fill_end = img.end - pos_start
                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
                mrope_pos_ids[:, fill_end:] -= num_pad
                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids

            mrope_position_ids.append(mrope_pos_ids)
            new_model_metas.append(dict(mrope_delta=mrope_delta))

        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
        context.mrope_position_ids = mrope_position_ids

        return new_model_metas

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        if context.is_decoding:
            return self._update_model_meta_decoding(context)
        else:
            return self._update_model_meta_prefilling(context)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Qwen2_5_VLInputProcessor(BaseModelInputProcessor):
    """Qwen2 input processor."""

    def __init__(self, config: PretrainedConfig) -> None:
        self.config = config

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values']
            image_grid_thw = input_mm['image_grid_thw']
            offset = input_mm['offset']
            start = offset
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=start,
                                     end=start + num_pad,
                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/qwen2_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class Qwen2MoeAttention(nn.Module):
    """Rewrite module of Qwen2MoeAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen2MoeMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=is_tp,
                                              all_reduce=all_reduce)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen2MoeSparseMoeBlock(nn.Module):
    """Moe block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.norm_topk_prob

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            all_reduce=False,
        )

        intermediate_size = config.shared_expert_intermediate_size
        self.shared_expert = Qwen2MoeMLP(
            config=config,
            intermediate_size=intermediate_size,
            dtype=dtype,
            device=device,
            is_tp=True,
            all_reduce=False,
        )
        self.shared_expert_gate = build_rowwise_linear(config.hidden_size,
                                                       1,
                                                       bias=False,
                                                       dtype=dtype,
                                                       device=device,
                                                       all_reduce=False)
        world_size, _ = get_tp_world_rank()
        if world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)

        topk_weights, topk_ids = self.softmax_topk(router_logits)

        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        shared_states = self.shared_expert(hidden_states)
        shared_states = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_states
        out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)

        return out_states


class Qwen2MoeDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Qwen2MoeAttention(config, dtype=dtype, device=device)

        # build MLP
        if (layer_idx not in config.mlp_only_layers) and (config.num_experts
                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):
            self.mlp = Qwen2MoeSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device)
        else:
            self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Qwen2MoeModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            Qwen2MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen2MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Qwen2MoeModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert map
        num_experts = self.config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen2_reward.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .qwen2 import Qwen2Model
from .utils.cudagraph import CudaGraphMixin


class Qwen2ForRewardModel(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Qwen2Model(config, dtype=dtype, device=device)

        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

        self.num_labels = 1
        self.score = nn.Sequential(
            build_rowwise_linear(config.hidden_size, config.hidden_size, bias=True, dtype=dtype, device=device),
            nn.ReLU(), build_rowwise_linear(config.hidden_size, self.num_labels, bias=True, dtype=dtype, device=device))

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        logits = self.score(hidden_states)
        return logits

    def update_weights(self):
        """Update weights."""
        pass

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            # inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen2_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, LayerNorm, RMSNorm, SiluAndMul,
                                 build_rotary_embedding_from_config)
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding, vlm_model


def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int],
                           position_ids: torch.Tensor, rotary_emb_func: Callable):
    _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)
    _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids
    cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)
    _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)
    _sin = torch.zeros_like(_cos)
    mrope_section = mrope_section * 2

    def _apply_split(src, dst):
        start = 0
        for i, m in enumerate(src.split(mrope_section, dim=-1)):
            dst[:, start:start + mrope_section[i]] = m[i % 3]
            start += mrope_section[i]

    _apply_split(cos, _cos)
    _apply_split(sin, _sin)

    return _cos, _sin


class Qwen2Attention(nn.Module):
    """Rewrite module of Qwen2Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen2MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(config.intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen2DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Qwen2Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = Qwen2MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Qwen2Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.mrope_section = config.rope_scaling['mrope_section']

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            Qwen2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        mrope_position_ids: torch.LongTensor = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        if mrope_position_ids is None:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
            cos, sin = cos[0], sin[0]
        else:
            cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,
                                              self.rotary_emb)
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class PatchEmbed(nn.Module):
    """Patch Embed."""

    def __init__(self,
                 patch_size: int = 14,
                 temporal_patch_size: int = 2,
                 in_channels: int = 3,
                 embed_dim: int = 1152,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        kernel_size = [temporal_patch_size, patch_size, patch_size]
        self.proj = nn.Conv3d(in_channels,
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False,
                              dtype=dtype,
                              device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
                                           self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class VisionRotaryEmbedding(nn.Module):
    """Vision rotary embedding."""

    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class VisionAttention(nn.Module):
    """Vision attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        dim = config.embed_dim
        num_heads = config.num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim

        # packed qkv
        self.qkv = build_qkv_proj(
            dim,
            num_q_heads=num_heads,
            num_kv_heads=num_heads,
            head_size=head_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attention = FlashAttention(
            num_heads,
            head_dim,
            causal=False,
        )

        # o_proj
        self.proj = build_rowwise_linear(dim,
                                         dim,
                                         bias=True,
                                         quant_config=quantization_config,
                                         dtype=dtype,
                                         device=device,
                                         is_tp=True)

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        # qkv proj
        qkv_states = self.qkv(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        q, k, v = self.qkv.split_qkv(qkv_states)

        cos, sin = rotary_pos_emb
        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)

        attn_output = self.attention(
            q,
            k,
            v,
            q_start_loc=cu_seqlens[:-1],
            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],
        )

        attn_output = attn_output.reshape(seq_length, -1)

        # o proj
        attn_output = self.proj(attn_output)
        return attn_output


class VisionMlp(nn.Module):
    """Vision mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        from transformers.activations import ACT2FN
        dim = config.embed_dim
        hidden_dim = int(config.embed_dim * config.mlp_ratio)
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.fc1 = build_colwise_linear(
            dim,
            hidden_dim,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:
            self.act = nn.GELU()
        else:
            self.act = ACT2FN[config.hidden_act]

        # down
        self.fc2 = build_rowwise_linear(hidden_dim,
                                        dim,
                                        bias=True,
                                        quant_config=quantization_config,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, x):
        """forward."""
        return self.fc2(self.act(self.fc1(x)))


class Qwen2VLVisionBlock(nn.Module):
    """Vision block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.norm1 = LayerNorm(config.embed_dim, eps=1e-6, dtype=dtype, device=device)
        self.norm2 = LayerNorm(config.embed_dim, eps=1e-6, dtype=dtype, device=device)

        self.attn = VisionAttention(config, dtype=dtype, device=device)

        self.mlp = VisionMlp(config, dtype=dtype, device=device)

    def forward(self,
                hidden_states,
                cu_seqlens,
                rotary_pos_emb,
                residual: Optional[torch.Tensor] = None) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm1(hidden_states)
        else:
            hidden_states, residual = self.norm1(hidden_states, residual)

        hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        hidden_states, residual = self.norm2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class PatchMerger(nn.Module):
    """PatchMerger."""

    def __init__(self,
                 dim: int,
                 context_dim: int,
                 spatial_merge_size: int = 2,
                 dtype: torch.dtype = None,
                 device: torch.device = None) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.ln_q = nn.LayerNorm(context_dim, eps=1e-6, dtype=dtype, device=device)
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device),
            nn.GELU(),
            nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        return x


@vlm_model
class Qwen2VisionTransformerPretrainedModel(nn.Module):
    """Vision transformer."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = PatchEmbed(
            patch_size=config.patch_size,
            temporal_patch_size=config.temporal_patch_size,
            in_channels=config.in_channels,
            embed_dim=config.embed_dim,
            dtype=dtype,
            device=device,
        )

        head_dim = config.embed_dim // config.num_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, device=device)

        self.blocks = nn.ModuleList(
            [Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)])
        self.merger = PatchMerger(dim=config.hidden_size,
                                  context_dim=config.embed_dim,
                                  spatial_merge_size=config.spatial_merge_size,
                                  dtype=dtype,
                                  device=device)

    def rot_pos_emb(self, grid_thw):
        """Rotary position embedding."""
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                rotary_pos_emb: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.patch_embed(hidden_states)
        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)

        residual = None
        for blk in self.blocks:
            hidden_states, residual = blk(hidden_states,
                                          cu_seqlens=cu_seqlens,
                                          rotary_pos_emb=rotary_pos_emb,
                                          residual=residual)

        hidden_states = hidden_states + residual

        return self.merger(hidden_states)


class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # preprocessor
        self.input_processor = Qwen2VLInputProcessor(self.config)

        # build vision model
        self.visual = Qwen2VisionTransformerPretrainedModel(
            config.vision_config,
            dtype=dtype,
            device=device,
        )
        # get text_config
        text_config = getattr(config, 'text_config', config)
        # build model
        self.model = Qwen2Model(text_config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        mrope_position_ids: torch.Tensor = None,
        pixel_values: torch.Tensor = None,
        vis_cu_seqlens: torch.Tensor = None,
        vis_pos_emb: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)
            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))
                image_embeds = self.visual(pixel_values, cu_seqlens=vis_cu_seqlens, rotary_pos_emb=vis_pos_emb)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)

        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""

        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_mask = None
        if context.input_multimodals is not None:
            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]
            if len(image_data) > 0:
                # flatten batch
                image_data = [data for im_data in image_data for data in im_data]
                pixel_values = torch.cat([data.data for data in image_data])
                image_token_id = image_data[0].meta['image_token_id']
                image_mask = input_ids == image_token_id
                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()
                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                         grid_thw[:, 0]).to(pixel_values.device)
                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
                vis_pos_emb = vis_pos_emb.repeat(1, 2)
                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())

        mrope_position_ids = getattr(context, 'mrope_position_ids', None)

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_mask=image_mask,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.qkv.' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        input_ids = kwargs.get('input_ids')
        num_tokens = input_ids.size(-1)
        new_batch_size = graph_meta.max_batchs

        is_decoding = graph_meta.is_decoding
        input_buffers = graph_meta.input_buffers
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
            if is_decoding:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
            else:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']

        return new_inputs

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(mrope_delta=0)] * batch_size
        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]

    def _update_model_meta_decoding(self, context: StepContext):
        """Update model meta for decoding."""
        model_metas = self._get_model_metas(context)
        position_ids = context.position_ids

        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
        mrope_deltas = position_ids.new_tensor(mrope_deltas)
        mrope_position_ids = position_ids + mrope_deltas[None]
        mrope_position_ids = mrope_position_ids.expand(3, -1)

        context.mrope_position_ids = mrope_position_ids
        return model_metas

    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
        """Get mrope ids."""
        t, h, w = grid_thw
        h //= 2
        w //= 2
        stride = torch.tensor([h * w, w, 1], device=device)[:, None]
        size = torch.tensor([t, h, w], device=device)[:, None]
        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
        pos_ids = pos_ids // stride % size
        return pos_ids

    def _update_model_meta_prefilling(self, context: StepContext):
        """Update model meta for prefilling."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_multimodals = [None] * len(model_metas)
        position_ids = context.position_ids
        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
        mrope_position_ids = []
        new_model_metas = []
        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
            images = []
            if input_mm is not None:
                images = input_mm.get('image', [])
            if model_meta is None or 'mrope_delta' not in model_meta:
                mrope_delta = 0
            else:
                mrope_delta = model_meta['mrope_delta']

            pos_start = pos_ids[0].item()
            mrope_pos_ids = pos_ids + mrope_delta
            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
            for img in images:
                grid_thw = img.meta['grid_thw'][0].tolist()
                _, h, w = grid_thw
                h //= 2
                w //= 2
                num_pad = img.end - img.start - max(h, w)
                mrope_delta -= num_pad
                fill_start = img.start - pos_start
                fill_end = img.end - pos_start
                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
                mrope_pos_ids[:, fill_end:] -= num_pad
                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids

            mrope_position_ids.append(mrope_pos_ids)
            new_model_metas.append(dict(mrope_delta=mrope_delta))

        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
        context.mrope_position_ids = mrope_position_ids

        return new_model_metas

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        if context.is_decoding:
            return self._update_model_meta_decoding(context)
        else:
            return self._update_model_meta_prefilling(context)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Qwen2VLInputProcessor(BaseModelInputProcessor):
    """Qwen2 input processor."""

    def __init__(self, config: PretrainedConfig) -> None:
        self.config = config

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_imgs = []
        for input_mm in input_multimodals:
            pixel_values = input_mm['pixel_values']
            image_grid_thw = input_mm['image_grid_thw']
            offset = input_mm['offset']
            start = offset
            image_token_id = input_mm['image_token_id']
            num_pad = input_mm['image_tokens']
            if isinstance(num_pad, torch.Tensor):
                num_pad = num_pad.item()

            mm_data = MultiModalData(data=pixel_values,
                                     start=start,
                                     end=start + num_pad,
                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
            input_imgs.append(mm_data)

        result = PreprocessInputResult(
            input_ids=input_ids,
            input_multimodals=dict(image=input_imgs),
        )
        return result


================================================
FILE: lmdeploy/pytorch/models/qwen3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class Qwen3Attention(nn.Module):
    """Rewrite module of Qwen3Attention."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            num_replicate_kv_heads=num_replicate_kv_heads,
            prefix=add_prefix('qkv_proj', prefix),
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=getattr(config, 'sliding_window', None),
        )

        # o_proj
        self.o_proj = build_o_proj(
            num_heads * head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('o_proj', prefix),
        )

        # q, k norm
        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen3MLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            prefix=add_prefix('gate_up_proj', prefix),
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(
            config.intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('down_proj', prefix),
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen3DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Qwen3Attention(config, dtype=dtype, device=device, prefix=add_prefix('self_attn', prefix))

        # build MLP
        self.mlp = Qwen3MLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))

        # build input layer norm
        self.input_layernorm = RMSNorm(
            config.hidden_size,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('input_layernorm', prefix),
        )

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('post_attention_layernorm', prefix),
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Qwen3model(nn.Module):
    """model."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        self.layers = nn.ModuleList([
            Qwen3DecoderLayer(config,
                              layer_idx,
                              dtype=dtype,
                              device=device,
                              prefix=add_prefix(f'layers.{layer_idx}', prefix))
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Qwen3model(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen3_5.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from functools import lru_cache
from typing import Any, Iterable, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update

import lmdeploy.pytorch.nn.gated_delta as gated_delta_util
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight
from lmdeploy.vl.constants import Modality

from .patch import add_prefix
from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding
from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3_5VisionAttention
from .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5InputProcessor
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixinV1, vlm_model


class Qwen3_5VisionPatchEmbed(nn.Module):

    def __init__(self, config, dtype: torch.dtype | None = None, device: torch.device | None = None) -> None:
        super().__init__()
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size

        kernel_size = (self.temporal_patch_size, self.patch_size, self.patch_size)
        self.proj = nn.Conv3d(self.in_channels,
                              self.embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=True,
                              dtype=dtype,
                              device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
                                           self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class Qwen3_5VisionMLP(nn.Module):
    """Vision mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        from transformers.activations import ACT2FN
        hidden_dim = config.hidden_size
        intermediate_size = config.intermediate_size
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.linear_fc1 = build_colwise_linear(
            hidden_dim,
            intermediate_size,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            prefix=add_prefix('linear_fc1', prefix),
        )

        # gelu_pytorch_tanh
        self.act = ACT2FN[config.hidden_act]

        # down
        self.linear_fc2 = build_rowwise_linear(
            intermediate_size,
            hidden_dim,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            prefix=add_prefix('linear_fc2', prefix),
        )

    def forward(self, x):
        """forward."""
        return self.linear_fc2(self.act(self.linear_fc1(x)))


class Qwen3_5VisionBlock(nn.Module):
    """Vision block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        self.layer_idx = layer_idx
        self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)
        self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)

        self.attn = Qwen3_5VisionAttention(config, dtype=dtype, device=device, prefix=add_prefix('attn', prefix))

        self.mlp = Qwen3_5VisionMLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))

    def forward(self,
                hidden_states: torch.Tensor,
                cu_seqlens: torch.Tensor,
                rotary_pos_emb: torch.Tensor | None = None) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Qwen3_5VisionPatchMerger(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 use_postshuffle_norm=False,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
        self.use_postshuffle_norm = use_postshuffle_norm
        self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size,
                              eps=1e-6,
                              dtype=dtype,
                              device=device)
        self.linear_fc1 = build_colwise_linear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            is_tp=True,
        )
        self.act_fn = nn.GELU()
        self.linear_fc2 = build_rowwise_linear(
            self.hidden_size,
            config.out_hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            is_tp=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
        x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return x


@vlm_model
class Qwen3_5VisionModel(nn.Module):
    """qwen3.5 vision model."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = Qwen3_5VisionPatchEmbed(config=config, dtype=dtype, device=device)

        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device)
        self.num_grid_per_side = int(config.num_position_embeddings**0.5)

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2, device=device)

        self.blocks = nn.ModuleList([
            Qwen3_5VisionBlock(config,
                               layer_idx,
                               dtype=dtype,
                               device=device,
                               prefix=add_prefix(f'blocks.{layer_idx}', prefix)) for layer_idx in range(config.depth)
        ])
        self.merger = Qwen3_5VisionPatchMerger(config=config, use_postshuffle_norm=False, dtype=dtype, device=device)

    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size

        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        """Rotary position embedding."""
        pos_ids = []

        for t, h, w in grid_thw:
            base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size)
            pos_ids.append(base if t == 1 else base.repeat(t, 1))

        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)

        return rotary_pos_emb

    # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474
    def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor:
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim
        device = self.pos_embed.weight.device

        outputs = []
        for t, h, w in grid_thw:
            h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device)
            w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device)

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

            # Create meshgrid view for all h, w vars
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij')
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij')

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
            w00 = 1 - dh_grid - w01

            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
            h_grid_idx = h_grid * num_grid_per_side

            indices = (h_grid_idx + w_grid).reshape(4, -1)
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
            weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device)

            embeds = self.pos_embed(indices)
            embeds *= weights
            combined = embeds.sum(dim=0)

            combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,
                pos_embeds: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.patch_embed(hidden_states)
        hidden_states = hidden_states + pos_embeds
        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)

        for _, blk in enumerate(self.blocks):
            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        hidden_states = self.merger(hidden_states)

        return hidden_states


class Qwen3_5MLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int | None = None,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 is_tp: bool = True,
                 all_reduce: bool = True,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
            prefix=add_prefix('gate_up_proj', prefix),
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(
            intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            all_reduce=all_reduce,
            prefix=add_prefix('down_proj', prefix),
        )

    def forward(self, x, all_routed_experts: torch.Tensor | None = None):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen3_5GatedDeltaNet(nn.Module):
    """Gated deltanet."""

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        prefix: str = '',
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads
        self.kv_ratio = self.num_v_heads // self.num_k_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = layer_idx
        self.activation = config.hidden_act
        self.layer_norm_epsilon = config.rms_norm_eps

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = CausalConv1d(
            in_channels=self.conv_dim,
            out_channels=self.conv_dim,
            kernel_size=self.conv_kernel_size,
            split=[self.key_dim, self.key_dim, self.value_dim],
            bias=False,
            groups=self.conv_dim,
            dtype=dtype,
            device=device,
        )

        # projection of the input hidden states
        projection_size_qkv = self.key_dim * 2 + self.value_dim
        self.in_proj_qkv = build_colwise_linear(self.hidden_size,
                                                projection_size_qkv,
                                                bias=False,
                                                dtype=dtype,
                                                device=device,
                                                is_tp=True)
        self.in_proj_qkv.weight.weight_loader = self.weight_loader_qkv
        self.in_proj_zba = build_merged_colwise_linear(
            self.hidden_size,
            [self.value_dim, self.num_v_heads, self.num_v_heads],
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=True,
            out_names=['z', 'b', 'a'],
        )

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.make_params(self.num_v_heads, device=device)
        self.A_log_exp = None

        self.norm = build_rmsnorm_gated(self.head_v_dim,
                                        eps=self.layer_norm_epsilon,
                                        activation=self.activation,
                                        dtype=dtype,
                                        device=device)
        self.out_proj = build_o_proj(self.value_dim,
                                     self.hidden_size,
                                     bias=False,
                                     dtype=dtype,
                                     device=device,
                                     is_tp=True)

        self.gated_delta = GatedDelta()

    def get_A_log_exp(self):
        if self.A_log_exp is None:
            self.A_log_exp = -self.A_log.float().exp()

        return self.A_log_exp

    def make_params(self, num_v_heads: int, device: torch.device | None):
        tp, _ = get_tp_world_rank()
        num_v_heads = num_v_heads // tp
        A = torch.empty(num_v_heads, device=device)
        dt_bias = torch.empty(num_v_heads, device=device)

        self.register_parameter('A_log', nn.Parameter(torch.log(A)))
        self.register_parameter('dt_bias', nn.Parameter(dt_bias))
        self.A_log.weight_loader = self.weight_loader_a_dt
        self.dt_bias.weight_loader = self.weight_loader_a_dt

    def weight_loader_qkv(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader for qkv projection."""
        tp, rank = get_tp_world_rank()
        q, k, v = loaded_weight.split([self.key_dim, self.key_dim, self.value_dim], dim=0)
        q = q.chunk(tp, dim=0)[rank]
        k = k.chunk(tp, dim=0)[rank]
        v = v.chunk(tp, dim=0)[rank]
        loaded_weight = torch.cat([q, k, v], dim=0)
        default_weight_loader(param, loaded_weight)

    def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        tp, rank = get_tp_world_rank()
        loaded_weight = loaded_weight.chunk(tp, dim=0)[rank]
        default_weight_loader(param, loaded_weight)

    def fix_zba_ordering(self, mixed_zba: torch.Tensor):
        """Derives `query`, `key` and `value` tensors from `mixed_qkv` and
        `mixed_zba`."""

        # zba
        split_arg_list_zba = [self.head_v_dim * self.kv_ratio, self.kv_ratio, self.kv_ratio]
        num_heads = mixed_zba.size(-1) // sum(split_arg_list_zba)
        split_arg_list_zba = [num_heads * x for x in split_arg_list_zba]
        z, b, a = torch.split(mixed_zba, split_arg_list_zba, dim=-1)
        # [..., ng, np/ng * hn] -> [..., np, hn]
        z = z.unflatten(-1, (-1, self.head_v_dim))
        return z, b, a

    def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):
        """Load states from cache."""
        return gated_delta_util.load_state(past_key_value=past_key_value, gated_delta_meta=gated_delta_meta)

    def forward(
        self,
        hidden_states: torch.Tensor,
        past_key_value: Tuple[torch.Tensor, torch.Tensor],
        gated_delta_meta: GatedDeltaMeta,
    ):
        """forward."""

        # load states
        conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta)

        # inputs proj
        projected_states_qkv = self.in_proj_qkv(hidden_states)
        projected_states_zba = self.in_proj_zba(hidden_states)
        z, b, a = self.fix_zba_ordering(projected_states_zba)

        mixed_qkv = projected_states_qkv
        mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta)

        tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1)
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // tp,
                self.key_dim // tp,
                self.value_dim // tp,
            ],
            dim=-1,
        )
        query = query.unflatten(-1, (-1, self.head_k_dim))
        key = key.unflatten(-1, (-1, self.head_k_dim))
        value = value.unflatten(-1, (-1, self.head_v_dim))

        beta = b.sigmoid()
        # If the model is loaded in fp16, without the .float() here, A might be -inf
        g = self.get_A_log_exp() * F.softplus(a.float() + self.dt_bias)
        if self.kv_ratio > 1:
            query = query.repeat_interleave(self.kv_ratio, dim=-2)
            key = key.repeat_interleave(self.kv_ratio, dim=-2)

        core_attn_out, recurrent_state = self.gated_delta(
            query,
            key,
            value,
            g=g,
            beta=beta,
            recurrent_state=recurrent_state,
            gated_delta_meta=gated_delta_meta,
        )

        z_shape_og = z.shape
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)

        output = self.out_proj(core_attn_out)
        return output


class Qwen3_5Attention(nn.Module):
    """Rewrite module of Qwen3MoeAttention."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        self.head_dim = head_dim
        self.layer_idx = layer_idx
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)

        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads * 2,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            num_replicate_kv_heads=num_replicate_kv_heads,
            dtype=dtype,
            device=device,
            prefix=add_prefix('qkv_proj', prefix),
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
        )

        # o_proj
        self.o_proj = build_o_proj(
            num_heads * head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('o_proj', prefix),
        )

        # q, k norm
        self.q_norm = RMSNorm(
            head_dim,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('q_norm', prefix),
        )
        self.k_norm = RMSNorm(
            head_dim,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('k_norm', prefix),
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Tuple[torch.Tensor, torch.Tensor],
        attn_metadata: Any,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)
        query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
        gate = gate.reshape(*hidden_states.shape[:-1], -1)
        attn_output = attn_output * gate.sigmoid()

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen3_5DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.layer_type = config.layer_types[layer_idx]
        if self.layer_type == 'linear_attention':
            self.linear_attn = Qwen3_5GatedDeltaNet(config,
                                                    layer_idx,
                                                    dtype=dtype,
                                                    device=device,
                                                    prefix=add_prefix('linear_attn', prefix))
        elif self.layer_type == 'full_attention':
            self.self_attn = Qwen3_5Attention(config,
                                              layer_idx,
                                              dtype=dtype,
                                              device=device,
                                              prefix=add_prefix('self_attn', prefix))

        # build MLP
        self.mlp = Qwen3_5MLP(config,
                              intermediate_size=config.intermediate_size,
                              dtype=dtype,
                              device=device,
                              prefix=add_prefix('mlp', prefix))

        # build input layer norm
        self.input_layernorm = RMSNorm(
            config.hidden_size,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('input_layernorm', prefix),
        )

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: List[torch.FloatTensor],
        residual: torch.Tensor | None,
        attn_metadata: Any,
        gated_delta_meta: GatedDeltaMeta,
        all_routed_experts: torch.Tensor | None = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        if self.layer_type == 'linear_attention':
            hidden_states = self.linear_attn(
                hidden_states=hidden_states,
                past_key_value=past_key_value,
                gated_delta_meta=gated_delta_meta,
            )
        elif self.layer_type == 'full_attention':
            hidden_states = self.self_attn(
                hidden_states=hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                attn_metadata=attn_metadata,
            )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)

        outputs = (hidden_states, residual)
        return outputs


class Qwen3_5TextRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: PretrainedConfig, device=None):
        super().__init__()
        rope_scaling = get_rope_parameters(config)
        assert rope_scaling is not None, 'RoPE scaling parameters must be provided in the config for Qwen3.5 models.'
        self.rope_type = rope_scaling.get('rope_type', 'default')

        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        if self.rope_type != 'default':
            self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        else:
            self.rope_init_fn = self.compute_default_rope_parameters

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

        self.mrope_section = rope_scaling.get('mrope_section', [11, 11, 10])

    @staticmethod
    def compute_default_rope_parameters(
        config: PretrainedConfig | None = None,
        device: torch.device | None = None,
        seq_len: int | None = None,
    ) -> tuple['torch.Tensor', float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        rope_parameters = get_rope_parameters(config)
        base = rope_parameters['rope_theta']
        partial_rotary_factor = rope_parameters.get('partial_rotary_factor', 1.0)
        head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads
        dim = int(head_dim * partial_rotary_factor)

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
        return inv_freq, attention_factor

    def apply_interleaved_mrope(self, freqs, mrope_section):
        """Apply interleaved MRoPE to 3D rotary embeddings.

        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
        interleaved [THTHWHTHW...TT], preserving frequency continuity.
        args:
            x: (3, bs, seq_len, head_dim // 2)
            mrope_section: (3,)
        returns:
            x_t: (bs, seq_len, head_dim // 2)
        """
        freqs_t = freqs[0]  # just overwrite the first dimension T
        for dim, offset in enumerate((1, 2), start=1):  # H, W
            length = mrope_section[dim] * 3
            idx = slice(offset, length, 3)
            freqs_t[..., idx] = freqs[dim, ..., idx]
        return freqs_t

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        # In contrast to other models, Qwen3VL has different position ids for the grids
        # So we expand the inv_freq to shape (3, ...)
        if position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)

        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
        freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos() * self.attention_scaling
        sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Qwen3_5TextModel(nn.Module):
    """qwen3.5 text model."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        # TODO: use full config.num_hidden_layers
        self.layers = nn.ModuleList([
            Qwen3_5DecoderLayer(config,
                                layer_idx,
                                dtype=dtype,
                                device=device,
                                prefix=add_prefix(f'layers.{layer_idx}', prefix))
            for layer_idx in range(self.config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor,
        position_ids: torch.LongTensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any,
        state_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        mrope_position_ids: torch.Tensor | None = None,
        all_routed_experts: torch.Tensor | None = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        if mrope_position_ids is None:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
        else:
            mrope_position_ids = mrope_position_ids.unsqueeze(1)
            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)

        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # make seq_idx
        gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids,
                                          attn_metadata)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_values[idx],
                residual=residual,
                attn_metadata=attn_metadata,
                gated_delta_meta=gated_delta_meta,
                all_routed_experts=all_routed_experts,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen3_5Model(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()

        self.visual = Qwen3_5VisionModel(config.vision_config,
                                         dtype=dtype,
                                         device=device,
                                         prefix=add_prefix('visual', prefix))
        self.language_model = Qwen3_5TextModel(config.text_config,
                                               dtype=dtype,
                                               device=device,
                                               prefix=add_prefix('language_model', prefix))

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any,
        state_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        mrope_position_ids: torch.Tensor | None = None,
        pixel_values: torch.Tensor | None = None,
        vis_cu_seqlens: torch.Tensor | None = None,
        vis_pos_emb: torch.Tensor | None = None,
        image_mask: torch.Tensor | None = None,
        pos_embeds: torch.Tensor | None = None,
        grid_thw: torch.Tensor | None = None,
        all_routed_experts: torch.Tensor | None = None,
    ):
        """Model forward, return logits."""

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))

                # get image embeds and deepstack visual embeds
                image_embeds = self.visual(pixel_values,
                                           cu_seqlens=vis_cu_seqlens,
                                           rotary_pos_emb=vis_pos_emb,
                                           pos_embeds=pos_embeds)

                # split image embeds per sample
                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
                image_embeds = torch.split(image_embeds, split_sizes)
                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)

                # mask and scatter to create final input embeddings
                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)

        hidden_states = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            state_ids=state_ids,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            all_routed_experts=all_routed_experts,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()


class Qwen3_5ForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # build preprocessor
        self.input_processor = Qwen3_5InputProcessor(self.config)

        # build model
        self.model = Qwen3_5Model(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
        # build lm_head
        self.lm_head = self.build_lm_head(config.text_config.hidden_size,
                                          config.text_config.vocab_size,
                                          bias=False,
                                          dtype=dtype,
                                          device=device)
        # dense model
        self.enable_return_routed_experts = False

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any,
        state_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        mrope_position_ids: torch.Tensor | None = None,
        pixel_values: torch.Tensor | None = None,
        vis_cu_seqlens: torch.Tensor | None = None,
        vis_pos_emb: torch.Tensor | None = None,
        image_mask: torch.Tensor | None = None,
        pos_embeds: torch.Tensor | None = None,
        grid_thw: torch.Tensor | None = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        all_routed_experts = None
        if self.enable_return_routed_experts:
            config = self.config.text_config
            num_tokens = input_ids.size(1)
            all_routed_experts = position_ids.new_empty(
                (num_tokens, config.num_hidden_layers, config.num_experts_per_tok), dtype=torch.uint16)

        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            state_ids=state_ids,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_mask=image_mask,
            pos_embeds=pos_embeds,
            grid_thw=grid_thw,
            all_routed_experts=all_routed_experts,
        )
        if all_routed_experts is None:
            return hidden_states
        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor | None = None,
        context: StepContext | None = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # make past_key_values
        state_caches = list(cache.transpose(0, 1) for cache in context.state_caches)
        state_caches = list(zip(state_caches[0], state_caches[1]))
        past_key_values = list(past_key_values)
        new_past_key_values = []
        for layer_type in self.config.text_config.layer_types:
            if layer_type == 'linear_attention':
                new_past_key_values.append(state_caches.pop(0))
            elif layer_type == 'full_attention':
                new_past_key_values.append(past_key_values.pop(0))

        # vlm inputs
        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_mask = None
        grid_thw = None
        pos_embeds = None
        if context.input_multimodals is not None:
            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]
            # flatten batch
            mm_inputs = [item for sublist in mm_inputs for item in sublist]

            if len(mm_inputs) > 0:
                modality = mm_inputs[0].modality
                pixel_values = torch.cat([inp.data for inp in mm_inputs])

                image_token_id = mm_inputs[0].meta.get('image_token_id')
                video_token_id = mm_inputs[0].meta.get('video_token_id')
                mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id
                image_mask = (input_ids == mm_token_id)

                grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()
                vis_pos_emb = self.model.visual.rot_pos_emb(grid_thw)
                pos_embeds = self.model.visual.fast_pos_embed_interpolate(grid_thw)
                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                         grid_thw[:, 0]).to(pixel_values.device)
                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
                vis_pos_emb = vis_pos_emb.repeat(1, 2)
                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())

        mrope_position_ids = getattr(context, 'mrope_position_ids', None)

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=new_past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            state_ids=context.state_offsets,
            # vl inputs
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_mask=image_mask,
            grid_thw=grid_thw,
            pos_embeds=pos_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_layers(name):
            """We might change the number of layers so we can debug the model
            with less gpus."""
            import re
            if '.layers.' not in name:
                return False
            matches = re.findall(r'\.layers\.(\d+)\.', name)
            layer_id = int(matches[0])
            return layer_id >= self.config.text_config.num_hidden_layers

        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
            ('.in_proj_zba', '.in_proj_z', 'z'),
            ('.in_proj_zba', '.in_proj_b', 'b'),
            ('.in_proj_zba', '.in_proj_a', 'a'),
        ]

        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:

            if __skip_layers(name):
                continue

            if 'mtp.' in name:
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.qkv.' in name:
                    # vl attention
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    for rms_norm_key in rms_norm_keys:
                        if rms_norm_key in name and 'weight' in name:
                            loaded_weight = loaded_weight + 1
                            break
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""

        max_batchs = graph_meta.max_batchs
        device = graph_meta.device
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)

        state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device)
        input_buffers['state_ids'] = state_ids
        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs):
        """Fill cudagraph buffers from forward inputs."""
        input_buffers = graph_meta.input_buffers
        new_inputs = super().fill_buffers_cudagraph(graph_meta, *args, **kwargs)
        state_ids = kwargs['state_ids']
        input_buffers['state_ids'].fill_(-1)
        input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids)
        new_inputs['state_ids'] = input_buffers['state_ids']

        input_ids = kwargs.get('input_ids')
        num_tokens = input_ids.size(-1)
        new_batch_size = graph_meta.max_batchs

        is_decoding = graph_meta.is_decoding
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
            if is_decoding:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
            else:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']

        return new_inputs

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(mrope_delta=0)] * batch_size
        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]

    def _update_model_meta_decoding(self, context: StepContext):
        """Update model meta for decoding."""
        model_metas = self._get_model_metas(context)
        position_ids = context.position_ids

        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
        mrope_deltas_cpu = torch.tensor(mrope_deltas, device='cpu')
        if (mrope_deltas_cpu == mrope_deltas_cpu[0]).all():
            mrope_deltas = position_ids.new_full((len(mrope_deltas), ), mrope_deltas[0])
        else:
            mrope_deltas = position_ids.new_tensor(mrope_deltas)
        mrope_position_ids = position_ids + mrope_deltas[None]
        mrope_position_ids = mrope_position_ids.expand(3, -1)

        context.mrope_position_ids = mrope_position_ids
        return model_metas

    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
        """Get mrope ids."""
        t, h, w = grid_thw
        h //= 2
        w //= 2
        stride = torch.tensor([h * w, w, 1], device=device)[:, None]
        size = torch.tensor([t, h, w], device=device)[:, None]
        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
        pos_ids = pos_ids // stride % size
        return pos_ids

    def _update_model_meta_prefilling(self, context: StepContext):
        """Update model meta for prefilling."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_multimodals = [None] * len(model_metas)
        position_ids = context.position_ids
        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
        mrope_position_ids = []
        new_model_metas = []
        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
            mm_data_list = []
            if input_mm is not None:
                mm_data_list.extend(input_mm.get('mm_data', []))

            if model_meta is None or 'mrope_delta' not in model_meta:
                mrope_delta = 0
            else:
                mrope_delta = model_meta['mrope_delta']

            pos_start = pos_ids[0].item()
            mrope_pos_ids = pos_ids + mrope_delta
            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()

            for mm_data in mm_data_list:
                if mm_data.modality == Modality.IMAGE:
                    grid_thw = mm_data.meta['grid_thw'][0].tolist()
                    _, h, w = grid_thw
                    h //= 2
                    w //= 2
                    num_pad = mm_data.end - mm_data.start - max(h, w)
                    mrope_delta -= num_pad
                    fill_start = mm_data.start - pos_start
                    fill_end = mm_data.end - pos_start
                    img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
                    img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
                    mrope_pos_ids[:, fill_end:] -= num_pad
                    mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids
                elif mm_data.modality == Modality.VIDEO:
                    video_token_id = self.config.video_token_id
                    grid_thw = mm_data.meta['grid_thw']

                    grid_thw = torch.repeat_interleave(grid_thw, grid_thw[:, 0], dim=0)
                    grid_thw[:, 0] = 1

                    position_ids_list = []
                    input_tokens = context.input_ids.tolist()[0]

                    st = 0
                    # treat each frame separately as a single image
                    for video_idx in range(grid_thw.shape[0]):
                        # text before video. e.g. <0.3 seconds><|vision_start|> ...
                        ed_video = input_tokens.index(video_token_id, st)
                        ed = ed_video
                        text_len = ed - st
                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0
                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx
                        position_ids_list.append(text_pos_ids)

                        # video frame.  ... <|video_end|>
                        t, h, w = (
                            grid_thw[video_idx][0],
                            grid_thw[video_idx][1] // 2,
                            grid_thw[video_idx][2] // 2,
                        )
                        video_pos_ids = self._get_multimodal_pos_ids(grid_thw[video_idx], pos_ids.device)
                        position_ids_list.append(video_pos_ids + text_len + st_idx)

                        st = ed + t * h * w

                    # text after video, <|vision_end|> ...
                    if st < len(input_tokens):
                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0
                        text_len = len(input_tokens) - st
                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx
                        position_ids_list.append(text_pos_ids)

                    mrope_pos_ids = torch.cat(position_ids_list, dim=1).reshape(3, -1)
                    mrope_delta = mrope_pos_ids.max() + 1 - pos_ids.size(0)
                    mrope_pos_ids += pos_start  # add back the original position offset

            mrope_position_ids.append(mrope_pos_ids)
            new_model_metas.append(dict(mrope_delta=mrope_delta))

        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
        context.mrope_position_ids = mrope_position_ids

        return new_model_metas

    def update_model_metas(self, past_key_values: List[List[torch.Tensor]], inputs_embeds: torch.Tensor | None,
                           context: StepContext):
        """Update model meta."""
        if context.is_decoding:
            return self._update_model_meta_decoding(context)
        else:
            return self._update_model_meta_prefilling(context)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


================================================
FILE: lmdeploy/pytorch/models/qwen3_5_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, Iterable, List, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.nn import RMSNorm
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix, get_build_model_context
from .qwen3_5 import (Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5ForConditionalGeneration, Qwen3_5GatedDeltaNet,
                      Qwen3_5MLP, Qwen3_5Model, Qwen3_5TextModel, Qwen3_5TextRotaryEmbedding)
from .qwen3_5 import Qwen3_5VisionModel as Qwen3_5MoeVisionModel
from .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5MoeInputProcessor


class Qwen3_5MoeTopKRouter(nn.Module):

    def __init__(self, config, dtype: torch.dtype | None = None, device: torch.device | None = None):
        super().__init__()
        self.top_k = config.num_experts_per_tok
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_size
        self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, dtype=dtype, device=device))

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts)
        router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k)
        router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
        router_top_value = router_top_value.to(router_logits.dtype)
        router_scores = router_top_value
        return router_logits, router_scores, router_indices


class Qwen3_5MoeSparseMoeBlock(nn.Module):
    """Sparse MoE block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok

        self.gate = Qwen3_5MoeTopKRouter(config, dtype=dtype, device=device)

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=False,
            layer_idx=layer_idx,
            prefix=add_prefix('experts', prefix),
        )

        self.shared_expert = Qwen3_5MLP(
            config=config,
            intermediate_size=config.shared_expert_intermediate_size,
            dtype=dtype,
            device=device,
            is_tp=True,
            all_reduce=False,
            prefix=add_prefix('shared_expert', prefix),
        )
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)

        # get all reduce
        dist_ctx = get_dist_manager().current_context()
        dp = dist_ctx.dist_config.dp
        world_size = dist_ctx.dist_config.moe_tp
        if dp == 1 and world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor, all_routed_experts: torch.Tensor | None = None):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.reshape(-1, hidden_dim)
        router_logits, topk_weights, topk_ids = self.gate(hidden_states)
        if all_routed_experts is not None:
            all_routed_experts[:, self.layer_idx, :] = topk_ids
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        shared_states = self.shared_expert(hidden_states)
        shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states

        out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)
        return out_states


class Qwen3_5MoeDecoderLayer(Qwen3_5DecoderLayer):
    """Decoder layer."""

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        prefix: str = '',
    ):
        nn.Module.__init__(self)
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.layer_type = config.layer_types[layer_idx]
        if self.layer_type == 'linear_attention':
            self.linear_attn = Qwen3_5GatedDeltaNet(config,
                                                    layer_idx,
                                                    dtype=dtype,
                                                    device=device,
                                                    prefix=add_prefix('linear_attn', prefix))
        elif self.layer_type == 'full_attention':
            self.self_attn = Qwen3_5Attention(config,
                                              layer_idx,
                                              dtype=dtype,
                                              device=device,
                                              prefix=add_prefix('self_attn', prefix))

        # build MLP
        self.mlp = Qwen3_5MoeSparseMoeBlock(config,
                                            layer_idx,
                                            dtype=dtype,
                                            device=device,
                                            prefix=add_prefix('mlp', prefix))

        # build input layer norm
        self.input_layernorm = RMSNorm(
            config.hidden_size,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('input_layernorm', prefix),
        )

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)


class Qwen3_5MoeTextModel(Qwen3_5TextModel):

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        nn.Module.__init__(self)
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        # TODO: use full config.num_hidden_layers
        self.layers = nn.ModuleList([
            Qwen3_5MoeDecoderLayer(config,
                                   layer_idx,
                                   dtype=dtype,
                                   device=device,
                                   prefix=add_prefix(f'layers.{layer_idx}', prefix))
            for layer_idx in range(self.config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)


class Qwen3_5MoeModel(Qwen3_5Model):

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        nn.Module.__init__(self)

        self.visual = Qwen3_5MoeVisionModel(config.vision_config,
                                            dtype=dtype,
                                            device=device,
                                            prefix=add_prefix('visual', prefix))
        self.language_model = Qwen3_5MoeTextModel(config.text_config,
                                                  dtype=dtype,
                                                  device=device,
                                                  prefix=add_prefix('language_model', prefix))


class Qwen3_5MoeForConditionalGeneration(Qwen3_5ForConditionalGeneration):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None,
                 prefix: str = ''):
        nn.Module.__init__(self)
        self.config = config
        self.ctx_mgr = ctx_mgr

        # build preprocessor
        self.input_processor = Qwen3_5MoeInputProcessor(self.config)

        # build model
        self.model = Qwen3_5MoeModel(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
        # build lm_head
        self.lm_head = self.build_lm_head(config.text_config.hidden_size,
                                          config.text_config.vocab_size,
                                          bias=False,
                                          dtype=dtype,
                                          device=device)
        # for router replay
        bm_ctx = get_build_model_context()
        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        # this func is not used, but it has same layout with tranformers implementation
        # so I will keep it for now.
        # load fused weights
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
        """Load weight of fused expert weights."""
        num_experts = self.config.text_config.num_experts
        fused_gateup_name = 'gate_up_proj'
        fused_down_name = 'down_proj'
        if fused_gateup_name in name:

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up.weight')
                param = params_dict[param_name]
                weight = loaded_weight[expert_id]
                w1, w3 = weight.chunk(2, 0)
                load_weight(param, w1, expert_id=expert_id, shard_id='gate')
                load_weight(param, w3, expert_id=expert_id, shard_id='up')

        elif fused_down_name in name:

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down.weight')
                param = params_dict[param_name]
                w2 = loaded_weight[expert_id]
                load_weight(param, w2, expert_id=expert_id, shard_id='down')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_layers(name):
            """We might change the number of layers so we can debug the model
            with less gpus."""
            import re
            if '.layers.' not in name:
                return False
            matches = re.findall(r'\.layers\.(\d+)\.', name)
            layer_id = int(matches[0])
            return layer_id >= self.config.text_config.num_hidden_layers

        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
            ('.in_proj_zba', '.in_proj_z', 'z'),
            ('.in_proj_zba', '.in_proj_b', 'b'),
            ('.in_proj_zba', '.in_proj_a', 'a'),
        ]

        # expert map
        num_experts = self.config.text_config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:

            if __skip_layers(name):
                continue

            if 'mtp.' in name:
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            if '.experts' in name and '.shared_expert' not in name:
                self._load_weight_fused_experts(name, loaded_weight, params_dict)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    if '.qkv.' in name:
                        # vl attention
                        param = params_dict[name]
                        q, k, v = param.weight_spliter(loaded_weight)
                        load_weight(param, q, shard_id='q')
                        load_weight(param, k, shard_id='k')
                        load_weight(param, v, shard_id='v')
                    else:
                        for rms_norm_key in rms_norm_keys:
                            if rms_norm_key in name and 'weight' in name:
                                loaded_weight = loaded_weight + 1
                                break
                        param = params_dict[name]
                        load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen3_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.eplb import EPLBManager
from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix, get_build_model_context
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class Qwen3MoeAttention(nn.Module):
    """Rewrite module of Qwen3MoeAttention."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)

        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            num_replicate_kv_heads=num_replicate_kv_heads,
            dtype=dtype,
            device=device,
            prefix=add_prefix('qkv_proj', prefix),
        )
        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=getattr(config, 'sliding_window', None),
        )

        # o_proj
        self.o_proj = build_rowwise_linear(
            num_heads * head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('o_proj', prefix),
        )

        # q, k norm
        self.q_norm = RMSNorm(
            head_dim,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('q_norm', prefix),
        )
        self.k_norm = RMSNorm(
            head_dim,
            config.rms_norm_eps,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('k_norm', prefix),
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Tuple[torch.Tensor] | None = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen3MoeMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
            prefix=add_prefix('gate_up_proj', prefix),
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(
            intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            all_reduce=all_reduce,
            prefix=add_prefix('down_proj', prefix),
        )

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen3MoeSparseMoeBlock(nn.Module):
    """Moe block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.norm_topk_prob

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )
        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        if get_dist_manager().current_context().dist_config.enable_eplb:
            dist_ctx = get_dist_manager().current_context()
            self.eplb_dispatch_info = EPLBManager.get_dispatch_info(
                ep_rank=dist_ctx.ep_rank,
                layer_idx=layer_idx,
            )
            self.num_experts = EPLBManager.num_physical_experts()
        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=True,
            layer_idx=layer_idx,
            prefix=add_prefix('experts', prefix),
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        all_routed_experts: torch.Tensor = None,
    ):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)
        topk_weights, topk_ids = self.softmax_topk(router_logits)
        if all_routed_experts is not None:
            all_routed_experts[:, self.layer_idx, :] = topk_ids
        if get_dist_manager().current_context().dist_config.enable_eplb:
            topk_ids = EPLBManager.topk_ids_logical_to_physical(topk_ids, self.eplb_dispatch_info)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        out_states = out_states.reshape(batch_size, sequence_length, -1)
        return out_states


class Qwen3MoeDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        dtype: torch.dtype = None,
        device: torch.device = None,
        prefix: str = '',
    ):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = Qwen3MoeAttention(config, dtype=dtype, device=device, prefix=add_prefix('self_attn', prefix))

        # build MLP
        if (layer_idx not in config.mlp_only_layers) and (config.num_experts
                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):
            self.mlp = Qwen3MoeSparseMoeBlock(config,
                                              layer_idx=layer_idx,
                                              dtype=dtype,
                                              device=device,
                                              prefix=add_prefix('mlp', prefix))
        else:
            self.mlp = Qwen3MoeMLP(config,
                                   intermediate_size=config.intermediate_size,
                                   dtype=dtype,
                                   device=device,
                                   prefix=add_prefix('mlp', prefix))

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       prefix=add_prefix('input_layernorm', prefix))

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                dtype=dtype,
                                                device=device,
                                                prefix=add_prefix('post_attention_layernorm', prefix))

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: List[torch.FloatTensor] | None,
        residual: torch.Tensor | None = None,
        attn_metadata: Any = None,
        all_routed_experts: torch.Tensor = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)

        outputs = (hidden_states, residual)
        return outputs


class Qwen3MoeModel(nn.Module):
    """model."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        self.padding_idx = getattr(config, 'pad_token_id', None)
        self.vocab_size = config.vocab_size
        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        if get_dist_manager().current_context().dist_config.enable_eplb:
            ep_size, _ = get_ep_world_rank()
            EPLBManager.init_global_eplb_metadata(
                ep_size=ep_size,
                num_routed_experts=config.num_experts,
                num_hidden_layers=config.num_hidden_layers,
            )

        # build all decode layers
        self.layers = nn.ModuleList([
            Qwen3MoeDecoderLayer(config,
                                 layer_idx,
                                 dtype=dtype,
                                 device=device,
                                 prefix=add_prefix(f'layers.{layer_idx}', prefix))
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size,
                            config.rms_norm_eps,
                            dtype=dtype,
                            device=device,
                            prefix=add_prefix('norm', prefix))

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: List[torch.FloatTensor] | None = None,
        attn_metadata: Any = None,
        inputs_embeds: torch.FloatTensor | None = None,
        all_routed_experts: torch.Tensor = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
                all_routed_experts=all_routed_experts,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen3MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(
        self,
        config: PretrainedConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        prefix: str = '',
    ):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # build model
        self.model = Qwen3MoeModel(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
        # build lm_head
        self.lm_head = self.build_lm_head(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            dtype=dtype,
            device=device,
        )
        # for router replay
        bm_ctx = get_build_model_context()
        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""

        # router replay
        all_routed_experts = None
        if self.enable_return_routed_experts:
            if inputs_embeds is not None:
                num_tokens = inputs_embeds.size(1)
            else:
                num_tokens = input_ids.size(1)
            all_routed_experts = position_ids.new_empty(
                (num_tokens, self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.uint16)

        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            all_routed_experts=all_routed_experts,
        )
        if all_routed_experts is None:
            return hidden_states
        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor | None = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        # load fused weights
        if any([k in name for k in ['fused_w1w3', 'fused_w2']]):
            return self._load_weight_fused_experts(name, loaded_weight, params_dict)

        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
        """Load weight of fused expert weights."""
        num_experts = self.config.num_experts
        fused_gateup_name = 'fused_w1w3'
        fused_down_name = 'fused_w2'
        if fused_gateup_name in name:
            chunk_size = loaded_weight.shape[0] // num_experts

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up')
                param = params_dict[param_name]
                w1 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size // 2)
                w3 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id + chunk_size // 2, length=chunk_size // 2)
                load_weight(param, w1, expert_id=expert_id, shard_id='gate')
                load_weight(param, w3, expert_id=expert_id, shard_id='up')

        elif fused_down_name in name:
            chunk_size = loaded_weight.shape[0] // num_experts

            for expert_id in range(num_experts):
                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down')
                param = params_dict[param_name]
                w2 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size)
                load_weight(param, w2, expert_id=expert_id, shard_id='down')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert map
        num_experts = self.config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen3_next.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

import lmdeploy.pytorch.distributed as dist
import lmdeploy.pytorch.nn.gated_delta as gated_delta_util
from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight

from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class Qwen3NextGatedDeltaNet(nn.Module):
    """Gated deltanet."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads
        self.kv_ratio = self.num_v_heads // self.num_k_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = layer_idx
        self.activation = config.hidden_act
        self.layer_norm_epsilon = config.rms_norm_eps

        # QKV
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = CausalConv1d(
            in_channels=self.conv_dim,
            out_channels=self.conv_dim,
            kernel_size=self.conv_kernel_size,
            split=[self.key_dim, self.key_dim, self.value_dim],
            bias=False,
            groups=self.conv_dim,
            dtype=dtype,
            device=device,
        )

        # projection of the input hidden states
        projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
        projection_size_ba = self.num_v_heads * 2
        self.in_proj_qkvz = build_colwise_linear(self.hidden_size,
                                                 projection_size_qkvz,
                                                 bias=False,
                                                 dtype=dtype,
                                                 device=device,
                                                 is_tp=True)
        self.in_proj_ba = build_colwise_linear(self.hidden_size,
                                               projection_size_ba,
                                               bias=False,
                                               dtype=dtype,
                                               device=device,
                                               is_tp=True)

        # time step projection (discretization)
        # instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.make_params(self.num_v_heads, device=device)
        self.A_log_exp = None

        self.norm = build_rmsnorm_gated(self.head_v_dim,
                                        eps=self.layer_norm_epsilon,
                                        activation=self.activation,
                                        dtype=dtype,
                                        device=device)
        self.out_proj = build_o_proj(self.value_dim,
                                     self.hidden_size,
                                     bias=False,
                                     dtype=dtype,
                                     device=device,
                                     is_tp=True)

        self.gated_delta = GatedDelta()

    def get_A_log_exp(self):
        if self.A_log_exp is None:
            self.A_log_exp = -self.A_log.float().exp()

        return self.A_log_exp

    def make_params(self, num_v_heads: int, device: torch.device | None):
        tp, _ = get_tp_world_rank()
        num_v_heads = num_v_heads // tp
        A = torch.empty(num_v_heads, device=device)
        dt_bias = torch.empty(num_v_heads, device=device)

        self.register_parameter('A_log', nn.Parameter(torch.log(A)))
        self.register_parameter('dt_bias', nn.Parameter(dt_bias))
        self.A_log.weight_loader = self.weight_loader_a_dt
        self.dt_bias.weight_loader = self.weight_loader_a_dt

    def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        tp, rank = get_tp_world_rank()
        loaded_weight = loaded_weight.chunk(tp, dim=0)[rank]
        default_weight_loader(param, loaded_weight)

    def fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor):
        """Derives `query`, `key` and `value` tensors from `mixed_qkvz` and
        `mixed_ba`."""
        # qkvz
        split_arg_list_qkvz = [
            self.head_k_dim * 2,
            (self.kv_ratio * self.head_v_dim),
            (self.kv_ratio * self.head_v_dim),
        ]
        mixed_qkvz = mixed_qkvz.unflatten(-1, (-1, sum(split_arg_list_qkvz)))
        qk, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1)
        qk = qk.unflatten(-1, (2, self.head_k_dim))
        qk = qk.transpose(-3, -2).flatten(-3, -1)
        value = value.flatten(-2, -1)
        mixed_qkv = torch.cat((qk, value), dim=-1)
        # [..., ng, np/ng * hn] -> [..., np, hn]
        z = z.reshape(*z.shape[:-2], -1, self.head_v_dim)

        # chunk_ba
        mixed_ba = mixed_ba.unflatten(-1, (-1, 2 * self.kv_ratio))
        b, a = mixed_ba.chunk(2, -1)
        # do sigmoid and float here to prevent contiguous kernel
        b = b.sigmoid().flatten(-2, -1)
        a = a.float().flatten(-2, -1)
        return mixed_qkv, z, b, a

    def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):
        """Load states from cache."""
        return gated_delta_util.load_state(past_key_value=past_key_value, gated_delta_meta=gated_delta_meta)

    def forward(
        self,
        hidden_states: torch.Tensor,
        past_key_value: Tuple[torch.Tensor, torch.Tensor],
        gated_delta_meta: GatedDeltaMeta,
    ):
        """forward."""

        # load states
        conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta)

        # inputs proj
        projected_states_qkvz = self.in_proj_qkvz(hidden_states)
        projected_states_ba = self.in_proj_ba(hidden_states)
        mixed_qkv, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)

        mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta)

        tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1)
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // tp,
                self.key_dim // tp,
                self.value_dim // tp,
            ],
            dim=-1,
        )
        query = query.unflatten(-1, (-1, self.head_k_dim))
        key = key.unflatten(-1, (-1, self.head_k_dim))
        value = value.unflatten(-1, (-1, self.head_v_dim))

        beta = b
        # If the model is loaded in fp16, without the .float() here, A might be -inf
        g = self.get_A_log_exp() * F.softplus(a + self.dt_bias)
        if self.kv_ratio > 1:
            query = query.repeat_interleave(self.kv_ratio, dim=-2)
            key = key.repeat_interleave(self.kv_ratio, dim=-2)

        core_attn_out, recurrent_state = self.gated_delta(
            query,
            key,
            value,
            g=g,
            beta=beta,
            recurrent_state=recurrent_state,
            gated_delta_meta=gated_delta_meta,
        )

        z_shape_og = z.shape
        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        z = z.reshape(-1, z.shape[-1])
        core_attn_out = self.norm(core_attn_out, z)
        core_attn_out = core_attn_out.reshape(z_shape_og)
        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)

        output = self.out_proj(core_attn_out)
        return output


class Qwen3NextAttention(nn.Module):
    """Rewrite module of Qwen3MoeAttention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        self.head_dim = head_dim
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)

        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads * 2,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.attention_bias,
            quant_config=quantization_config,
            num_replicate_kv_heads=num_replicate_kv_heads,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

        # q, k norm
        self.q_norm = RMSNorm(head_dim,
                              config.rms_norm_eps,
                              quant_config=quantization_config,
                              dtype=dtype,
                              device=device)
        self.k_norm = RMSNorm(head_dim,
                              config.rms_norm_eps,
                              quant_config=quantization_config,
                              dtype=dtype,
                              device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)
        query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
        gate = gate.reshape(*hidden_states.shape[:-1], -1)
        attn_output = attn_output * gate.sigmoid()

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen3NextMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int = None,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 is_tp: bool = True,
                 all_reduce: bool = True):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        if intermediate_size is None:
            intermediate_size = config.intermediate_size
        # gate up
        self.gate_up_proj = build_merged_colwise_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=is_tp,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_rowwise_linear(intermediate_size,
                                              config.hidden_size,
                                              bias=False,
                                              quant_config=quantization_config,
                                              dtype=dtype,
                                              device=device,
                                              is_tp=is_tp,
                                              all_reduce=all_reduce)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class Qwen3NextSparseMoeBlock(nn.Module):
    """Moe block."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        self.layer_idx = layer_idx
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.renormalize = self.norm_topk_prob

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=False,
            layer_idx=layer_idx,
        )

        self.shared_expert = Qwen3NextMLP(
            config=config,
            intermediate_size=config.shared_expert_intermediate_size,
            dtype=dtype,
            device=device,
            is_tp=True,
            all_reduce=False,
        )
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)

        # get all reduce
        dist_ctx = get_dist_manager().current_context()
        dp = dist_ctx.dist_config.dp
        world_size = dist_ctx.dist_config.moe_tp
        if dp == 1 and world_size > 1:
            self._all_reduce = True
        else:
            self._all_reduce = False

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)
        topk_weights, topk_ids = self.softmax_topk(router_logits)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        shared_states = self.shared_expert(hidden_states)
        shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states

        out_states += shared_states
        out_states = out_states.reshape(batch_size, sequence_length, -1)

        if self._all_reduce:
            dist.all_reduce(out_states)
        return out_states


class Qwen3NextDecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.layer_type = config.layer_types[layer_idx]
        if self.layer_type == 'linear_attention':
            self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx, dtype=dtype, device=device)
        elif self.layer_type == 'full_attention':
            self.self_attn = Qwen3NextAttention(config, dtype=dtype, device=device)

        # build MLP
        if (layer_idx not in config.mlp_only_layers) and (config.num_experts
                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):
            self.mlp = Qwen3NextSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device)
        else:
            self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor],
        attn_metadata: Any,
        gated_delta_meta: GatedDeltaMeta,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        if self.layer_type == 'linear_attention':
            hidden_states = self.linear_attn(
                hidden_states=hidden_states,
                past_key_value=past_key_value,
                gated_delta_meta=gated_delta_meta,
            )
        elif self.layer_type == 'full_attention':
            hidden_states = self.self_attn(
                hidden_states=hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                attn_metadata=attn_metadata,
            )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Qwen3NextModel(nn.Module):
    """Qwen3 next model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = build_embedding(
            config.vocab_size,
            config.hidden_size,
            self.padding_idx,
            dtype=dtype,
            device=device,
        )

        # build all decode layers
        # TODO: use full config.num_hidden_layers
        self.layers = nn.ModuleList([
            Qwen3NextDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(self.config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor,
        position_ids: torch.LongTensor,
        past_key_values: List[torch.FloatTensor],
        attn_metadata: Any,
        state_ids: torch.Tensor,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # make seq_idx
        gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids,
                                          attn_metadata)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_values[idx],
                residual=residual,
                attn_metadata=attn_metadata,
                gated_delta_meta=gated_delta_meta,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Qwen3NextForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Qwen3NextModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        state_ids: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            state_ids=state_ids,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # make past_key_values
        state_caches = list(cache.transpose(0, 1) for cache in context.state_caches)
        state_caches = list(zip(state_caches[0], state_caches[1]))
        past_key_values = list(past_key_values)
        new_past_key_values = []
        for layer_type in self.config.layer_types:
            if layer_type == 'linear_attention':
                new_past_key_values.append(state_caches.pop(0))
            elif layer_type == 'full_attention':
                new_past_key_values.append(past_key_values.pop(0))

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=new_past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            state_ids=context.state_offsets,
        )

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_batchs = graph_meta.max_batchs
        device = graph_meta.device

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device)
        input_buffers['state_ids'] = state_ids

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""
        input_buffers = graph_meta.input_buffers

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        state_ids = kwargs['state_ids']
        input_buffers['state_ids'].fill_(-1)
        input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids)
        new_inputs['state_ids'] = input_buffers['state_ids']

        return new_inputs

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        # load fused weights
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""

        def __skip_layers(name):
            """We might change the number of layers so we can debug the model
            with less gpus."""
            import re
            if '.layers.' not in name:
                return False
            matches = re.findall(r'\.layers\.(\d+)\.', name)
            layer_id = int(matches[0])
            return layer_id >= self.config.num_hidden_layers

        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert map
        num_experts = self.config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:

            if __skip_layers(name):
                continue

            if 'mtp.' in name:
                continue
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name and '.shared_expert' not in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    for rms_norm_key in rms_norm_keys:
                        if rms_norm_key in name and 'weight' in name:
                            loaded_weight = loaded_weight + 1
                            break
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/qwen3_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from functools import lru_cache
from typing import Any, Dict, Iterable, List, Tuple

import numpy as np
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update

from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalData
from lmdeploy.pytorch.nn import LayerNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from lmdeploy.vl.constants import Modality

from .patch import add_prefix
from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding
from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention
from .qwen3 import Qwen3model
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
from .utils.model import DeployModelMixinV1, vlm_model


class Qwen3VLTextRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: PretrainedConfig, device=None):
        super().__init__()
        if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get('rope_type', 'default')
        else:
            self.rope_type = 'default'

        self._pack_for_trans5(config)

        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

        self.mrope_section = config.rope_scaling.get('mrope_section', [24, 20, 20])

    def _pack_for_trans5(self, config):
        if self.rope_type == 'default' and 'default' not in ROPE_INIT_FUNCTIONS:
            # transformers 5 has removed default in ROPE_INIT_FUNCTIONS
            self.rope_type = 'linear'
            rope_parameters = get_rope_parameters(config)
            if 'factor' not in rope_parameters:
                rope_parameters['factor'] = 1.0

    def apply_interleaved_mrope(self, freqs, mrope_section):
        """Apply interleaved MRoPE to 3D rotary embeddings.

        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
        interleaved [THTHWHTHW...TT], preserving frequency continuity.
        args:
            x: (3, bs, seq_len, head_dim // 2)
            mrope_section: (3,)
        returns:
            x_t: (bs, seq_len, head_dim // 2)
        """
        freqs_t = freqs[0]  # just overwrite the first dimension T
        for dim, offset in enumerate((1, 2), start=1):  # H, W
            length = mrope_section[dim] * 3
            idx = slice(offset, length, 3)
            freqs_t[..., idx] = freqs[dim, ..., idx]
        return freqs_t

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        # In contrast to other models, Qwen3VL has different position ids for the grids
        # So we expand the inv_freq to shape (3, ...)
        if position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu'
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
            freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Qwen3VLTextModel(Qwen3model):
    """Text part of Qwen3VL.

    not a pure text-only model, as DeepStack integrates visual features into the early hidden states.
    """

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__(config=config, dtype=dtype, device=device, prefix=prefix)

        # build rotary embedding
        # TODO: zhouxinyu, add triton kernel for interleaved mrope
        self.rotary_emb = Qwen3VLTextRotaryEmbedding(config, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: List[torch.FloatTensor] | None = None,
        attn_metadata: Any = None,
        inputs_embeds: torch.FloatTensor | None = None,
        mrope_position_ids: torch.LongTensor = None,
        # args for deepstack
        visual_pos_masks: torch.Tensor | None = None,
        deepstack_visual_embeds: List[torch.Tensor] | None = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        if mrope_position_ids is None:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
        else:
            mrope_position_ids = mrope_position_ids.unsqueeze(1)
            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)

        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

            # add visual features to the hidden states of first several layers
            if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):
                hidden_states = hidden_states + residual
                hidden_states = self._deepstack_process(
                    hidden_states,
                    visual_pos_masks,
                    deepstack_visual_embeds[idx],
                )
                residual = None

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
                           visual_embeds: torch.Tensor):
        visual_pos_masks = visual_pos_masks.to(hidden_states.device)
        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
        local = torch.zeros_like(hidden_states)
        local.masked_scatter_(visual_pos_masks, visual_embeds)
        hidden_states += local
        return hidden_states


class Qwen3VLVisionPatchEmbed(nn.Module):

    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None:
        super().__init__()
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size

        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
        self.proj = nn.Conv3d(self.in_channels,
                              self.embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=True,
                              dtype=dtype,
                              device=device)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
                                           self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class Qwen3VLVisionMLP(nn.Module):
    """Vision mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        from transformers.activations import ACT2FN
        hidden_dim = config.hidden_size
        intermediate_size = config.intermediate_size
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.linear_fc1 = build_colwise_linear(hidden_dim,
                                               intermediate_size,
                                               bias=True,
                                               dtype=dtype,
                                               device=device,
                                               quant_config=quantization_config,
                                               is_tp=True,
                                               prefix=add_prefix('linear_fc1', prefix))

        # gelu_pytorch_tanh
        self.act = ACT2FN[config.hidden_act]

        # down
        self.linear_fc2 = build_rowwise_linear(
            intermediate_size,
            hidden_dim,
            bias=True,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
            prefix=add_prefix('linear_fc2', prefix),
        )

    def forward(self, x):
        """forward."""
        return self.linear_fc2(self.act(self.linear_fc1(x)))


class Qwen3VLVisionBlock(nn.Module):
    """Vision block."""

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        dtype: torch.dtype = None,
        device: torch.device = None,
        prefix: str = '',
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)
        self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)

        self.attn = Qwen3VLVisionAttention(config, dtype=dtype, device=device, prefix=add_prefix('attn', prefix))

        self.mlp = Qwen3VLVisionMLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))

    def forward(self,
                hidden_states: torch.Tensor,
                cu_seqlens: torch.Tensor,
                rotary_pos_emb: torch.Tensor | None = None) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Qwen3VLVisionPatchMerger(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 use_postshuffle_norm=False,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = '') -> None:
        super().__init__()
        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
        self.use_postshuffle_norm = use_postshuffle_norm
        self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size,
                              eps=1e-6,
                              dtype=dtype,
                              device=device)
        self.linear_fc1 = build_colwise_linear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('linear_fc1', prefix),
        )
        self.act_fn = nn.GELU()
        self.linear_fc2 = build_rowwise_linear(
            self.hidden_size,
            config.out_hidden_size,
            bias=True,
            dtype=dtype,
            device=device,
            is_tp=True,
            prefix=add_prefix('linear_fc2', prefix),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
        x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return x


@vlm_model
class Qwen3VLVisionModel(nn.Module):
    """Vision transformer."""

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__()
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device)

        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device)
        self.num_grid_per_side = int(config.num_position_embeddings**0.5)

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device)

        self.blocks = nn.ModuleList([
            Qwen3VLVisionBlock(config,
                               layer_idx,
                               dtype=dtype,
                               device=device,
                               prefix=add_prefix(f'blocks.{layer_idx}', prefix)) for layer_idx in range(config.depth)
        ])
        self.merger = Qwen3VLVisionPatchMerger(config=config,
                                               use_postshuffle_norm=False,
                                               dtype=dtype,
                                               device=device,
                                               prefix=add_prefix('merger', prefix))

        if hasattr(config, 'deepstack_visual_indexes'):
            self.deepstack_visual_indexes = config.deepstack_visual_indexes
            self.deepstack_merger_list = nn.ModuleList([
                Qwen3VLVisionPatchMerger(config=config,
                                         use_postshuffle_norm=True,
                                         dtype=dtype,
                                         device=device,
                                         prefix=add_prefix(f'deepstack_merger_list.{dvi}', prefix))
                for dvi in range(len(config.deepstack_visual_indexes))
            ])

    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size

        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        """Rotary position embedding."""
        pos_ids = []

        for t, h, w in grid_thw:
            base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size)
            pos_ids.append(base if t == 1 else base.repeat(t, 1))

        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)

        return rotary_pos_emb

    # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474
    def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor:
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim
        device = self.pos_embed.weight.device

        outputs = []
        for t, h, w in grid_thw:
            h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device)
            w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device)

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

            # Create meshgrid view for all h, w vars
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij')
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij')

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
            w00 = 1 - dh_grid - w01

            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
            h_grid_idx = h_grid * num_grid_per_side

            indices = (h_grid_idx + w_grid).reshape(4, -1)
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
            weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device)

            embeds = self.pos_embed(indices)
            embeds *= weights
            combined = embeds.sum(dim=0)

            combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,
                pos_embeds: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.patch_embed(hidden_states)
        hidden_states = hidden_states + pos_embeds
        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
            if hasattr(self, 'deepstack_visual_indexes') and layer_num in self.deepstack_visual_indexes:
                deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num)
                deepstack_feature = self.deepstack_merger_list[deepstack_merge_idx](hidden_states)
                deepstack_feature_lists.append(deepstack_feature)

        hidden_states = self.merger(hidden_states)

        return hidden_states, deepstack_feature_lists


class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(
        self,
        config: PretrainedConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        prefix: str = '',
    ):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr

        # build preprocessor
        self.input_processor = Qwen3VLInputProcessor(self.config)

        # build vision model
        self.visual = Qwen3VLVisionModel(
            config.vision_config,
            dtype=dtype,
            device=device,
            prefix=add_prefix('visual', prefix),
        )

        # build text model
        self.language_model = Qwen3VLTextModel(config.text_config,
                                               dtype=dtype,
                                               device=device,
                                               prefix=add_prefix('language_model', prefix))

        # build lm_head
        self.lm_head = self.build_lm_head(config.text_config.hidden_size,
                                          config.text_config.vocab_size,
                                          bias=False,
                                          dtype=dtype,
                                          device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        mrope_position_ids: torch.Tensor = None,
        pixel_values: torch.Tensor = None,
        vis_cu_seqlens: torch.Tensor = None,
        vis_pos_emb: torch.Tensor = None,
        image_mask: torch.Tensor = None,
        pos_embeds: torch.Tensor = None,
        grid_thw: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""

        visual_pos_masks = None
        deepstack_visual_embeds = None
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

            if pixel_values is not None:
                dtype = inputs_embeds.dtype
                pixel_values = pixel_values.to(dtype)
                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))

                # get image embeds and deepstack visual embeds
                image_embeds, deepstack_visual_embeds = self.visual(pixel_values,
                                                                    cu_seqlens=vis_cu_seqlens,
                                                                    rotary_pos_emb=vis_pos_emb,
                                                                    pos_embeds=pos_embeds)

                # split image embeds per sample
                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
                image_embeds = torch.split(image_embeds, split_sizes)
                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)

                # mask and scatter to create final input embeddings
                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)

                visual_pos_masks = expanded_image_mask

        hidden_states = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            # args for deepstack
            visual_pos_masks=visual_pos_masks,
            deepstack_visual_embeds=deepstack_visual_embeds,
        )
        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.language_model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: torch.Tensor | None = None,
        context: StepContext = None,
    ):
        """Prepare input."""

        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        pixel_values = None
        vis_cu_seqlens = None
        vis_pos_emb = None
        image_mask = None
        grid_thw = None
        pos_embeds = None
        if context.input_multimodals is not None:
            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]
            # flatten batch
            mm_inputs = [item for sublist in mm_inputs for item in sublist]

            if len(mm_inputs) > 0:
                modality = mm_inputs[0].modality
                pixel_values = torch.cat([inp.data for inp in mm_inputs])

                image_token_id = mm_inputs[0].meta.get('image_token_id')
                video_token_id = mm_inputs[0].meta.get('video_token_id')
                mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id
                image_mask = (input_ids == mm_token_id)

                grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()
                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
                pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw)
                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                                         grid_thw[:, 0]).to(pixel_values.device)
                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
                vis_pos_emb = vis_pos_emb.repeat(1, 2)
                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())

        mrope_position_ids = getattr(context, 'mrope_position_ids', None)

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
            mrope_position_ids=mrope_position_ids,
            pixel_values=pixel_values,
            vis_cu_seqlens=vis_cu_seqlens,
            vis_pos_emb=vis_pos_emb,
            image_mask=image_mask,
            grid_thw=grid_thw,
            pos_embeds=pos_embeds,
        )

    @classmethod
    def rename_weight(cls, name: str) -> str:
        """Rename weight."""
        if name.startswith('model.language_model.'):
            return 'language_model.' + name[len('model.language_model.'):]
        elif name.startswith('model.visual.'):
            return 'visual.' + name[len('model.visual.'):]
        elif name.startswith('model.'):
            return name[len('model.'):]
        return name

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                if '.qkv.' in name:
                    param = params_dict[name]
                    q, k, v = param.weight_spliter(loaded_weight)
                    load_weight(param, q, shard_id='q')
                    load_weight(param, k, shard_id='k')
                    load_weight(param, v, shard_id='v')
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Make cudagraph buffers from forward inputs."""
        max_tokens = graph_meta.max_tokens

        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)

        return input_buffers

    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
        """Fill cudagraph buffers from forward inputs."""

        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)

        input_ids = kwargs.get('input_ids')
        num_tokens = input_ids.size(-1)
        new_batch_size = graph_meta.max_batchs

        is_decoding = graph_meta.is_decoding
        input_buffers = graph_meta.input_buffers
        mrope_position_ids = kwargs.get('mrope_position_ids', None)
        if mrope_position_ids is not None:
            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
            if is_decoding:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
            else:
                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']

        return new_inputs

    def _get_model_metas(self, context: StepContext):
        """Get model metas."""
        model_metas = context.model_metas
        if model_metas is None:
            batch_size = context.q_seqlens.numel()
            return [dict(mrope_delta=0)] * batch_size
        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]

    def _update_model_meta_decoding(self, context: StepContext):
        """Update model meta for decoding."""
        model_metas = self._get_model_metas(context)
        position_ids = context.position_ids

        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
        mrope_deltas = position_ids.new_tensor(mrope_deltas)
        mrope_position_ids = position_ids + mrope_deltas[None]
        mrope_position_ids = mrope_position_ids.expand(3, -1)

        context.mrope_position_ids = mrope_position_ids
        return model_metas

    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
        """Get mrope ids."""
        t, h, w = grid_thw
        h //= 2
        w //= 2
        stride = torch.tensor([h * w, w, 1], device=device)[:, None]
        size = torch.tensor([t, h, w], device=device)[:, None]
        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
        pos_ids = pos_ids // stride % size
        return pos_ids

    def _update_model_meta_prefilling(self, context: StepContext):
        """Update model meta for prefilling."""
        model_metas = self._get_model_metas(context)
        input_multimodals = context.input_multimodals
        if input_multimodals is None:
            input_multimodals = [None] * len(model_metas)
        position_ids = context.position_ids
        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
        mrope_position_ids = []
        new_model_metas = []
        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
            mm_data_list = []
            if input_mm is not None:
                mm_data_list.extend(input_mm.get('mm_data', []))

            if model_meta is None or 'mrope_delta' not in model_meta:
                mrope_delta = 0
            else:
                mrope_delta = model_meta['mrope_delta']

            pos_start = pos_ids[0].item()
            mrope_pos_ids = pos_ids + mrope_delta
            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()

            for mm_data in mm_data_list:
                if mm_data.modality == Modality.IMAGE:
                    grid_thw = mm_data.meta['grid_thw'][0].tolist()
                    _, h, w = grid_thw
                    h //= 2
                    w //= 2
                    num_pad = mm_data.end - mm_data.start - max(h, w)
                    mrope_delta -= num_pad
                    fill_start = mm_data.start - pos_start
                    fill_end = mm_data.end - pos_start
                    img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
                    img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
                    mrope_pos_ids[:, fill_end:] -= num_pad
                    mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids
                elif mm_data.modality == Modality.VIDEO:
                    video_token_id = self.config.video_token_id
                    grid_thw = mm_data.meta['grid_thw']

                    grid_thw = torch.repeat_interleave(grid_thw, grid_thw[:, 0], dim=0)
                    grid_thw[:, 0] = 1

                    position_ids_list = []
                    input_tokens = context.input_ids.tolist()[0]

                    st = 0
                    # treat each frame separately as a single image
                    for video_idx in range(grid_thw.shape[0]):
                        # text before video. e.g. <0.3 seconds><|vision_start|> ...
                        ed_video = input_tokens.index(video_token_id, st)
                        ed = ed_video
                        text_len = ed - st
                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0
                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx
                        position_ids_list.append(text_pos_ids)

                        # video frame.  ... <|video_end|>
                        t, h, w = (
                            grid_thw[video_idx][0],
                            grid_thw[video_idx][1] // 2,
                            grid_thw[video_idx][2] // 2,
                        )
                        video_pos_ids = self._get_multimodal_pos_ids(grid_thw[video_idx], pos_ids.device)
                        position_ids_list.append(video_pos_ids + text_len + st_idx)

                        st = ed + t * h * w

                    # text after video, <|vision_end|> ...
                    if st < len(input_tokens):
                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0
                        text_len = len(input_tokens) - st
                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx
                        position_ids_list.append(text_pos_ids)

                    mrope_pos_ids = torch.cat(position_ids_list, dim=1).reshape(3, -1)
                    mrope_delta = mrope_pos_ids.max() + 1 - pos_ids.size(0)
                    mrope_pos_ids += pos_start  # add back the original position offset

            mrope_position_ids.append(mrope_pos_ids)
            new_model_metas.append(dict(mrope_delta=mrope_delta))

        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
        context.mrope_position_ids = mrope_position_ids

        return new_model_metas

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: torch.Tensor | None = None,
                           context: StepContext = None):
        """Update model meta."""
        if context.is_decoding:
            return self._update_model_meta_decoding(context)
        else:
            return self._update_model_meta_prefilling(context)

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return self.input_processor


class Qwen3VLInputProcessor(BaseModelInputProcessor):
    """Qwen3 input processor."""

    def __init__(self, config: PretrainedConfig) -> None:
        self.config = config

    def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:
        """Make image MultiModalData."""
        pixel_values = input_mm['pixel_values']
        image_grid_thw = input_mm['image_grid_thw']
        offset = input_mm['offset']
        start = offset
        image_token_id = input_mm['image_token_id']
        num_pad = input_mm['image_tokens']
        if isinstance(num_pad, torch.Tensor):
            num_pad = num_pad.item()

        mm_data = MultiModalData(modality=Modality.IMAGE,
                                 data=pixel_values,
                                 start=start,
                                 end=start + num_pad,
                                 meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
        return mm_data

    def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:
        """Make video MultiModalData."""
        pixel_values_videos = input_mm['pixel_values_videos']
        video_grid_thw = input_mm['video_grid_thw']
        offset = input_mm['offset']
        start = offset
        video_token_id = input_mm['video_token_id']
        num_pad = input_mm['video_tokens']
        if isinstance(num_pad, torch.Tensor):
            num_pad = num_pad.item()

        mm_data = MultiModalData(modality=Modality.VIDEO,
                                 data=pixel_values_videos,
                                 start=start,
                                 end=start + num_pad,
                                 meta=dict(
                                     grid_thw=video_grid_thw,
                                     video_token_id=video_token_id,
                                 ))
        return mm_data

    def preprocess_input(self,
                         input_ids: List[int],
                         input_multimodals: List[Dict[str, Any]] = None,
                         **kwargs) -> PreprocessInputResult:
        """Prepare multimodal input."""
        if input_multimodals is None or len(input_multimodals) == 0:
            return input_ids, input_multimodals

        input_mm_data = []
        for input_mm in input_multimodals:
            modality = input_mm.get('modality')
            if modality == Modality.IMAGE:
                mm_data = self._make_image_mm_data(input_mm)
            elif modality == Modality.VIDEO:
                mm_data = self._make_video_mm_data(input_mm)
            input_mm_data.append(mm_data)

        result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data))

        return result


================================================
FILE: lmdeploy/pytorch/models/qwen3_vl_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix
from .qwen3_moe import Qwen3MoeModel
from .qwen3_vl import Qwen3VLForConditionalGeneration
from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding


class Qwen3VLMoeTextModel(Qwen3MoeModel):
    """Text part of Qwen3VL.

    not a pure text-only model, as DeepStack integrates visual features into the early hidden states.
    """

    def __init__(self,
                 config: PretrainedConfig,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 prefix: str = ''):
        super().__init__(config=config, dtype=dtype, device=device, prefix=prefix)

        # build rotary embedding
        # TODO: zhouxinyu, add triton kernel for interleaved mrope
        self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config, device=device)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        mrope_position_ids: torch.LongTensor = None,
        # args for deepstack
        visual_pos_masks: Optional[torch.Tensor] = None,
        deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        if mrope_position_ids is None:
            cos, sin = self.rotary_emb(hidden_states, position_ids)
        else:
            mrope_position_ids = mrope_position_ids.unsqueeze(1)
            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)

        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

            # add visual features to the hidden states of first several layers
            if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):
                hidden_states = hidden_states + residual
                hidden_states = self._deepstack_process(
                    hidden_states,
                    visual_pos_masks,
                    deepstack_visual_embeds[idx],
                )
                residual = None

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
                           visual_embeds: torch.Tensor):
        visual_pos_masks = visual_pos_masks.to(hidden_states.device)
        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
        local = torch.zeros_like(hidden_states)
        local.masked_scatter_(visual_pos_masks, visual_embeds)
        hidden_states += local
        return hidden_states


class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(
        self,
        config: PretrainedConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        prefix: str = '',
    ):
        super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device, prefix=prefix)

        self.language_model = Qwen3VLMoeTextModel(config.text_config,
                                                  dtype=dtype,
                                                  device=device,
                                                  prefix=add_prefix('language_model', prefix))

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""

        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    # modify from vllm qwen3vlmoe fused expert loading
    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                                   fused_expert_params_mapping: List):
        """Load weight of fused expert weights."""
        num_experts = self.config.text_config.num_experts

        for (param_name, weight_name) in fused_expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]

            loaded_weight = loaded_weight.transpose(-1, -2)  # no bias
            if 'gate_up' in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)
                w1 = loaded_weight[0]
                w3 = loaded_weight[1]
                for expert_id in range(num_experts):
                    load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate')
                    load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up')
            elif 'down' in name:
                w2 = loaded_weight
                for expert_id in range(num_experts):
                    load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert mapping
        num_experts = self.config.text_config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            # (param_name, weight_name, expert_id, shard_id)
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        # fused expert mapping
        fused_expert_params_mapping = [
            # (param_name, weight_name)
            ('.experts.gate_up.weight', '.experts.gate_up_proj'),
            ('.experts.down.weight', '.experts.down_proj'),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            name = name.replace('.block_sparse_moe.', '.mlp.')
            if '.experts' in name:
                is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name)
                if is_fused_expert:
                    self._load_weight_fused_experts(name,
                                                    loaded_weight,
                                                    params_dict,
                                                    fused_expert_params_mapping=fused_expert_params_mapping)
                else:
                    self._load_weight_experts(name,
                                              loaded_weight,
                                              params_dict,
                                              expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    if '.qkv.' in name:
                        param = params_dict[name]
                        q, k, v = param.weight_spliter(loaded_weight)
                        load_weight(param, q, shard_id='q')
                        load_weight(param, k, shard_id='k')
                        load_weight(param, v, shard_id='v')
                    else:
                        param = params_dict[name]
                        load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/sdar.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class SDARAttention(nn.Module):
    """attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.attention_bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()
        dllm_block_length = config.dllm_block_length

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
            block_sparse_size=dllm_block_length,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

        # q, k norm
        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
        )
        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class SDARMLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [config.intermediate_size, config.intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class SDARDecoderLayer(nn.Module):
    """Decode layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = SDARAttention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = SDARMLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class SDARModel(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            SDARDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class SDARForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length
        # build model
        self.model = SDARModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/sdar_moe.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
                                        build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class SDARMoeAttention(nn.Module):
    """attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
        # packed qkv
        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
        self.qkv_proj = build_qkv_proj(hidden_size,
                                       num_q_heads=num_heads,
                                       num_kv_heads=num_key_value_heads,
                                       head_size=head_dim,
                                       bias=config.attention_bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device,
                                       num_replicate_kv_heads=num_replicate_kv_heads)

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()
        dllm_block_length = config.dllm_block_length

        # attention
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=config.sliding_window,
            block_sparse_size=dllm_block_length,
        )

        # o_proj
        self.o_proj = build_o_proj(num_heads * head_dim,
                                   hidden_size,
                                   bias=config.attention_bias,
                                   quant_config=quantization_config,
                                   dtype=dtype,
                                   device=device,
                                   is_tp=True)

        # q, k norm
        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply q, k norm
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
        )
        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class SDARMoeMLP(nn.Module):
    """mlp."""

    def __init__(self,
                 config: PretrainedConfig,
                 intermediate_size: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.gate_up_proj = build_gateup_linear(
            config.hidden_size,
            [intermediate_size, intermediate_size],
            bias=False,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        self.act_fn = SiluAndMul(inplace=True)

        # down
        self.down_proj = build_down_linear(intermediate_size,
                                           config.hidden_size,
                                           bias=False,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.gate_up_proj(x)
        act = self.act_fn(gate_up)
        return self.down_proj(act)


class SDARMoeSparseMoeBlock(nn.Module):
    """SDARMoeSparseMoeBlock."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.moe_intermediate_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.renormalize = config.norm_topk_prob

        self.gate = build_rowwise_linear(
            self.hidden_dim,
            self.num_experts,
            bias=False,
            dtype=dtype,
            device=device,
            is_tp=False,
        )

        self.softmax_topk = SoftmaxTopK(
            self.top_k,
            n_groups=getattr(config, 'router_n_groups', -1),
        )

        self.experts = build_fused_moe(
            self.hidden_dim,
            self.ffn_dim,
            self.num_experts,
            top_k=self.top_k,
            renormalize=self.renormalize,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            all_reduce=True,
            layer_idx=layer_idx,
        )

    def forward(self, hidden_states: torch.Tensor):
        """forward."""
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)
        topk_weights, topk_ids = self.softmax_topk(router_logits)
        out_states = self.experts(
            hidden_states,
            topk_weights,
            topk_ids,
        )

        out_states = out_states.reshape(batch_size, sequence_length, -1)
        return out_states


class SDARMoeDecoderLayer(nn.Module):
    """Decode layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx
        quantization_config = getattr(config, 'quantization_config', None)

        # build attention layer
        self.self_attn = SDARMoeAttention(config, dtype=dtype, device=device)

        # build MLP
        if (layer_idx not in config.mlp_only_layers) and (config.num_experts > 0 and
                                                          (layer_idx + 1) % config.decoder_sparse_step == 0):
            self.mlp = SDARMoeSparseMoeBlock(config, layer_idx, dtype=dtype, device=device)
        else:
            self.mlp = SDARMoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       config.rms_norm_eps,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # build attention layer norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                config.rms_norm_eps,
                                                quant_config=quantization_config,
                                                dtype=dtype,
                                                device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class SDARMoeModel(nn.Module):
    """SDAR model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            SDARMoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class SDARMoeForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
        'gate_up_proj': [
            'gate_proj',
            'up_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length
        # build model
        self.model = SDARMoeModel(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
                             expert_params_mapping: List):
        """Load weight experts."""
        # load fused weights
        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
            break
        else:
            param = params_dict[name]
            load_weight(param, loaded_weight)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
            ('.gate_up_proj', '.gate_proj', 0),
            ('.gate_up_proj', '.up_proj', 1),
        ]

        # expert map
        num_experts = self.config.num_experts
        expert_params_mapping = []
        for exp_id in range(num_experts):
            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
            expert_params_mapping += [gate_param, up_param, down_param]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue

            if '.experts' in name:
                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
            else:
                for (param_name, weight_name, shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    load_weight(param, loaded_weight, shard_id=shard_id)
                    break
                else:
                    param = params_dict[name]
                    load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/siglip.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math
from typing import Iterable, Set, Tuple, Union

import torch
from torch import nn
from transformers import SiglipVisionConfig

from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight


class SiglipVisionEmbeddings(nn.Module):

    def __init__(self,
                 config: SiglipVisionConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(in_channels=config.num_channels,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size,
                                         padding='valid',
                                         dtype=dtype,
                                         device=device)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim, dtype=dtype, device=device)
        self.register_buffer('position_ids', torch.arange(self.num_positions).expand((1, -1)), persistent=False)

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """This method allows to interpolate the pre-trained position
        encodings, to be able to use the model on higher resolution images.
        This method is also adapted to support torch.jit tracing and no class
        embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        """  # noqa

        num_patches = embeddings.shape[1]
        num_positions = self.position_embedding.weight.shape[0]

        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embedding(self.position_ids)

        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = int(math.sqrt(num_positions))
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode='bicubic',
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings


class SiglipAttention(nn.Module):

    def __init__(self,
                 config: SiglipVisionConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs) -> None:
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        quantization_config = getattr(config, 'quantization_config', None)
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = build_qkv_proj(self.embed_dim,
                                       num_q_heads=self.num_heads,
                                       num_kv_heads=self.num_heads,
                                       head_size=self.head_dim,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       bias=True,
                                       device=device)

        self.out_proj = build_rowwise_linear(self.embed_dim,
                                             self.embed_dim,
                                             bias=True,
                                             quant_config=quantization_config,
                                             dtype=dtype,
                                             device=device,
                                             is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel."""
        batch_size, q_len, _ = hidden_states.size()
        qkv_states = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
        query_states = query_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)

        out = nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, scale=self.scale)
        out = out.transpose(1, 2).contiguous().view(batch_size, q_len, -1)
        attn_output = self.out_proj(out)

        return attn_output, None


class SiglipMLP(nn.Module):

    def __init__(self,
                 config: SiglipVisionConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs) -> None:
        super().__init__()
        from transformers.activations import ACT2FN
        self.config = config
        self.ctx_mgr = ctx_mgr
        self.activation_fn = ACT2FN[config.hidden_act]
        quantization_config = getattr(config, 'quantization_config', None)
        self.fc1 = build_colwise_linear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True,
                                        quant_config=quantization_config)
        self.fc2 = build_rowwise_linear(config.intermediate_size,
                                        config.hidden_size,
                                        bias=True,
                                        quant_config=quantization_config,
                                        dtype=dtype,
                                        device=device,
                                        is_tp=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """forward."""
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):

    def __init__(self,
                 config: SiglipVisionConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs) -> None:
        super().__init__()

        self.embed_dim = config.hidden_size

        self.self_attn = SiglipAttention(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.mlp = SiglipMLP(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, None]:
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, None


class SiglipEncoder(nn.Module):

    def __init__(self,
                 config: SiglipVisionConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None,
                 **kwargs) -> None:
        super().__init__()

        self.config = config
        num_hidden_layers = config.num_hidden_layers

        self.layers = nn.ModuleList([
            SiglipEncoderLayer(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
            for layer_idx in range(num_hidden_layers)
        ])

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        **kwargs,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        hidden_states = inputs_embeds

        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        **kwargs,
    ) -> None:
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)
        self.mlp = SiglipMLP(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: SiglipVisionConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        **kwargs,
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config, ctx_mgr=ctx_mgr, device=device, dtype=dtype)

        self.encoder = SiglipEncoder(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

        num_hidden_layers = config.num_hidden_layers
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(f'The original encoder only has {num_hidden_layers} '
                             f'layers, but you requested {len(self.encoder.layers)} layers.')

        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)

        self.use_head = (True if not hasattr(config, 'vision_use_head') else config.vision_use_head)
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = True,
    ) -> torch.Tensor:

        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
        last_hidden_state = self.encoder(inputs_embeds=hidden_states)

        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = 'pixel_values'

    def __init__(
        self,
        config: SiglipVisionConfig,
        ctx_mgr: StepContextManager,
        dtype: torch.dtype = None,
        device: torch.device = None,
        **kwargs,
    ) -> None:
        super().__init__()

        self.vision_model = SiglipVisionTransformer(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('qkv_proj', 'q_proj', 'q'),
            ('qkv_proj', 'k_proj', 'k'),
            ('qkv_proj', 'v_proj', 'v'),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is optional in SiglipVisionModel
            if (name.startswith('vision_model.post_layernorm') and self.vision_model.post_layernorm is None):
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith('vision_model.encoder.layers'):
                layer_idx = int(name.split('.')[3])
                if layer_idx >= layer_count:
                    continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


================================================
FILE: lmdeploy/pytorch/models/starcoder2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin


class Starcoder2Attention(nn.Module):
    """Rewrite module of Starcoder2Attention."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        num_heads = config.num_attention_heads
        num_key_value_heads = config.num_key_value_heads
        hidden_size = config.hidden_size
        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)

        # packed qkv
        self.qkv_proj = build_qkv_proj(
            hidden_size,
            num_q_heads=num_heads,
            num_kv_heads=num_key_value_heads,
            head_size=head_dim,
            bias=config.use_bias,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )

        # rotary embedding
        self.apply_rotary_pos_emb = ApplyRotaryEmb()

        # attention
        sliding_window = getattr(config, 'sliding_window', None)
        self.attn_fwd = Attention(
            num_heads,
            head_dim,
            num_kv_heads=num_key_value_heads,
            v_head_size=head_dim,
            sliding_window=sliding_window,
        )

        # o_proj
        self.o_proj = build_rowwise_linear(num_heads * head_dim,
                                           hidden_size,
                                           bias=config.use_bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_metadata: Any = None,
    ):
        """Rewrite of LlamaAttention.forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        # (-1, heads, head_dim)
        qkv_states = qkv_states.flatten(0, -2)
        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)

        # apply rotary embedding
        cos, sin = rotary_pos_emb
        query_states, key_states = self.apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            inplace=True,
        )

        # attention
        attn_output = self.attn_fwd(
            query_states,
            key_states,
            value_states,
            past_key_value[0],
            past_key_value[1],
            attn_metadata,
            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
            inplace=True,
        )
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)

        # o proj
        attn_output = self.o_proj(attn_output)
        return attn_output


class Starcoder2MLP(nn.Module):
    """mlp."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        quantization_config = getattr(config, 'quantization_config', None)
        # gate up
        self.c_fc = build_colwise_linear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
            dtype=dtype,
            device=device,
            quant_config=quantization_config,
            is_tp=True,
        )

        # silu and mul
        hidden_act = config.hidden_act
        if hidden_act is None:
            hidden_act = 'gelu_pytorch_tanh'
            assert hidden_act == 'gelu_pytorch_tanh'
        self.act_fn = nn.GELU(approximate='tanh')

        # down
        self.c_proj = build_rowwise_linear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=config.use_bias,
                                           quant_config=quantization_config,
                                           dtype=dtype,
                                           device=device,
                                           is_tp=True)

    def forward(self, x):
        """forward."""
        gate_up = self.c_fc(x)
        act = self.act_fn(gate_up)
        return self.c_proj(act)


class Starcoder2DecoderLayer(nn.Module):
    """Decoder layer."""

    def __init__(self,
                 config: PretrainedConfig,
                 layer_idx: int,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.layer_idx = layer_idx

        # build attention layer
        self.self_attn = Starcoder2Attention(config, dtype=dtype, device=device)

        # build MLP
        self.mlp = Starcoder2MLP(config, dtype=dtype, device=device)

        # build input layer norm
        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_epsilon, dtype=dtype, device=device)

        # build attention layer norm
        self.post_attention_layernorm = LayerNorm(config.hidden_size,
                                                  eps=config.norm_epsilon,
                                                  dtype=dtype,
                                                  device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
        past_key_value: Optional[List[torch.FloatTensor]],
        residual: Optional[torch.Tensor] = None,
        attn_metadata: Any = None,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            rotary_pos_emb=rotary_pos_emb,
            past_key_value=past_key_value,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        outputs = (hidden_states, residual)
        return outputs


class Starcoder2Model(nn.Module):
    """model."""

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size,
                                         config.hidden_size,
                                         self.padding_idx,
                                         dtype=dtype,
                                         device=device)

        # build all decode layers
        self.layers = nn.ModuleList([
            Starcoder2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
            for layer_idx in range(config.num_hidden_layers)
        ])

        # build norm
        self.norm = LayerNorm(config.hidden_size, eps=config.norm_epsilon, dtype=dtype, device=device)

        # build rotary embedding
        self.rotary_emb = build_rotary_embedding_from_config(config)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attn_metadata: Any = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        """Rewrite of LlamaModel.forward."""

        # token embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # rotary embedding
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        cos, sin = cos[0], sin[0]
        rotary_pos_emb = (cos, sin)

        # decoding
        residual = None
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            hidden_states, residual = decoder_layer(
                hidden_states,
                rotary_pos_emb=rotary_pos_emb,
                past_key_value=past_key_value,
                residual=residual,
                attn_metadata=attn_metadata,
            )

        # norm
        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.embed_tokens


class Starcoder2ForCausalLM(nn.Module, CudaGraphMixin):
    """ModelForCausalLM."""

    packed_modules_mapping = {
        'qkv_proj': [
            'q_proj',
            'k_proj',
            'v_proj',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build model
        self.model = Starcoder2Model(config, dtype=dtype, device=device)
        # build lm_head
        self.lm_head = build_rowwise_linear(config.hidden_size,
                                            config.vocab_size,
                                            bias=False,
                                            dtype=dtype,
                                            device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.lm_head(hidden_states)

    def update_weights(self):
        """Update weights."""
        self.lm_head.weight = self.model.embed_tokens.weight

    def get_input_embeddings(self):
        """Get input embeddings."""
        return self.model.get_input_embeddings()

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        # get input_ids, position_ids and attention metadatas
        input_ids = context.input_ids
        position_ids = context.position_ids
        attn_metadata = context.attn_metadata

        # process vision embeddings
        vision_embeddings = context.input_embeddings
        vision_embedding_indexing = context.input_embedding_indexing
        if vision_embeddings is not None and len(vision_embeddings) > 0:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)

        # inputs of forward
        return dict(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        # modify from vllm
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ('.qkv_proj', '.q_proj', 'q'),
            ('.qkv_proj', '.k_proj', 'k'),
            ('.qkv_proj', '.v_proj', 'v'),
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if 'rotary_emb.inv_freq' in name:
                continue
            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
                continue
            if self.config.tie_word_embeddings and 'lm_head.weight' in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                load_weight(param, loaded_weight, shard_id=shard_id)
                break
            else:
                param = params_dict[name]
                load_weight(param, loaded_weight)


================================================
FILE: lmdeploy/pytorch/models/utils/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/models/utils/cudagraph.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch
from torch import Tensor
from torch.profiler import record_function

from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager

BuffType = Dict[str, Tensor]


def _get_meta_flashattn(
        batch_size: int,
        max_seqlen_q: int,
        max_seqlen_k: int,
        num_heads_q: int,
        num_heads_kv: int,
        headdim: int,
        cache_seqlens: torch.Tensor,
        qkv_dtype=torch.bfloat16,
        headdim_v=None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_k_new: Optional[torch.Tensor] = None,
        page_size: Optional[int] = None,
        causal=True,
        window_size=(-1, -1),  # -1 means infinite context window
        num_splits=0,
):
    """Get scheduler metadata for flash attn."""
    from flash_attn_interface import get_scheduler_metadata

    metadata = get_scheduler_metadata(
        batch_size,
        max_seqlen_q,
        max_seqlen_k,
        num_heads_q,
        num_heads_kv,
        headdim,
        cache_seqlens,
        qkv_dtype=qkv_dtype,
        headdim_v=headdim_v,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_k_new=cu_seqlens_k_new,
        page_size=page_size,
        causal=causal,
        window_size=window_size,
        num_splits=num_splits,
    )
    return metadata


def next_power_of_2(n: int):
    """Return the smallest power of 2 greater than or equal to n."""
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n |= n >> 32
    n += 1
    return n


@dataclass
class CudaGraphMeta:
    """Meta info of cudagraph."""
    max_batchs: int
    max_tokens: int
    num_blocks: int
    is_decoding: int
    device: torch.device
    input_buffers: BuffType = None
    output_buffers: BuffType = None
    vocab_size: int = 1
    use_mla_fp8_cache: bool = False
    use_flash_mla: bool = False
    mla_index_topk: Optional[int] = None
    decode_query_len: int = 1
    use_fa3_decoding: bool = False


class CudaGraphMixin:
    """Mixin class to support cudagraph."""

    def support_cuda_graph(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Return True is model support cudagraph."""
        return attn_metadata.is_decoding

    def make_output_buffers(self, output):
        """Make output buffers."""
        if isinstance(output, torch.Tensor):
            output_buffers = dict(hidden_states=output)
        else:
            assert isinstance(output, Dict)
            output_buffers = output
        return output_buffers

    def update_meta_flashattn(self, graph_meta: CudaGraphMeta, block_size: int, max_seqlen_k: int,
                              cache_seqlens: torch.Tensor):
        """Update meta flashattn."""
        ctx_mgr = get_step_ctx_manager()
        step_ctx = ctx_mgr.current_context()
        model_config = step_ctx.model_config
        batch_size = graph_meta.max_batchs
        max_seqlen_q = graph_meta.decode_query_len
        sliding_window = model_config.sliding_window
        num_attention_heads = model_config.num_attention_heads
        num_key_value_heads = model_config.num_key_value_heads
        headdim = model_config.head_dim
        torch_dtype = model_config.dtype
        if sliding_window is None:
            window_size = (-1, -1)
        elif isinstance(sliding_window, int):
            window_size = (sliding_window, sliding_window)
        cache_seqlens = cache_seqlens.to(torch.int32)
        scheduler_metadata = _get_meta_flashattn(
            batch_size=batch_size,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            num_heads_q=num_attention_heads,
            num_heads_kv=num_key_value_heads,
            headdim=headdim,
            cache_seqlens=cache_seqlens,
            qkv_dtype=torch_dtype,
            page_size=block_size,
            window_size=window_size,
        )
        return scheduler_metadata

    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: List, **kwargs) -> BuffType:
        """Make cudagraph buffers from forward inputs."""
        max_batches = graph_meta.max_batchs
        max_tokens = graph_meta.max_tokens
        num_blocks = graph_meta.num_blocks
        device = graph_meta.device
        decode_query_len = graph_meta.decode_query_len

        input_buffers: BuffType = dict()
        input_buffers['input_ids'] = torch.randint(0,
                                                   graph_meta.vocab_size, (1, max_tokens),
                                                   dtype=torch.int64,
                                                   device=device)
        input_buffers['position_ids'] = torch.zeros((1, max_tokens), dtype=torch.int64, device=device)

        # flash_mla requires block_offsets and kv_lens int32
        input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device)
        input_buffers['qkv_lens'] = torch.zeros(3, max_batches, dtype=torch.int32, device=device)

        input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0]
        input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1]
        input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2]
        input_buffers['qkv_seqlens'] = input_buffers['qkv_lens'][1:]
        input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device)

        input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device)
        input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0]
        input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1]

        if graph_meta.use_flash_mla is True:
            import flash_mla

            # create buffers for flash mla
            num_attention_heads = self.config.num_attention_heads
            index_topk = graph_meta.mla_index_topk
            num_heads_q = None if index_topk is None else num_attention_heads
            input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata(
                torch.ones(max_batches, dtype=torch.int32, device=device),
                num_attention_heads * decode_query_len,
                num_heads_k=1,
                num_heads_q=num_heads_q,
                is_fp8_kvcache=graph_meta.use_mla_fp8_cache,
                topk=index_topk)

        # use fa3 decode kernel for spec decode
        elif graph_meta.use_fa3_decoding is True:
            block_size = past_key_values[0][0].size(1)
            input_buffers['scheduler_metadata'] = self.update_meta_flashattn(graph_meta,
                                                                             block_size=block_size,
                                                                             max_seqlen_k=decode_query_len,
                                                                             cache_seqlens=input_buffers['kv_seqlens'])

        return input_buffers

    @record_function('fill_buffers_cudagraph')
    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor,
                               past_key_values: List, attn_metadata: Any, inputs_embeds: Tensor,
                               **kwargs) -> Dict[str, Tensor]:
        """Fill cudagraph buffers from forward inputs."""

        block_offsets: Tensor = attn_metadata.block_offsets
        q_start_loc: Tensor = attn_metadata.q_start_loc
        q_seqlens: Tensor = attn_metadata.q_seqlens
        kv_seqlens: Tensor = attn_metadata.kv_seqlens
        input_buffers: BuffType = graph_meta.input_buffers

        batch_size, num_blocks = block_offsets.size()
        num_tokens = input_ids.size(-1)
        decode_query_len = graph_meta.decode_query_len
        # fill buffer
        input_buffers['input_ids'].random_(0, graph_meta.vocab_size)
        input_buffers['input_ids'][:, :num_tokens] = input_ids
        input_buffers['position_ids'][:, :num_tokens] = position_ids
        input_buffers['block_offsets'][:batch_size, :num_blocks] = block_offsets

        qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens))
        input_buffers['qkv_lens'].zero_()
        input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs)
        input_buffers['qkv_lens'][:, :batch_size] = qkv
        input_buffers['cu_seqlens'][:, 1:] = input_buffers['qkv_seqlens'].cumsum(1)
        if inputs_embeds is not None:
            emb_size = inputs_embeds.size(-1)
            if 'inputs_embeds' not in input_buffers:
                max_num_tokens = input_buffers['input_ids'].size(-1)
                input_buffers['inputs_embeds'] = inputs_embeds.new_zeros(1, max_num_tokens, emb_size)
            input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds

        # create inputs
        new_batch_size = input_buffers['block_offsets'].size(0)
        attn_metadata.block_offsets = input_buffers['block_offsets']
        attn_metadata.q_start_loc = input_buffers['q_start_loc']
        attn_metadata.q_seqlens = input_buffers['q_seqlens']
        attn_metadata.kv_seqlens = input_buffers['kv_seqlens']
        attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q']
        attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k']

        if graph_meta.use_flash_mla is True:
            import flash_mla
            num_attention_heads = self.config.num_attention_heads
            index_topk = graph_meta.mla_index_topk
            num_heads_q = None if index_topk is None else num_attention_heads
            tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
                attn_metadata.kv_seqlens.to(torch.int32),
                num_attention_heads * decode_query_len,
                num_heads_k=1,
                num_heads_q=num_heads_q,
                is_fp8_kvcache=graph_meta.use_mla_fp8_cache,
                topk=index_topk)
            # here we use copy_ instead of = to avoid using new allocated mem for cuda graph
            input_buffers['tile_scheduler_metadata'].copy_(tile_scheduler_metadata)
            input_buffers['num_splits'][:new_batch_size + 1].copy_(num_splits[:new_batch_size + 1])
            attn_metadata.tile_scheduler_metadata = input_buffers['tile_scheduler_metadata']
            attn_metadata.num_splits = input_buffers['num_splits']

        # use fa3 decode kernel for spec decode
        elif graph_meta.use_fa3_decoding is True:
            block_size = past_key_values[0][0].size(1)
            scheduler_metadata = self.update_meta_flashattn(
                graph_meta,
                block_size=block_size,
                max_seqlen_k=attn_metadata.max_kv_seqlen,
                cache_seqlens=input_buffers['kv_seqlens'],
            )
            assert scheduler_metadata.shape == input_buffers['scheduler_metadata'].shape
            input_buffers['scheduler_metadata'].copy_(scheduler_metadata)
            attn_metadata.scheduler_metadata = input_buffers['scheduler_metadata']

        new_inputs = dict(
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
        )

        new_inputs['input_ids'] = input_buffers['input_ids']
        new_inputs['position_ids'] = input_buffers['position_ids']

        if inputs_embeds is not None:
            new_inputs['inputs_embeds'] = input_buffers['inputs_embeds']

        new_inputs.update(kwargs)
        return new_inputs

    def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepContext):
        """Update step context with input buffers."""
        input_buffers = graph_meta.input_buffers
        local_adapter_ids = context.local_adapter_ids
        if local_adapter_ids is not None:
            if input_buffers['local_adapter_ids'].data_ptr() != local_adapter_ids.data_ptr():
                input_buffers['local_adapter_ids'].fill_(0)
            batch_size = local_adapter_ids.size(0)
            input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids
            context.local_adapter_ids = input_buffers['local_adapter_ids']
        context.q_seqlens = input_buffers['q_seqlens']
        context.kv_seqlens = input_buffers['kv_seqlens']
        context.q_start_loc = input_buffers['q_start_loc']

    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs):
        """Get outputs from buffers."""
        num_tokens = input_ids.size(-1)
        outputs = dict()
        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]
        if output_buffers.get('all_routed_experts', None) is not None:
            outputs['all_routed_experts'] = output_buffers['all_routed_experts'][:num_tokens, ...].clone()
        return outputs


================================================
FILE: lmdeploy/pytorch/models/utils/micro_batch.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools

import torch


def enable_micro_batch(param_name, index=-1):
    """Decorator factory to enable micro-batch computation."""

    def decorator(func):

        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            if index != -1 and len(args) > index:
                inputs = args[index]
            else:
                inputs = kwargs.get(param_name, None)

            if isinstance(inputs, list):
                # Apply forward computation to each micro-batch
                results = []
                for input in inputs:
                    if index != -1 and len(args) > index:
                        args = args[0:index] + (input, ) + args[index + 1:]
                    else:
                        kwargs[param_name] = input
                    result = func(self, *args, **kwargs)
                    results.append(result)
                return results
            else:
                # If not a list, directly apply the forward computation
                return func(self, *args, **kwargs)

        return wrapper

    return decorator


def split_batch(func, param_name, index=-1, num_splits=2):
    """Decorator to split along the 0th dimension into a specified number of
    chunks."""

    def wrapper(*args, **kwargs):
        if index != -1 and len(args) > index:
            inputs = args[index]
        else:
            inputs = kwargs.get(param_name, None)

        if inputs is not None:
            split_inputs = list(torch.chunk(inputs, num_splits, dim=0))
            if index != -1 and len(args) > index:
                args = args[0:index] + (split_inputs, ) + args[index + 1:]
            else:
                kwargs[param_name] = split_inputs

            results = func(*args, **kwargs)
            return torch.cat(results, dim=0)
        else:
            return func(*args, **kwargs)

    return wrapper


================================================
FILE: lmdeploy/pytorch/models/utils/model.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Iterable, List, Optional, Tuple

import torch

from lmdeploy.pytorch.config import QuantizationConfig
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, StepContext
from lmdeploy.pytorch.models.patch import get_build_model_context
from lmdeploy.pytorch.nn.embedding import ParallelEmbedding
from lmdeploy.pytorch.nn.linear import build_rowwise_linear


class BaseModelMetaProcessor:
    """Model meta processor base class."""

    def update_inputs(self, inputs: ModelInputs, device: torch.device) -> ModelInputs:
        """Update model inputs."""
        return inputs

    def update_delta(self, inputs: ModelInputs, delta: ModelInputsDelta) -> ModelInputs:
        """Update model inputs for delta."""
        return inputs

    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
        """Merge model inputs with deltas."""
        return inputs


class DeployModelMixin:

    def forward(self, *args, **kwargs):
        """Forward of model."""
        raise NotImplementedError('Not Implemented')

    def prepare_inputs_for_generation(
        self,
        past_key_values: List[List[torch.Tensor]],
        inputs_embeds: Optional[torch.Tensor] = None,
        context: StepContext = None,
    ):
        """Prepare input."""
        raise NotImplementedError('Not Implemented')

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        """Load weights."""
        raise NotImplementedError('Not Implemented')

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return hidden_states

    @classmethod
    def rename_weight(cls, name: str) -> str:
        """Rename weight."""
        return name

    def update_weights(self):
        """Update weights."""
        pass

    def update_model_metas(self,
                           past_key_values: List[List[torch.Tensor]],
                           inputs_embeds: Optional[torch.Tensor] = None,
                           context: StepContext = None):
        """Update model meta."""
        return None

    def get_input_processor(self) -> BaseModelInputProcessor:
        """Get input processor."""
        return None

    def get_modelmeta_processor(self) -> BaseModelMetaProcessor:
        """Get model meta preprocessor."""
        return BaseModelMetaProcessor()

    @classmethod
    def update_quant_config(cls, quant_config: QuantizationConfig):
        """Update quant config."""
        if quant_config is None:
            return
        ignored_layers = [cls.rename_weight(name) for name in quant_config.ignored_layers]

        added_ignore_layers = set()

        for layer_name in ignored_layers:
            if '.q_proj' in layer_name:
                added_ignore_layers.add(layer_name.replace(
                    '.q_proj',
                    '.qkv_proj',
                ))
            elif '.gate_proj' in layer_name:
                if '.experts' in layer_name:
                    added_ignore_layers.add(layer_name.split('.experts', 1)[0] + '.experts')
                else:
                    added_ignore_layers.add(layer_name.replace('.gate_proj', '.gate_up_proj'))
            elif '.down_proj' in layer_name:
                if '.experts' in layer_name:
                    added_ignore_layers.add(layer_name.split('.experts', 1)[0] + '.experts')
                else:
                    added_ignore_layers.add(layer_name)

        added_ignore_layers = list(added_ignore_layers)

        ignored_layers.extend(added_ignore_layers)
        quant_config.ignored_layers = ignored_layers

        return quant_config


class DeployModelMixinV1(DeployModelMixin):

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        head_dtype = self.lm_head.weight.dtype
        if hidden_states.dtype != head_dtype:
            hidden_states = hidden_states.to(dtype=head_dtype)
        hidden_states = self.lm_head(hidden_states)
        return hidden_states

    def get_input_embeddings(self):
        """Get embeds."""
        raise NotImplementedError('Not Implemented')

    def update_weights(self):
        """Update weights."""
        if getattr(self.config, 'tie_word_embeddings', False):
            self.lm_head.weight = self.get_input_embeddings().weight

    def build_lm_head(self,
                      hidden_size: int,
                      vocab_size: int,
                      bias: bool = False,
                      dtype: Optional[torch.dtype] = None,
                      device: Optional[torch.device] = None,
                      **kwargs):
        """Build LM Head."""
        bm_ctx = get_build_model_context()
        head_dtype = torch.float32 if bm_ctx.fp32_lm_head else dtype
        lm_head = build_rowwise_linear(
            hidden_size,
            vocab_size,
            bias,
            dtype=head_dtype,
            device=device,
            **kwargs,
        )
        return lm_head


def vlm_model(vlm_cls):
    if not issubclass(vlm_cls, torch.nn.Module):
        raise ValueError('Only subclasses of nn.Module can be decorated with @vlm_model.')

    @functools.wraps(vlm_cls)
    def wrapper(*args, **kwargs):
        bm_ctx = get_build_model_context()
        disable_vision_encoder = bm_ctx.disable_vision_encoder
        if disable_vision_encoder:
            mod = torch.nn.Identity()
            mod._is_dummy_mod = True
            return mod
        else:
            return vlm_cls(*args, **kwargs)

    return wrapper


def build_embedding(vocab_size: int,
                    hidden_size: int,
                    padding_idx: int,
                    dtype: torch.dtype = None,
                    device: torch.device = None,
                    is_tp: bool = False,
                    **kwargs):
    """Build embedding."""
    bm_ctx = get_build_model_context()

    # run with fp32 only when share weights with lm_head
    force_dtype = None
    if bm_ctx.fp32_lm_head and bm_ctx.tie_word_embeddings:
        force_dtype = torch.float32

    return ParallelEmbedding(
        vocab_size,
        hidden_size,
        padding_idx,
        dtype=dtype,
        device=device,
        is_tp=is_tp,
        force_dtype=force_dtype,
        **kwargs,
    )


================================================
FILE: lmdeploy/pytorch/models/whisper.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adpated from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py

import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.nn import LayerNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear


class WhisperAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper."""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        config: PretrainedConfig = None,
        dtype: torch.dtype = None,
        device: torch.device = None,
    ) -> None:
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}'
                             f' and `num_heads`: {num_heads}).')
        self.scaling = self.head_dim**-0.5

        # packed qkv
        # TODO, zhouxinyu, hf whisper hard-code k_proj bias = False, may double check
        self.qkv_proj = build_qkv_proj(self.embed_dim,
                                       num_q_heads=self.num_heads,
                                       num_kv_heads=self.num_heads,
                                       head_size=self.head_dim,
                                       bias=bias,
                                       quant_config=quantization_config,
                                       dtype=dtype,
                                       device=device)

        # o_proj
        self.out_proj = build_rowwise_linear(self.embed_dim,
                                             self.embed_dim,
                                             bias=bias,
                                             quant_config=quantization_config,
                                             dtype=dtype,
                                             device=device,
                                             is_tp=True)

    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """forward."""
        # qkv proj
        qkv_states = self.qkv_proj(hidden_states)
        q, k, v = self.qkv_proj.split_qkv(qkv_states)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        q = q * self.scaling

        # attention
        attn_output = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=1.0)

        # o proj
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.flatten(-2, -1)
        attn_output = self.out_proj(attn_output)
        return attn_output


class WhisperEncoderLayer(nn.Module):

    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:
        super().__init__()
        self.config = config
        quantization_config = getattr(config, 'quantization_config', None)

        self.act = ACT2FN[config.activation_function]
        self.embed_dim = config.d_model

        self.self_attn = WhisperAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            config=config,
            dtype=dtype,
            device=device,
        )
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, dtype=dtype, device=device)
        self.fc1 = build_colwise_linear(
            self.embed_dim,
            config.encoder_ffn_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )
        self.fc2 = build_rowwise_linear(
            config.encoder_ffn_dim,
            self.embed_dim,
            bias=True,
            quant_config=quantization_config,
            dtype=dtype,
            device=device,
        )
        self.final_layer_norm = LayerNorm(self.embed_dim, dtype=dtype, device=device)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        """
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.act(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


================================================
FILE: lmdeploy/pytorch/multimodal/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .data_type import MultiModalData

__all__ = ['MultiModalData']


================================================
FILE: lmdeploy/pytorch/multimodal/data_type.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Union

from torch import Tensor

from lmdeploy.vl.constants import Modality

NestedTensor = Union[Tensor, List[Tensor]]


@dataclass
class MultiModalData:
    data: NestedTensor
    start: int
    end: int = None
    meta: Dict[str, Any] = None

    modality: Modality = Modality.IMAGE

    def __post_init__(self):
        if self.end is None:
            self.end = self.start

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        out_dict = dict()
        for f in fields(self):
            k = f.name
            if k in ('data', 'meta'):
                continue
            v = getattr(self, k)
            out_dict[k] = v

        if isinstance(self.data, Tensor):
            data = self.data.to(device=device, non_blocking=non_blocking)
        else:
            data = [d.to(device=device, non_blocking=non_blocking) for d in self.data]
        out_dict['data'] = data

        new_meta = None
        if self.meta is not None:
            new_meta = dict()
            for k, v in self.meta.items():
                if isinstance(v, Tensor):
                    v = v.to(device=device, non_blocking=non_blocking)
                elif hasattr(v, 'to_device'):
                    v = v.to_device(device=device, non_blocking=non_blocking)
                new_meta[k] = v

        out_dict['meta'] = new_meta
        return MultiModalData(**out_dict)


MultiModalInputs = Dict[str, List[MultiModalData]]


================================================
FILE: lmdeploy/pytorch/nn/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# attention module is modified from:
# https://github.com/vllm-project/vllm/blob/main/vllm/attention/
from .activation import GeluAndMul, SiluAndMul  # noqa: F401
from .attention import Attention, FlashAttention  # noqa: F401
from .embedding import ParallelEmbedding  # noqa: F401
from .norm import LayerNorm, RMSNorm  # noqa: F401
from .rotary_embedding import ApplyRotaryEmb  # noqa: F401
from .rotary_embedding import RopeType  # noqa: F401
from .rotary_embedding import YarnParameters  # noqa: F401
from .rotary_embedding import build_rotary_embedding  # noqa: F401
from .rotary_embedding import build_rotary_embedding_from_config  # noqa: F401
from .rotary_embedding import build_rotary_params  # noqa: F401


================================================
FILE: lmdeploy/pytorch/nn/activation.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor, nn

from ..backends import OpType, get_backend


class SiluAndMul(nn.Module):
    """Silu and elementwise multiple."""

    def __init__(self, inplace: bool = True):
        super().__init__()
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.SiluAndMul)
        self.impl = builder.build(inplace)

    def forward(self, x: Tensor):
        """forward."""
        return self.impl.forward(x)


class GeluAndMul(nn.Module):
    """Gelu and elementwise multiple."""

    def __init__(self, approximate: str = 'none'):
        super().__init__()
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.GeluAndMul)
        self.impl = builder.build(approximate)

    def forward(self, x: Tensor):
        """forward."""
        return self.impl.forward(x)


================================================
FILE: lmdeploy/pytorch/nn/attention.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn

from lmdeploy.pytorch.distributed import get_tp_world_rank

from ..backends import OpType, get_backend
from ..backends.attention import AttentionMetadata
from .utils import get_distribute_size


def _update_num_heads(num_heads: int, num_kv_heads: int):
    """Update heads."""
    world_size, rank = get_tp_world_rank('attn')
    num_heads = get_distribute_size(num_heads, world_size, rank)
    num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)
    return num_heads, num_kv_heads


class Attention(nn.Module):
    """Attention layer."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_size: int = None,
        alibi: bool = False,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        causal: bool = True,
        use_flash_mla: bool = False,
        learnable_sink: bool = False,
        block_sparse_size: int = 1,
        **kwargs,
    ):
        super().__init__()
        if num_kv_heads is None:
            num_kv_heads = num_heads
        if v_head_size is None:
            v_head_size = head_size
        self.origin_num_heads = num_heads
        num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)
        self.num_heads = num_heads

        layer_backend = get_backend()
        impl_builder = layer_backend.get_layer_impl_builder(OpType.PagedAttention)

        self.impl = impl_builder.build(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_size=v_head_size,
            alibi=alibi,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            causal=causal,
            use_flash_mla=use_flash_mla,
            learnable_sink=learnable_sink,
            block_sparse_size=block_sparse_size,
            **kwargs,
        )

        if alibi:
            self.alibi_ready = False
        else:
            self.alibi_ready = True

    def _lazy_init(self, device):
        """Lazy init."""
        if not self.alibi_ready:
            _, rank = get_tp_world_rank('attn')
            start = self.num_heads * rank
            end = start + self.num_heads
            alibi_slopes = self.impl.make_alibi_slopes(start,
                                                       end,
                                                       self.origin_num_heads,
                                                       alibi_scale=1,
                                                       dtype=torch.float32,
                                                       device=device)
            self.impl.set_alibi_slopes(alibi_slopes)
            self.alibi_ready = True

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        k_scales_zeros: torch.Tensor = None,
        v_scales_zeros: torch.Tensor = None,
        s_aux: torch.Tensor = None,
        nsa_indices: torch.Tensor = None,
        inplace: bool = True,
    ) -> torch.Tensor:
        """forward."""
        self._lazy_init(query.device)

        kwargs = dict()
        if nsa_indices is not None:
            kwargs['nsa_indices'] = nsa_indices
        if s_aux is not None:
            kwargs['learnable_sink'] = s_aux
        return self.impl.forward(
            query,
            key,
            value,
            k_cache,
            v_cache,
            attn_metadata=attn_metadata,
            k_scales_zeros=k_scales_zeros,
            v_scales_zeros=v_scales_zeros,
            inplace=inplace,
            **kwargs,
        )

    @staticmethod
    def update_meta_flashmla(attn_metadata: AttentionMetadata, num_attention_heads):
        get_backend().update_meta_flashmla(attn_metadata, num_attention_heads)


class FlashAttention(nn.Module):
    """Flash attention w/o paging."""

    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        scale: float = None,
        num_kv_heads: int = None,
        v_head_dim: int = None,
        causal: bool = True,
        sliding_window: int = None,
        logit_softcapping: float = 0.0,
        **kwargs,
    ):
        super().__init__()
        if num_kv_heads is None:
            num_kv_heads = num_heads
        if v_head_dim is None:
            v_head_dim = head_dim
        num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)

        layer_backend = get_backend()

        impl_builder = layer_backend.get_layer_impl_builder(OpType.FlashAttention)

        self.impl = impl_builder.build(
            num_heads=num_heads,
            head_dim=head_dim,
            scale=scale,
            num_kv_heads=num_kv_heads,
            v_head_dim=v_head_dim,
            causal=causal,
            sliding_window=sliding_window,
            logit_softcapping=logit_softcapping,
            **kwargs,
        )

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                q_start_loc: torch.Tensor,
                q_seqlens: torch.Tensor,
                kv_start_loc: torch.Tensor = None,
                kv_seqlens: torch.Tensor = None,
                max_q_seqlen: int = None) -> torch.Tensor:
        """forward."""

        if max_q_seqlen is None:
            max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))

        if kv_start_loc is None and kv_seqlens is None:
            kv_start_loc = q_start_loc
            kv_seqlens = q_seqlens

        assert kv_start_loc is not None
        assert kv_seqlens is not None

        return self.impl.forward(
            query,
            key,
            value,
            q_start_loc=q_start_loc,
            q_seqlens=q_seqlens,
            kv_start_loc=kv_start_loc,
            kv_seqlens=kv_seqlens,
            max_q_seqlen=max_q_seqlen,
        )


================================================
FILE: lmdeploy/pytorch/nn/embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_dist_group, get_dist_manager, get_tp_world_rank
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

DEFAULT_VOCAB_PADDING_SIZE = 64


def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
    """Pad the vocab size to the given value."""
    return ((vocab_size + pad_to - 1) // pad_to) * pad_to


class ParallelEmbedding(nn.Module):

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        padding_idx: int,
        dtype: torch.dtype = None,
        device: torch.device = None,
        is_tp: bool = False,
        padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
        layer_type: str = 'attn',
        force_dtype: torch.dtype = None,
    ):
        self.dist_ctx = get_dist_manager().current_context()
        super().__init__()

        self.is_tp = is_tp
        self.vocab_size = vocab_size
        self.padding_size = padding_size
        if padding_idx is not None:
            if padding_idx < 0:
                padding_idx = vocab_size + padding_idx
            assert padding_idx >= 0 and padding_idx < vocab_size
        self.padding_idx = padding_idx

        dist_cfg = get_dist_manager().current_config()
        _, self.rank = get_tp_world_rank(layer_type)
        self.tp, _ = dist_cfg.get_tp_by_layer(layer_type)

        dist_group = get_dist_group(layer_type=layer_type)
        self.tp_group = dist_group.gpu_group

        if is_tp and self.tp > 1:
            self.vocab_size_padded = pad_vocab_size(self.vocab_size, self.padding_size)
            assert self.vocab_size_padded % self.tp == 0, \
                f'vocab_size_padded({self.vocab_size_padded}) must be divisible by tp({self.tp})'
            self.vocab_size_padded = self.vocab_size_padded // self.tp
        else:
            self.vocab_size_padded = self.vocab_size

        self.out_dtype = dtype
        self.start_index = self.rank * self.vocab_size_padded
        self.end_index = (self.rank + 1) * self.vocab_size_padded
        weight_dtype = force_dtype or dtype
        self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, weight_dtype, device))
        self.weight.weight_loader = self.weight_loader

        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.Embedding)
        self.impl = builder.build(self.start_index, self.end_index)

        self.all_reduce = self.is_tp and self.tp > 1

    @staticmethod
    def create_weight(vocab_size: int, hidden_size: int, dtype: torch.dtype = None, device: torch.device = None):
        """Create weight."""
        if dtype is None:
            dtype = torch.float16
        if device is None:
            device = 'cuda'
        weight = torch.nn.Parameter(torch.zeros((vocab_size, hidden_size), dtype=dtype, device=device),
                                    requires_grad=False)
        return weight

    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader for rowwise embedding."""
        loaded_weight = loaded_weight.to(param.device)

        shard_size = self.vocab_size_padded
        if self.end_index > loaded_weight.shape[0]:
            shard_size = loaded_weight.shape[0] - self.start_index

        loaded_weight = loaded_weight.narrow(0, self.start_index, shard_size)
        param[:loaded_weight.shape[0]].data.copy_(loaded_weight)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        if not self.all_reduce:
            default_weight_loader(param, loaded_weight)
            if self.padding_idx is not None:
                self.weight[self.padding_idx] = 0
        else:
            self._weight_loader_tp_rowwise(param, loaded_weight)
            if (self.padding_idx is not None and self.padding_idx >= self.start_index
                    and self.padding_idx < self.end_index):
                self.weight[self.padding_idx - self.start_index] = 0

    def forward(self, x: torch.Tensor):
        embeddings = self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group)
        if self.out_dtype is not None and embeddings.dtype != self.out_dtype:
            embeddings = embeddings.to(dtype=self.out_dtype)
        return embeddings


================================================
FILE: lmdeploy/pytorch/nn/eplb.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch


class EPLBDispatchInfo:

    def __init__(self, info) -> None:
        self.info = info


class EPLBManager:
    eplb = None

    @classmethod
    def init_global_eplb_metadata(cls, ep_size: int, num_routed_experts: int, num_hidden_layers: int):
        assert ep_size > 1, 'eplb requires ep_size > 1'
        from dlblas.layers.moe import eplb
        EPLBManager.eplb = eplb
        eplb.init_global_eplb_metadata(ep_size=ep_size,
                                       num_routed_experts=num_routed_experts,
                                       num_hidden_layers=num_hidden_layers)

    @classmethod
    def num_physical_experts(cls) -> int:
        return EPLBManager.eplb.get_global_eplb_metadata().num_physical_experts()

    @classmethod
    def topk_ids_logical_to_physical(cls, topk_ids: torch.Tensor, eplb_dispatch_info: EPLBDispatchInfo):
        return EPLBManager.eplb.topk_ids_logical_to_physical(topk_ids=topk_ids, info=eplb_dispatch_info.info)

    @classmethod
    def get_dispatch_info(cls, ep_rank, layer_idx) -> EPLBDispatchInfo:
        info = EPLBManager.eplb.EPLBDispatchInfo.init_new(ep_rank=ep_rank, layer_idx=layer_idx)
        return EPLBDispatchInfo(info)


================================================
FILE: lmdeploy/pytorch/nn/gated_delta.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Sequence, Tuple

import torch
from torch import nn
from torch.profiler import record_function

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs):
    # TODO: used custom kernel
    from fla.modules import FusedRMSNormGated
    try:
        # avoid unwanted specialize
        from fla.modules.fused_norm_gate import layer_norm_gated_fwd_kernel
        keys = layer_norm_gated_fwd_kernel.fn.keys
        if 'NB' in keys:
            keys.remove('NB')
    except Exception:
        logger.debug('patch layer_norm_gated_fwd_kernel autotuning failed.')
    return FusedRMSNormGated(hidden_size, eps=eps, **kwargs)


class GatedDeltaMeta:

    def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tensor, attn_metadata: Any):
        self.num_tokens = num_tokens
        self.is_decoding = attn_metadata.is_decoding
        self.cu_seqlens = attn_metadata.cu_seqlens_q
        device = self.cu_seqlens.device

        # get seq_idx (1, num_tokens)
        seqlens = attn_metadata.q_seqlens
        batch_size = seqlens.numel()
        batch_idx = torch.arange(0, batch_size, dtype=torch.int32, device=device)
        self.seq_idx = torch.repeat_interleave(batch_idx, seqlens, output_size=num_tokens)[None]

        # conv_idx
        range_idx = torch.arange(-conv_kernel_size, 0, device=device)
        self.conv_idx = self.cu_seqlens[1:, None] + range_idx[None]
        self.conv_idx = self.conv_idx.clamp_min(0)

        self.conv_state_indices = state_ids.to(torch.int32)
        # we assume 0 is dummy state, shared by all invalid states.
        self.valid_state = state_ids >= 0
        self.state_ids = state_ids.clamp(0)


class CausalConv1dFunc:

    def __init__(self, activation: str = 'silu'):
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.CausalConv1d)
        impl = builder.build()
        self.causal_conv1d_fn = impl.conv1d_fn
        self.causal_conv1d_update = impl.update_fn
        self.activation = activation

    def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor,
                    gated_delta_meta: GatedDeltaMeta):
        """
        x: (b, seqlen, dim)
        seqlen: (b)
        out: (b, seqlen, dim)
        conv_state: (b, dim, kernel_size)
        """
        seq_idx = gated_delta_meta.seq_idx
        conv_idx = gated_delta_meta.conv_idx
        state_ids = gated_delta_meta.state_ids

        assert x.dim() == 3
        x = x.transpose(-2, -1)
        if weight.dim() == 3:
            assert weight.size(1) == 1
            weight = weight[:, 0]

        # fill conv state
        # TODO: find efficient way to fill conv state without gather + scatter
        final_state = conv_state.index_select(0, state_ids)
        batch_size = conv_state.size(0)
        conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1)
        torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=final_state)
        conv_state = conv_state.index_copy_(0, state_ids, final_state)

        out = self.causal_conv1d_fn(
            x,
            weight,
            bias,
            seq_idx,
            return_final_states=False,
            activation=self.activation,
        )

        out = out.transpose(-2, -1)

        # store conv_state
        return out, conv_state

    def conv1d_update(
        self,
        x: torch.Tensor,
        weight: torch.Tensor,
        bias: torch.Tensor,
        conv_state: torch.Tensor,
        conv_state_indices: torch.Tensor,
    ):
        if weight.dim() == 3:
            assert weight.size(1) == 1
            weight = weight[:, 0]
        out = self.causal_conv1d_update(x[0],
                                        conv_state,
                                        weight,
                                        bias,
                                        activation=self.activation,
                                        conv_state_indices=conv_state_indices)
        return out[None], conv_state

    @record_function('causal_conv1d')
    def __call__(
        self,
        x: torch.Tensor,
        weight: torch.Tensor,
        bias: torch.Tensor,
        conv_state: torch.Tensor,
        gated_delta_meta: GatedDeltaMeta,
    ):
        if gated_delta_meta.is_decoding:
            conv_state_indices = gated_delta_meta.conv_state_indices
            return self.conv1d_update(x, weight, bias, conv_state, conv_state_indices)
        return self.conv1d_func(x, weight, bias, conv_state, gated_delta_meta=gated_delta_meta)


class GatedDelta:

    def __init__(self, use_qk_l2norm_in_kernel: bool = True):
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.GatedDeltaRule)
        self.impl = builder.build()
        self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel

    def __call__(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        g: torch.Tensor,
        beta: torch.Tensor,
        recurrent_state: torch.Tensor,
        gated_delta_meta: GatedDeltaMeta,
    ):
        """call."""
        is_decoding = gated_delta_meta.is_decoding
        cu_seqlens = gated_delta_meta.cu_seqlens
        state_ids = gated_delta_meta.state_ids

        if not is_decoding:
            core_attn_out, last_recurrent_state = self.impl.chunk_gated_delta_rule(
                query,
                key,
                value,
                g=g,
                beta=beta,
                initial_state=recurrent_state,
                state_indices=state_ids,
                output_final_state=True,
                use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,
                cu_seqlens=cu_seqlens,
            )
        else:
            # qkvgb (1, seqlen, ...) -> (seqlen, 1, ...)
            core_attn_out, last_recurrent_state = self.impl.fused_recurrent_gated_delta_rule(
                query[0, :, None],
                key[0, :, None],
                value[0, :, None],
                g=g[0, :, None],
                beta=beta[0, :, None],
                initial_state=recurrent_state,
                output_final_state=True,
                use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,
                state_indices=state_ids,
            )
            # out (seqlen, 1, ...) -> (1, seqlen, ...)
            core_attn_out = core_attn_out[None, :, 0]
        return core_attn_out, last_recurrent_state


class CausalConv1d(nn.Module):
    """Causal conv1d wrapper."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int | Tuple[int],
        split: Sequence[int],
        groups: int = 1,
        bias: bool = True,
        device: str | torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        tp, rank = get_tp_world_rank()
        self.tp = tp
        self.rank = rank
        in_channels = in_channels // tp
        out_channels = out_channels // tp
        groups = groups // tp
        assert len(split) == 3
        self.split = split

        weight, w_bias = self.make_weight(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            groups=groups,
            bias=bias,
            device=device,
            dtype=dtype,
        )

        self.register_weight(weight, w_bias)
        self.causal_conv1d_func = CausalConv1dFunc(activation='silu')

    @staticmethod
    def make_weight(
        in_channels: int,
        out_channels: int,
        kernel_size: int | Tuple[int],
        groups: int = 1,
        bias: bool = True,
        device: str | torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        weight_shape = (out_channels, in_channels // groups,
                        kernel_size if isinstance(kernel_size, int) else kernel_size[0])
        bias_shape = (out_channels, ) if bias else None

        weight = torch.empty(weight_shape, device=device, dtype=dtype)
        if bias_shape is not None:
            w_bias = torch.empty(bias_shape, device=device, dtype=dtype)
        else:
            w_bias = None
        return weight, w_bias

    def register_weight(self, weight: torch.Tensor, w_bias: torch.Tensor | None = None):
        self.register_parameter('weight', nn.Parameter(weight))
        self.weight.weight_loader = self.weight_loader
        if w_bias is not None:
            self.register_parameter('bias', nn.Parameter(w_bias))
            self.bias.weight_loader = self.weight_loader
        else:
            self.register_parameter('bias', None)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        q, k, v = loaded_weight.split(self.split, dim=0)
        q = q.chunk(self.tp, dim=0)[self.rank]
        k = k.chunk(self.tp, dim=0)[self.rank]
        v = v.chunk(self.tp, dim=0)[self.rank]
        loaded_weight = torch.cat([q, k, v], dim=0)
        default_weight_loader(param, loaded_weight)

    def forward(self, x: torch.Tensor, conv_state: torch.Tensor, gated_delta_meta: GatedDeltaMeta):
        """forward."""
        return self.causal_conv1d_func(x, self.weight, self.bias, conv_state, gated_delta_meta=gated_delta_meta)


@record_function('gated_delta_load_state')
def load_state(past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):
    """Load states from cache."""
    return past_key_value[:2]


================================================
FILE: lmdeploy/pytorch/nn/linear/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional

import torch
from torch import nn

from lmdeploy.pytorch.config import TPMode
from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank
from lmdeploy.pytorch.models.patch import get_build_model_context

from .awq import AwqLinear, MergedAwqLinear, QKVAwqLinear
from .blocked_fp8 import BlockedF8Linear, MergedBlockedF8Linear, QKVBlockedF8Linear
from .default import BaseLinear, MergedBaseLinear, QKVBaseLinear
from .lora import LoRA  # noqa: F401
from .w8a8 import MergedW8A8Linear, QKVW8A8Linear, W8A8Linear


def build_linear(
    in_features: int,
    out_features: int,
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    colwise: bool = True,
    is_tp: bool = False,
    quant_config: Dict = None,
    all_reduce: bool = True,
    tp_align_size: int = 1,
    dp_gather: bool = False,
    layer_type: str = 'attn',
    prefix: str = '',
) -> nn.Module:
    """Build linear."""
    if layer_type is None:
        layer_type = 'attn'
    all_reduce = all_reduce if is_tp else False
    quant_method = None
    if quant_config is not None:
        quant_config = get_build_model_context().quant_config
        quant_method = quant_config.get_quant_method(prefix)

    if dp_gather and quant_method is not None:
        assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}')

    if quant_method is None:
        return BaseLinear(
            in_features,
            out_features,
            bias=bias,
            dtype=dtype,
            device=device,
            colwise=colwise,
            is_tp=is_tp,
            all_reduce=all_reduce,
            tp_align_size=tp_align_size,
            dp_gather=dp_gather,
            layer_type=layer_type,
        )

    if quant_method == 'awq':
        return AwqLinear(
            in_features,
            out_features,
            w_bit=quant_config.bits,
            group_size=quant_config.group_size,
            bias=bias,
            device=device,
            colwise=colwise,
            is_tp=is_tp,
            all_reduce=all_reduce,
            layer_type=layer_type,
        )
    if quant_method == 'smooth_quant':
        return W8A8Linear(in_features,
                          out_features,
                          bias=bias,
                          dtype=dtype,
                          device=device,
                          colwise=colwise,
                          is_tp=is_tp,
                          all_reduce=all_reduce,
                          quant_dtype=quant_config.quant_dtype,
                          layer_type=layer_type)
    elif quant_method == 'fp8':
        return BlockedF8Linear(
            in_features,
            out_features,
            bias=bias,
            fp8_dtype=quant_config.quant_dtype,
            scale_fmt=quant_config.scale_fmt,
            dtype=dtype,
            device=device,
            colwise=colwise,
            is_tp=is_tp,
            all_reduce=all_reduce,
            dp_gather=dp_gather,
            layer_type=layer_type,
        )
    else:
        raise RuntimeError(f'Unsupported quant method: {quant_method}')


def build_colwise_linear(
    in_features: int,
    out_features: int,
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    is_tp: bool = False,
    tp_align_size: int = 1,
    quant_config: Dict = None,
    dp_disable_tp: bool = False,
    dp_gather: bool = False,
    check_dist: bool = True,
    layer_type: str = 'attn',
    prefix: str = '',
) -> nn.Module:
    """Build columnwise parallel linear layer."""
    if check_dist:
        dist_config = get_dist_manager().current_config()
        tp, tp_mode = dist_config.get_tp_by_layer(layer_type)

        # check is_tp
        is_tp = is_tp if tp > 1 else False
        is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp

        # check dp_gather
        dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False

    return build_linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias,
        dtype=dtype,
        device=device,
        colwise=True,
        is_tp=is_tp,
        quant_config=quant_config,
        all_reduce=False,
        tp_align_size=tp_align_size,
        dp_gather=dp_gather,
        layer_type=layer_type,
        prefix=prefix,
    )


def build_rowwise_linear(
    in_features: int,
    out_features: int,
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    is_tp: bool = False,
    tp_align_size: int = 1,
    quant_config: Dict = None,
    all_reduce: bool = True,
    dp_disable_tp: bool = False,
    check_dist: bool = True,
    layer_type: str = 'attn',
    prefix: str = '',
) -> nn.Module:
    """Build rowwise parallel linear layer."""
    if check_dist:
        dist_config = get_dist_manager().current_config()
        tp, _ = dist_config.get_tp_by_layer(layer_type)
        is_tp = is_tp if tp > 1 else False
        is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp
    return build_linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias,
        dtype=dtype,
        device=device,
        colwise=False,
        is_tp=is_tp,
        quant_config=quant_config,
        all_reduce=all_reduce,
        tp_align_size=tp_align_size,
        layer_type=layer_type,
        prefix=prefix,
    )


def build_merged_colwise_linear(
    in_features: int,
    all_out_features: List[int],
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    quant_config: Dict = None,
    is_tp: bool = True,
    out_names: List[Any] = None,
    dp_gather: bool = False,
    check_dist: bool = True,
    layer_type: str = 'attn',
    prefix: str = '',
):
    """Merge linear."""
    if check_dist and is_tp:
        is_tp = get_tp_world_rank(layer_type)[0] > 1
    quant_method = None
    if quant_config is not None:
        quant_config = get_build_model_context().quant_config
        quant_method = quant_config.get_quant_method(prefix)
    if dp_gather and quant_method is not None:
        assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}')

    if quant_method is None:
        return MergedBaseLinear(in_features=in_features,
                                all_out_features=all_out_features,
                                bias=bias,
                                dtype=dtype,
                                device=device,
                                is_tp=is_tp,
                                out_names=out_names,
                                dp_gather=dp_gather,
                                layer_type=layer_type)

    if quant_method == 'awq':
        return MergedAwqLinear(
            in_features,
            all_out_features=all_out_features,
            w_bit=quant_config.bits,
            group_size=quant_config.group_size,
            bias=bias,
            device=device,
            is_tp=is_tp,
            layer_type=layer_type,
        )
    if quant_method == 'smooth_quant':
        return MergedW8A8Linear(in_features=in_features,
                                all_out_features=all_out_features,
                                bias=bias,
                                dtype=dtype,
                                device=device,
                                is_tp=is_tp,
                                out_names=out_names,
                                quant_dtype=quant_config.quant_dtype,
                                layer_type=layer_type)
    elif quant_method == 'fp8':
        return MergedBlockedF8Linear(
            in_features=in_features,
            all_out_features=all_out_features,
            bias=bias,
            fp8_dtype=quant_config.quant_dtype,
            scale_fmt=quant_config.scale_fmt,
            dtype=dtype,
            device=device,
            is_tp=is_tp,
            out_names=out_names,
            dp_gather=dp_gather,
            layer_type=layer_type,
        )
    else:
        raise RuntimeError(f'Unsupported quant method: {quant_method}')


def build_qkv_proj(in_features: int,
                   num_q_heads: int,
                   num_kv_heads: int,
                   head_size: int,
                   head_size_v: int = None,
                   bias: bool = False,
                   quant_config: Dict = None,
                   dtype: Optional[torch.dtype] = None,
                   device: Optional[torch.device] = None,
                   is_tp: bool = True,
                   num_replicate_kv_heads: int = 1,
                   prefix: str = ''):
    """Build qkv proj."""
    dist_config = get_dist_manager().current_config()
    is_tp = is_tp if dist_config.attn_tp > 1 else False
    quant_method = None
    if quant_config is not None:
        quant_config = get_build_model_context().quant_config
        quant_method = quant_config.get_quant_method(prefix)
    if head_size_v is None:
        head_size_v = head_size

    if quant_method is None:
        return QKVBaseLinear(in_features=in_features,
                             num_q_heads=num_q_heads,
                             num_kv_heads=num_kv_heads,
                             head_size=head_size,
                             head_size_v=head_size_v,
                             bias=bias,
                             dtype=dtype,
                             device=device,
                             is_tp=is_tp,
                             num_replicate_kv_heads=num_replicate_kv_heads)

    if quant_method == 'awq':
        return QKVAwqLinear(in_features=in_features,
                            num_q_heads=num_q_heads,
                            num_kv_heads=num_kv_heads,
                            head_size=head_size,
                            head_size_v=head_size_v,
                            w_bit=quant_config.bits,
                            group_size=quant_config.group_size,
                            bias=bias,
                            device=device,
                            is_tp=is_tp,
                            num_replicate_kv_heads=num_replicate_kv_heads)
    if quant_method == 'smooth_quant':
        return QKVW8A8Linear(in_features=in_features,
                             num_q_heads=num_q_heads,
                             num_kv_heads=num_kv_heads,
                             head_size=head_size,
                             head_size_v=head_size_v,
                             bias=bias,
                             dtype=dtype,
                             device=device,
                             is_tp=is_tp,
                             num_replicate_kv_heads=num_replicate_kv_heads,
                             quant_dtype=quant_config.quant_dtype)
    if quant_method == 'fp8':
        return QKVBlockedF8Linear(in_features=in_features,
                                  num_q_heads=num_q_heads,
                                  num_kv_heads=num_kv_heads,
                                  head_size=head_size,
                                  head_size_v=head_size_v,
                                  bias=bias,
                                  fp8_dtype=quant_config.quant_dtype,
                                  scale_fmt=quant_config.scale_fmt,
                                  dtype=dtype,
                                  device=device,
                                  is_tp=is_tp,
                                  dp_gather=False,
                                  num_replicate_kv_heads=num_replicate_kv_heads)
    else:
        raise RuntimeError(f'Unsupported quant method: {quant_method}')


def build_o_proj(
    in_features: int,
    out_features: int,
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    is_tp: bool = False,
    tp_align_size: int = 1,
    quant_config: Dict = None,
    all_reduce: bool = True,
    prefix: str = '',
) -> nn.Module:
    """Build down linear."""
    dist_config = get_dist_manager().current_config()
    is_tp = is_tp if dist_config.attn_tp > 1 else False

    return build_rowwise_linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias,
        dtype=dtype,
        device=device,
        is_tp=is_tp,
        tp_align_size=tp_align_size,
        quant_config=quant_config,
        all_reduce=all_reduce,
        check_dist=False,
        layer_type='attn',
        prefix=prefix,
    )


def build_gateup_linear(
    in_features: int,
    all_out_features: List[int],
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    quant_config: Dict = None,
    is_tp: bool = True,
    out_names: List[Any] = None,
    dp_gather: bool = True,
    prefix: str = '',
):
    """Build gate up linear."""
    dist_config = get_dist_manager().current_config()
    tp, tp_mode = dist_config.get_tp_by_layer('mlp')
    is_tp = is_tp if tp > 1 else False
    dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False

    return build_merged_colwise_linear(
        in_features=in_features,
        all_out_features=all_out_features,
        bias=bias,
        dtype=dtype,
        device=device,
        quant_config=quant_config,
        is_tp=is_tp,
        out_names=out_names,
        dp_gather=dp_gather,
        check_dist=False,
        layer_type='mlp',
        prefix=prefix,
    )


def build_down_linear(
    in_features: int,
    out_features: int,
    bias: bool,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    is_tp: bool = False,
    tp_align_size: int = 1,
    quant_config: Dict = None,
    all_reduce: bool = True,
    prefix: str = '',
) -> nn.Module:
    """Build down linear."""
    dist_config = get_dist_manager().current_config()
    is_tp = is_tp if dist_config.mlp_tp > 1 else False

    return build_rowwise_linear(
        in_features=in_features,
        out_features=out_features,
        bias=bias,
        dtype=dtype,
        device=device,
        is_tp=is_tp,
        tp_align_size=tp_align_size,
        quant_config=quant_config,
        all_reduce=all_reduce,
        check_dist=False,
        layer_type='mlp',
        prefix=prefix,
    )


================================================
FILE: lmdeploy/pytorch/nn/linear/awq.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

from ..utils import chunk_aligned, get_distribute_size
from .base import LinearBase
from .utils import QKVMixin, check_qkv_split_layout


class AwqLinear(LinearBase):
    """W4a16 linear."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        w_bit: int,
        group_size: int,
        bias: bool,
        device: Optional[torch.device] = None,
        colwise: bool = True,
        is_tp: bool = False,
        all_reduce: bool = True,
        layer_type: str = 'attn',
    ):
        super().__init__(dtype=torch.float16,
                         device=device,
                         colwise=colwise,
                         is_tp=is_tp,
                         all_reduce=all_reduce,
                         layer_type=layer_type)
        if self.is_tp:
            in_features, out_features = self._get_io_features(in_features, out_features, w_bit, group_size, colwise)
        qweight, scales, qzeros, bias = self.create_weights(in_features, out_features, w_bit, group_size, bias,
                                                            self.dtype, self.device)
        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW4A16)
        self.impl = impl_builder.build(in_features,
                                       out_features,
                                       w_bit,
                                       group_size,
                                       bias is not None,
                                       dtype=scales.dtype)
        self.register_all_parameters(qweight, scales, qzeros, bias)

        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size
        self.elem_per_int = 32 // w_bit

    def setup_loaders(self):
        """Setup weight loaders."""
        self.qweight.weight_loader = self.weight_loader
        self.qweight._weight_type = 'qweight'
        self.scales.weight_loader = self.weight_loader
        self.scales._weight_type = 'scales'
        self.qzeros.weight_loader = self.weight_loader
        self.qzeros._weight_type = 'qzeros'
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader
            self.bias._weight_type = 'bias'

    def register_all_parameters(self,
                                qweight: torch.Tensor,
                                scales: torch.Tensor,
                                qzeros: torch.Tensor,
                                bias: Optional[torch.Tensor] = None):
        """Register all parameters."""
        qweight = torch.nn.Parameter(qweight, requires_grad=False)
        scales = torch.nn.Parameter(scales, requires_grad=False)
        qzeros = torch.nn.Parameter(qzeros, requires_grad=False)
        if bias is not None:
            bias = torch.nn.Parameter(bias, requires_grad=False)
        self.register_parameter('qweight', qweight)
        self.register_parameter('scales', scales)
        self.register_parameter('qzeros', qzeros)
        self.register_parameter('bias', bias)
        self.setup_loaders()

    def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool):
        """Get io features."""
        align = max(32 // w_bit, group_size)
        world_size, rank = self.get_tp_world_rank()
        if colwise:
            out_features = get_distribute_size(out_features, world_size, rank, align=align)
        else:
            in_features = get_distribute_size(in_features, world_size, rank, align=align)
        return in_features, out_features

    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for colwise linear."""
        if loaded_weight.dim() == 1:
            # bias
            align = max(self.elem_per_int, self.group_size)
            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]
            return default_weight_loader(param, weight)

        if loaded_weight.size(1) == self.out_features:
            # scaling
            align = max(self.elem_per_int, self.group_size)
            weight = chunk_aligned(loaded_weight, world_size, 1, align)[rank]
            return default_weight_loader(param, weight)

        align = max(self.elem_per_int, self.group_size) // self.elem_per_int
        weight = chunk_aligned(loaded_weight, world_size, 1, align)[rank]
        return default_weight_loader(param, weight)

    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for rowwise linear."""
        if loaded_weight.dim() == 1:
            # bias
            if rank == 0:
                loaded_weight = torch.zeros_like(loaded_weight)
            return default_weight_loader(param, loaded_weight)

        if loaded_weight.size(0) == self.in_features:
            # qweight
            align = max(self.elem_per_int, self.group_size)
            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]
            return default_weight_loader(param, weight)

        align = max(self.elem_per_int, self.group_size) // self.group_size
        weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]
        return default_weight_loader(param, weight)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        if not self.is_tp:
            return default_weight_loader(param, loaded_weight)

        world_size, rank = self.get_tp_world_rank()
        if self.colwise:
            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)
        else:
            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)

    def create_weights(self, in_features: int, out_features: int, w_bit: int, group_size: int, bias: bool,
                       dtype: torch.dtype, device: torch.device):
        """Create weights."""
        assert in_features % group_size == 0
        elem_per_int = 32 // w_bit
        assert out_features % elem_per_int == 0

        grouped_in_feats = in_features // group_size
        quant_out_feats = out_features // elem_per_int
        qweight = torch.empty((in_features, quant_out_feats), dtype=torch.int32, device=device)
        scales = torch.empty((grouped_in_feats, out_features), dtype=dtype, device=device)
        qzeros = torch.empty((grouped_in_feats, quant_out_feats), dtype=torch.int32, device=device)
        if bias:
            bias = torch.empty((out_features, ), dtype=dtype, device=device)
        else:
            bias = None
        return qweight, scales, qzeros, bias

    def update_weights(self):
        """Update weights."""
        qweight, scales, qzeros, bias = self.impl.update_weights(self.qweight, self.scales, self.qzeros, self.bias)
        self.register_all_parameters(qweight, scales, qzeros, bias)

    def _forward_default(self, x, all_reduce, tp_sizes):
        """Default forward implement."""
        return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce, group=self.tp_group)


class MergedAwqLinear(AwqLinear):
    """Merged awq linear."""

    def __init__(self,
                 in_features: int,
                 all_out_features: List[int],
                 w_bit: int,
                 group_size: int,
                 bias: bool,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 out_names: Optional[List[int]] = None,
                 layer_type: str = 'attn'):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)

        self.split_section_s = all_out_features
        elem_per_int = 32 // w_bit
        self.split_section_wz = [size // elem_per_int for size in all_out_features]

        all_out_features = self._update_all_out_features(all_out_features, w_bit, group_size)
        self.all_out_features = all_out_features
        if out_names is None:
            out_names = torch.arange(len(self.all_out_features)).tolist()
        assert len(out_names) == len(self.all_out_features)
        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))
        out_features = sum(all_out_features)
        super().__init__(in_features,
                         out_features,
                         w_bit,
                         group_size,
                         bias,
                         device,
                         colwise=True,
                         is_tp=is_tp,
                         layer_type=layer_type)
        self.setup_loaders()

    def setup_loaders(self):
        """Setup weight loaders."""
        self.qweight.weight_loader = self.weight_loader
        self.qweight.weight_spliter = self.weight_spliter_wz
        self.qweight._weight_type = 'qweight'
        self.scales.weight_loader = self.weight_loader
        self.scales.weight_spliter = self.weight_spliter_s
        self.scales._weight_type = 'scales'
        self.qzeros.weight_loader = self.weight_loader
        self.qzeros.weight_spliter = self.weight_spliter_wz
        self.qzeros._weight_type = 'qzeros'
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader
            self.bias.weight_spliter = self.weight_spliter_s
            self.bias._weight_type = 'bias'

    def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool):
        """Get io features."""
        return in_features, out_features

    def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int):
        """Update all out features."""
        world_size, rank = self.get_tp_world_rank()
        new_all_out_features = []
        align = max(32 // w_bit, group_size)
        for out_feat in all_out_features:
            new_out_feat = get_distribute_size(out_feat, world_size, rank, align)
            new_all_out_features.append(new_out_feat)
        return new_all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]
        if loaded_weight.dim() == 1:
            # bias
            align = max(self.elem_per_int, self.group_size)
            param_w = param.data.split(self.all_out_features, 0)[shard_idx]
            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]
            param_w.copy_(weight)

        if param._weight_type in ['scales', 'bias']:
            # scales
            align = max(self.elem_per_int, self.group_size)
            param_w = param.data.split(self.all_out_features, -1)[shard_idx]
        else:
            # qweight or qzeros
            align = max(self.elem_per_int, self.group_size) // self.elem_per_int
            quanted_out_feats = [feat // self.elem_per_int for feat in self.all_out_features]
            param_w = param.data.split(quanted_out_feats, 1)[shard_idx]

        weight = chunk_aligned(loaded_weight, world_size, -1, align)[rank]
        param_w.copy_(weight)

    def weight_spliter_wz(self, loaded_weight: torch.Tensor):
        """Weight spliter."""
        return loaded_weight.split(self.split_section_wz, dim=1)

    def weight_spliter_s(self, loaded_weight: torch.Tensor):
        """Weight spliter."""
        return loaded_weight.split(self.split_section_s, dim=-1)


class QKVAwqLinear(MergedAwqLinear, QKVMixin):
    """Qkv awq linear."""

    def __init__(self,
                 in_features: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 head_size: int,
                 head_size_v: int,
                 w_bit: int,
                 group_size: int,
                 bias: bool = False,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 num_replicate_kv_heads: int = 1):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')
        QKVMixin.__init__(self,
                          num_q_heads=num_q_heads,
                          num_kv_heads=num_kv_heads,
                          head_size=head_size,
                          head_size_v=head_size_v,
                          num_replicate_kv_heads=num_replicate_kv_heads,
                          is_tp=is_tp,
                          tp=self.tp,
                          tp_rank=self.tp_rank)

        elem_per_int = 32 // w_bit
        self.qkv_split_section_s = self.qkv_split_section
        self.qkv_split_section_wz = [size // elem_per_int for size in self.qkv_split_section_s]
        all_out_features = self.get_qkv_out_feautures()
        out_names = ('q', 'k', 'v')
        super().__init__(in_features,
                         all_out_features,
                         w_bit=w_bit,
                         group_size=group_size,
                         bias=bias,
                         device=device,
                         is_tp=is_tp,
                         out_names=out_names,
                         layer_type='attn')

    def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int):
        """Update all out features."""
        return all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        chunk_size, chunk_idx = world_size, rank
        shard_idx = self.out_names_map[shard_id]

        if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
            # update to duplicate k/v for tp_size > num_kv_heads
            chunk_size = world_size // self.num_replicate_kv_heads
            chunk_idx = rank // self.num_replicate_kv_heads

        if loaded_weight.dim() == 1:
            # bias
            align = max(self.elem_per_int, self.group_size)
            param_w = param.data.split(self.all_out_features, 0)[shard_idx]
            weight = chunk_aligned(loaded_weight, chunk_size, 0, align)[chunk_idx]
            param_w.copy_(weight)
            return

        if param._weight_type in ['scales', 'bias']:
            # scales
            align = max(self.elem_per_int, self.group_size)
            param_w = param.data.split(self.all_out_features, -1)[shard_idx]
        else:
            # qweight or qzeros
            align = max(self.elem_per_int, self.group_size) // self.elem_per_int
            quanted_out_feats = [feat // self.elem_per_int for feat in self.all_out_features]
            param_w = param.data.split(quanted_out_feats, 1)[shard_idx]

        weight = chunk_aligned(loaded_weight, chunk_size, -1, align)[chunk_idx]
        param_w.copy_(weight)

    def weight_spliter_wz(self, loaded_weight: torch.Tensor, layout: str = 'default'):
        """Weight spliter."""
        check_qkv_split_layout(layout)
        if layout == 'default':
            return loaded_weight.split(self.qkv_split_section_wz, dim=1)
        elif layout == 'hgd':
            assert self.head_size == self.head_size_v
            heads = [sec // self.head_size for sec in self.qkv_split_section_s]
            kv_heads = heads[-1]
            loaded_weight = loaded_weight.unflatten(1, (kv_heads, -1, self.head_size // self.elem_per_int))
            q = loaded_weight[:, :, :-2].flatten(1, 3)
            k = loaded_weight[:, :, -2].flatten(1, 2)
            v = loaded_weight[:, :, -1].flatten(1, 2)
            return q, k, v
        else:
            raise RuntimeError(f'Unsupported layout: {layout}')

    def weight_spliter_s(self, loaded_weight: torch.Tensor, layout: str = 'default'):
        """Weight spliter."""
        check_qkv_split_layout(layout)
        if layout == 'default':
            return loaded_weight.split(self.qkv_split_section_s, dim=-1)
        elif layout == 'hgd':
            assert self.head_size == self.head_size_v
            heads = [sec // self.head_size for sec in self.qkv_split_section_s]
            kv_heads = heads[-1]
            loaded_weight = loaded_weight.unflatten(1, (kv_heads, -1, self.head_size))
            q = loaded_weight[:, :, :-2].flatten(1, 3)
            k = loaded_weight[:, :, -2].flatten(1, 2)
            v = loaded_weight[:, :, -1].flatten(1, 2)
            return q, k, v
        else:
            raise RuntimeError(f'Unsupported layout: {layout}')

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.qkv_split_section_s, dim=0)


================================================
FILE: lmdeploy/pytorch/nn/linear/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.config import TPMode
from lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_group, get_dist_manager, get_tp_world_rank,
                                          reduce_scatter_by_tp_sizes)
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager

from .utils import update_tp_args


class LinearForwardDPTP:

    def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192):
        """Linear forward dp tp."""
        self.gemm_func = gemm_func
        self.dist_ctx = get_dist_manager().current_context()
        self.dist_config = self.dist_ctx.dist_config
        self.tp = self.dist_config.mlp_tp
        self.attn_tp = self.dist_config.attn_tp

        tp_group = self.dist_ctx.mlp_tp_group
        self.rank = tp_group.rank
        self.gather_rank = self.rank // self.attn_tp
        self.gather_group = tp_group.gpu_gather_group
        self.tp_group = tp_group.gpu_group
        self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2

    def all_gather(self, hidden_states: torch.Tensor, tp_sizes: List[int]):
        """All gather."""
        hidden_states, handle = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)
        return hidden_states, handle

    def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]):
        """Reduce scatter."""
        hidden_states_list = list(hidden_states.split(tp_sizes, -2))
        cur_out_states = hidden_states_list[self.gather_rank]
        out_states.copy_(cur_out_states)
        hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)]
        hidden_states_list[self.rank] = out_states
        handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True)
        return out_states, handle

    def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, output_states: torch.Tensor, tp_sizes: List[int],
                                 handle: dist.Work):
        """Gemm and reduce scatter."""
        handle.wait()
        cur_out = self.gemm_func(hidden_states)
        return self.reduce_scatter(cur_out, output_states, tp_sizes)

    def forward(self, hidden_states: torch.Tensor):
        """forward."""

        def __slice_tensor(tensor: torch.Tensor, slice_size: int):
            """Slice tensor."""
            cur_tensor = tensor[:slice_size]
            tensor = tensor[slice_size:]
            return cur_tensor, tensor

        def __slice_and_gather():
            """Slice and gather."""
            nonlocal hidden_states, tp_sizes, output_states
            cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round)
            tp_sizes -= cur_tp_sizes
            cur_tp_sizes = cur_tp_sizes.tolist()

            slice_size = cur_tp_sizes[self.gather_rank]
            cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size)
            cur_output, output_states = __slice_tensor(output_states, slice_size)

            # all gather
            cur_hidden_states, handle = self.all_gather(cur_hidden_states, cur_tp_sizes)
            return dict(hidden_states=cur_hidden_states, output_states=cur_output, handle=handle, tp_sizes=cur_tp_sizes)

        step_ctx = get_step_ctx_manager().current_context()
        tp_sizes = step_ctx.dp_meta.moe_tp_sizes
        tp_sizes = torch.tensor(tp_sizes)
        max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round)

        output_states = torch.empty_like(hidden_states)
        return_states = output_states

        # pre
        cur_inputs = __slice_and_gather()
        handles = []

        # main loop
        while tp_sizes.sum() > 0:
            next_inputs = __slice_and_gather()
            _, handle = self._gemm_and_reduce_scatter(**cur_inputs)
            handles.append(handle)
            cur_inputs = next_inputs

        # post
        _, handle = self._gemm_and_reduce_scatter(**cur_inputs)
        handles.append(handle)
        for handle in handles:
            handle.wait()
        return return_states


class LinearBase(nn.Module):
    """Base class for linear layers."""

    def __init__(
        self,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        colwise: bool = True,
        is_tp: bool = False,
        all_reduce: bool = True,
        tp_align_size: int = 1,
        dp_gather: bool = False,
        layer_type: str = 'attn',
    ):
        super().__init__()
        self.init_tp_args(is_tp, all_reduce, colwise, layer_type)
        self.colwise = colwise
        self.tp_align_size = tp_align_size
        self.dp_gather = dp_gather
        if device is None:
            device = torch.device('cpu')
        if dtype is None:
            dtype = torch.float16
        self.device = device
        self.dtype = dtype
        self.layer_type = layer_type

        self.lora_adapters = nn.ModuleDict()

    def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str):
        if getattr(self, '_tp_args_initialized', False):
            return
        is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise, layer_type=layer_type)
        self.is_tp = is_tp
        self.all_reduce = all_reduce
        if is_tp:
            dist_cfg = get_dist_manager().current_config()
            _, rank = get_tp_world_rank(layer_type)
            tp, tp_mode = dist_cfg.get_tp_by_layer(layer_type)
            self.tp_rank = rank
            self.tp = tp
            self.tp_mode = tp_mode
            dist_group = get_dist_group(layer_type=layer_type)
            self.tp_group = dist_group.gpu_group
            self.gather_group = dist_group.gpu_gather_group
        else:
            self.tp_rank = 0
            self.tp = 1
            self.tp_mode = TPMode.DEFAULT
            self.tp_group = None
            self.gather_group = None

        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:

            def _gemm_func(self, x):
                out = self._forward_default(x, False, None)

                for lora_adapter in self.lora_adapters.values():
                    out = lora_adapter(x, out)
                return out

            self.linear_dptp_forward = LinearForwardDPTP(_gemm_func)

        self._tp_args_initialized = True

    def get_tp_world_rank(self):
        """Get tp world rank."""
        assert hasattr(self, 'tp') and hasattr(self, 'tp_rank'), 'Please run init_tp_args first.'
        return self.tp, self.tp_rank

    def update_weights(self):
        """Update weights."""
        raise NotImplementedError('This method should be implemented in subclasses.')

    def _forward_default(self, x, all_reduce: bool, tp_sizes: List[int]):
        """Default forward implement."""
        raise NotImplementedError('This method should be implemented in subclasses.')

    def _forward_lora(self, x, tp_sizes: List[int] = None):
        """Forward with LoRA."""
        out = self._forward_default(x, False, tp_sizes)

        for lora_adapter in self.lora_adapters.values():
            out = lora_adapter(x, out)
        if self.all_reduce:
            if self.tp_mode == TPMode.DP_TP:
                out = reduce_scatter_by_tp_sizes(out, self.tp_rank, tp_sizes, group=self.tp_group)
            else:
                dist.all_reduce(out, group=self.tp_group)
        return out

    def _forward_dp_tp(self, x):
        """Forward dp_tp."""
        if self.dp_gather and self.all_reduce:
            return self.linear_dptp_forward.forward(x)

        step_ctx = get_step_ctx_manager().current_context()
        dp_meta = step_ctx.dp_meta
        tp_sizes = dp_meta.tp_sizes

        if self.dp_gather:
            x = gather_by_tp_sizes(x, tp_sizes, group=self.gather_group)

        if len(self.lora_adapters) == 0:
            return self._forward_default(x, self.all_reduce, tp_sizes)
        else:
            return self._forward_lora(x, tp_sizes)

    def forward(self, x):
        """Forward of linear layer."""
        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:
            return self._forward_dp_tp(x)

        if len(self.lora_adapters) == 0:
            return self._forward_default(x, self.all_reduce, None)
        else:
            return self._forward_lora(x)


================================================
FILE: lmdeploy/pytorch/nn/linear/blocked_fp8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.config import TPMode
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

from ..quant_utils import quant_blocked_fp8
from ..utils import div_up, get_distribute_size
from .base import LinearBase
from .utils import QKVMixin, check_qkv_split_layout


class BlockedF8Linear(LinearBase):
    """Blocked f8 linear."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale_fmt: Optional[str] = None,
        colwise: bool = True,
        is_tp: bool = False,
        all_reduce: bool = True,
        dp_gather: bool = False,
        layer_type: str = 'attn',
    ):
        super().__init__(dtype=dtype,
                         device=device,
                         colwise=colwise,
                         is_tp=is_tp,
                         all_reduce=all_reduce,
                         dp_gather=dp_gather,
                         layer_type=layer_type)
        self.block_size = 128
        self.fp8_dtype = fp8_dtype
        self.scale_fmt = scale_fmt
        if self.is_tp:
            in_features, out_features = self._get_io_features(in_features, out_features, colwise)
        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearBlockedF8)
        self.impl = impl_builder.build(in_features,
                                       out_features,
                                       block_size=128,
                                       bias=bias is not None,
                                       dtype=self.dtype)
        self.impl.set_scale_fmt(scale_fmt)
        weight, weight_scale_inv, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)
        self.register_all_parameters(weight, weight_scale_inv, bias)

        self.in_features = in_features
        self.out_features = out_features

    def setup_loaders(self):
        """Setup weight loaders."""
        self.weight.weight_loader = self.weight_loader_with_quant
        self.weight_scale_inv.weight_loader = self.weight_loader
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader

    def register_all_parameters(self,
                                weight: torch.Tensor,
                                weight_scale_inv: torch.Tensor,
                                bias: Optional[torch.Tensor] = None):
        """Register all parameters."""
        weight = torch.nn.Parameter(weight, requires_grad=False)
        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
        if bias is not None:
            bias = torch.nn.Parameter(bias, requires_grad=False)
        self.register_parameter('weight', weight)
        self.register_parameter('weight_scale_inv', weight_scale_inv)
        self.register_parameter('bias', bias)
        self.setup_loaders()

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        world_size, rank = self.get_tp_world_rank()
        if colwise:
            out_features = get_distribute_size(out_features, world_size, rank)
        else:
            in_features = get_distribute_size(in_features, world_size, rank)
        return in_features, out_features

    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for colwise linear."""
        weight = loaded_weight.chunk(world_size, 0)[rank]
        return default_weight_loader(param, weight)

    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for rowwise linear."""
        if loaded_weight.dim() == 2:
            loaded_weight = loaded_weight.to(param.device)
            weight = loaded_weight.chunk(world_size, 1)[rank]
            return default_weight_loader(param, weight)
        else:
            # bias
            if rank != 0:
                loaded_weight = torch.zeros_like(loaded_weight)
            return default_weight_loader(param, loaded_weight)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        if not self.is_tp:
            return default_weight_loader(param, loaded_weight)

        world_size, rank = self.get_tp_world_rank()
        if self.colwise:
            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)
        else:
            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)

    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader with weight quant."""
        if loaded_weight.dtype != param.dtype:
            # quant loaded weight
            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),
                                                        param.dtype,
                                                        self.block_size,
                                                        scale_fmt=self.scale_fmt)
            self.weight_loader(self.weight, quanted_weight)
            self.weight_loader(self.weight_scale_inv, scaling)
        else:
            return self.weight_loader(param, loaded_weight)

    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):
        """Create weights."""
        weight = torch.empty((out_features, in_features), dtype=self.fp8_dtype, device=device)
        weight_scale_inv = torch.empty((div_up(out_features, self.block_size), div_up(in_features, self.block_size)),
                                       dtype=torch.float32,
                                       device=device)
        if bias:
            bias = torch.empty((out_features, ), dtype=dtype, device=device)
        else:
            bias = None
        return weight, weight_scale_inv, bias

    def update_weights(self):
        """Update weights."""
        weight, weight_scale_inv, bias = self.impl.update_weights(self.weight, self.weight_scale_inv, self.bias)
        self.register_all_parameters(weight, weight_scale_inv, bias)

    def _forward_default(self, x, all_reduce, tp_sizes):
        """Default forward implement."""
        if self.tp_mode == TPMode.DP_TP:
            rank = self.tp_rank
            return self.impl.forward(x,
                                     self.weight,
                                     self.weight_scale_inv,
                                     self.bias,
                                     all_reduce,
                                     group=self.tp_group,
                                     rank=rank,
                                     scatter_size=tp_sizes)
        else:
            return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, group=self.tp_group)


class MergedBlockedF8Linear(BlockedF8Linear):
    """Merged blocked fp8 linear."""

    def __init__(self,
                 in_features: int,
                 all_out_features: List[int],
                 bias: bool,
                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,
                 scale_fmt: Optional[str] = None,
                 replicate: Optional[List[bool]] = None,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 out_names: Optional[List[int]] = None,
                 dp_gather: bool = False,
                 layer_type: str = 'attn'):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)
        if replicate is None:
            replicate = tuple(False for _ in all_out_features)
        self.block_size = 128
        self.split_section = all_out_features
        self.scale_split_section = [section // self.block_size for section in self.split_section]
        all_out_features = self._update_all_out_features(all_out_features, replicate)
        self.all_out_features = all_out_features
        self.replicate = replicate
        if out_names is None:
            out_names = torch.arange(len(self.all_out_features)).tolist()
        assert len(out_names) == len(self.all_out_features)
        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))
        out_features = sum(all_out_features)
        super().__init__(in_features,
                         out_features,
                         bias,
                         dtype,
                         device,
                         fp8_dtype=fp8_dtype,
                         scale_fmt=scale_fmt,
                         colwise=True,
                         is_tp=is_tp,
                         dp_gather=dp_gather,
                         layer_type=layer_type)
        self.setup_loaders()

    def setup_loaders(self):
        """Setup weight loaders."""
        self.weight.weight_loader = self.weight_loader_with_quant
        self.weight.weight_spliter = self.weight_spliter
        self.weight._weight_type = 'qweight'
        self.weight_scale_inv.weight_loader = self.weight_loader
        self.weight_scale_inv.weight_spliter = self.weight_spliter
        self.weight_scale_inv._weight_type = 'scales'
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader
            self.bias.weight_spliter = self.weight_spliter
            self.bias._weight_type = 'bias'

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        return in_features, out_features

    def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]):
        """Update all out features."""
        world_size, rank = self.get_tp_world_rank()
        new_all_out_features = []
        for out_feat, rep in zip(all_out_features, replicate):
            if rep:
                new_all_out_features.append(out_feat)
            new_out_feat = get_distribute_size(out_feat, world_size, rank)
            new_all_out_features.append(new_out_feat)
        return new_all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]
        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
            loaded_weight = loaded_weight.to(torch.float32)
            all_out_features = [feats // self.block_size for feats in self.all_out_features]
            param_w = param.data.split(all_out_features, 0)[shard_idx]
        else:
            param_w = param.data.split(self.all_out_features, 0)[shard_idx]
        if not self.replicate[shard_idx]:
            loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
        param_w.copy_(loaded_weight)

    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader with weight quant."""
        if loaded_weight.dtype != param.dtype:
            # quant loaded weight
            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),
                                                        param.dtype,
                                                        self.block_size,
                                                        scale_fmt=self.scale_fmt)
            self.weight_loader(self.weight, quanted_weight, shard_id)
            self.weight_loader(self.weight_scale_inv, scaling, shard_id)
        else:
            return self.weight_loader(param, loaded_weight, shard_id)

    def weight_spliter(self, loaded_weight: torch.Tensor):
        """Weight spliter."""
        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
            return loaded_weight.split(self.scale_split_section, dim=0)
        return loaded_weight.split(self.split_section, dim=0)

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.split_section, dim=0)


class QKVBlockedF8Linear(MergedBlockedF8Linear, QKVMixin):
    """Qkv blockedf8 linear."""

    def __init__(self,
                 in_features: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 head_size: int,
                 head_size_v: int,
                 bias: bool = False,
                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,
                 scale_fmt: Optional[str] = None,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 dp_gather: bool = False,
                 num_replicate_kv_heads: int = 1):
        self.block_size = 128
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')
        QKVMixin.__init__(self,
                          num_q_heads=num_q_heads,
                          num_kv_heads=num_kv_heads,
                          head_size=head_size,
                          head_size_v=head_size_v,
                          num_replicate_kv_heads=num_replicate_kv_heads,
                          is_tp=is_tp,
                          tp=self.tp,
                          tp_rank=self.tp_rank)

        all_out_features = self.get_qkv_out_feautures()
        out_names = ('q', 'k', 'v')
        super().__init__(in_features,
                         all_out_features,
                         dtype=dtype,
                         fp8_dtype=fp8_dtype,
                         scale_fmt=scale_fmt,
                         bias=bias,
                         device=device,
                         is_tp=is_tp,
                         out_names=out_names,
                         dp_gather=dp_gather,
                         layer_type='attn')

    def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]):
        """Update all out features."""
        return all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        _, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]

        num_head = self.num_q_heads if shard_id == 'q' \
            else self.num_kv_heads
        head_dim = self.head_size if shard_id in ['q', 'k'] \
            else self.head_size_v
        # update to duplicate k/v for tp_size > num_kv_heads
        rank_idx = rank if shard_id == 'q' \
            else rank // self.num_replicate_kv_heads
        sec_len = num_head * head_dim
        all_out_features = self.all_out_features
        if param._weight_type == 'scales':
            loaded_weight = loaded_weight.to(torch.float32)
            all_out_features = [sec // self.block_size for sec in all_out_features]
            sec_len = sec_len // self.block_size

        sec_start = rank_idx * sec_len

        loaded_weight = loaded_weight.narrow(dim=0, start=sec_start, length=sec_len)
        param_w = param.data.split(all_out_features, 0)[shard_idx]
        param_w.copy_(loaded_weight)

    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader with weight quant."""
        if loaded_weight.dtype != param.dtype:
            # quant loaded weight
            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),
                                                        param.dtype,
                                                        self.block_size,
                                                        scale_fmt=self.scale_fmt)
            self.weight_loader(self.weight, quanted_weight, shard_id)
            self.weight_loader(self.weight_scale_inv, scaling, shard_id)
        else:
            return self.weight_loader(param, loaded_weight, shard_id)

    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):
        """Weight spliter."""
        check_qkv_split_layout(layout)
        assert layout == 'default'
        qkv_split_section = self.qkv_split_section
        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
            qkv_split_section = [sec // self.block_size for sec in qkv_split_section]
        return loaded_weight.split(qkv_split_section, dim=0)


================================================
FILE: lmdeploy/pytorch/nn/linear/default.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.config import TPMode
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

from ..utils import chunk_aligned, get_distribute_size
from .base import LinearBase
from .utils import QKVMixin, check_qkv_split_layout


class BaseLinear(LinearBase):
    """Linear layer."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        colwise: bool = True,
        is_tp: bool = False,
        all_reduce: bool = True,
        tp_align_size: int = 1,
        dp_gather: bool = False,
        layer_type: str = 'attn',
    ):
        super().__init__(dtype=dtype,
                         device=device,
                         colwise=colwise,
                         is_tp=is_tp,
                         all_reduce=all_reduce,
                         tp_align_size=tp_align_size,
                         dp_gather=dp_gather,
                         layer_type=layer_type)
        if self.is_tp:
            in_features, out_features = self._get_io_features(in_features, out_features, colwise)
        impl_builder = get_backend().get_layer_impl_builder(OpType.Linear)
        self.impl = impl_builder.build(in_features, out_features, bias is not None, dtype=self.dtype)
        weight, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)
        self.register_all_parameters(weight, bias)

        self.in_features = in_features
        self.out_features = out_features

    def setup_loaders(self):
        """Setup loaders."""
        self.weight.weight_loader = self.weight_loader
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader

    def register_all_parameters(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Register all parameters."""
        weight = torch.nn.Parameter(weight, requires_grad=False)
        if bias is not None:
            bias = torch.nn.Parameter(bias, requires_grad=False)
        self.register_parameter('weight', weight)
        self.register_parameter('bias', bias)
        self.setup_loaders()

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        world_size, rank = self.get_tp_world_rank()
        if colwise:
            out_features = get_distribute_size(out_features, world_size, rank, align=self.tp_align_size)
        else:
            in_features = get_distribute_size(in_features, world_size, rank, align=self.tp_align_size)
        return in_features, out_features

    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for colwise linear."""
        weight = chunk_aligned(loaded_weight, world_size, 0, self.tp_align_size)[rank]
        return default_weight_loader(param, weight)

    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for rowwise linear."""
        if loaded_weight.dim() == 2:
            loaded_weight = loaded_weight.to(param.device)
            weight = chunk_aligned(loaded_weight, world_size, 1, self.tp_align_size)[rank]
            return default_weight_loader(param, weight)
        else:
            # bias
            if rank != 0:
                loaded_weight = torch.zeros_like(loaded_weight)
            return default_weight_loader(param, loaded_weight)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        if not self.is_tp:
            return default_weight_loader(param, loaded_weight)

        world_size, rank = self.get_tp_world_rank()
        if self.colwise:
            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)
        else:
            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)

    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):
        """Create weights."""
        weight = torch.empty((out_features, in_features), dtype=dtype, device=device)
        if bias:
            bias = torch.empty((out_features, ), dtype=dtype, device=device)
        else:
            bias = None
        return weight, bias

    def update_weights(self):
        """Update weights."""
        weight, bias = self.impl.update_weights(self.weight, self.bias)
        self.register_all_parameters(weight, bias)

    def _forward_default(self, x, all_reduce, tp_sizes):
        """Default forward implement."""
        if self.tp_mode == TPMode.DP_TP:
            rank = self.tp_rank
            return self.impl.forward(x,
                                     self.weight,
                                     self.bias,
                                     all_reduce,
                                     group=self.tp_group,
                                     rank=rank,
                                     scatter_size=tp_sizes)
        else:
            return self.impl.forward(x, self.weight, self.bias, all_reduce, group=self.tp_group)


class MergedBaseLinear(BaseLinear):
    """Merged base linear."""

    def __init__(self,
                 in_features: int,
                 all_out_features: List[int],
                 bias: bool,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 out_names: Optional[List[int]] = None,
                 dp_gather: bool = False,
                 layer_type: str = 'attn'):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)
        self.split_section = all_out_features
        all_out_features = self._update_all_out_features(all_out_features)
        self.all_out_features = all_out_features
        if out_names is None:
            out_names = torch.arange(len(self.all_out_features)).tolist()
        assert len(out_names) == len(self.all_out_features)
        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))
        out_features = sum(all_out_features)
        super().__init__(in_features,
                         out_features,
                         bias,
                         dtype,
                         device,
                         colwise=True,
                         is_tp=is_tp,
                         dp_gather=dp_gather,
                         layer_type=layer_type)
        self.setup_loaders()

    def setup_loaders(self):
        """Update loaders."""
        self.weight.weight_loader = self.weight_loader
        self.weight.weight_spliter = self.weight_spliter
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader
            self.bias.weight_spliter = self.weight_spliter

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        return in_features, out_features

    def _update_all_out_features(self, all_out_features: List[int]):
        """Update all out features."""
        world_size, rank = self.get_tp_world_rank()
        new_all_out_features = []
        for out_feat in all_out_features:
            new_out_feat = get_distribute_size(out_feat, world_size, rank)
            new_all_out_features.append(new_out_feat)
        return new_all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]
        param_w = param.data.split(self.all_out_features, 0)[shard_idx]
        loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
        param_w.copy_(loaded_weight)

    def weight_spliter(self, loaded_weight: torch.Tensor):
        """Weight spliter."""
        return loaded_weight.split(self.split_section, dim=0)

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.split_section, dim=0)


class QKVBaseLinear(MergedBaseLinear, QKVMixin):
    """Qkv base linear."""

    def __init__(self,
                 in_features: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 head_size: int,
                 head_size_v: int,
                 bias: bool = False,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 num_replicate_kv_heads: int = 1):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')
        QKVMixin.__init__(self,
                          num_q_heads=num_q_heads,
                          num_kv_heads=num_kv_heads,
                          head_size=head_size,
                          head_size_v=head_size_v,
                          num_replicate_kv_heads=num_replicate_kv_heads,
                          is_tp=is_tp,
                          tp=self.tp,
                          tp_rank=self.tp_rank)

        all_out_features = self.get_qkv_out_feautures()
        out_names = ('q', 'k', 'v')
        super().__init__(in_features,
                         all_out_features,
                         bias=bias,
                         dtype=dtype,
                         device=device,
                         is_tp=is_tp,
                         out_names=out_names,
                         layer_type='attn')

    def _update_all_out_features(self, all_out_features: List[int]):
        """Update all out features."""
        return all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        chunk_size, chunk_idx = world_size, rank
        shard_idx = self.out_names_map[shard_id]
        param_w = param.data.split(self.all_out_features, 0)[shard_idx]

        if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:
            # update to duplicate k/v for tp_size > num_kv_heads
            chunk_size = world_size // self.num_replicate_kv_heads
            chunk_idx = rank // self.num_replicate_kv_heads
        if shard_idx in [0, 1]:
            loaded_weight = chunk_aligned(loaded_weight, chunk_size, 0, self.head_size)[chunk_idx]
        elif shard_idx == 2:
            loaded_weight = chunk_aligned(loaded_weight, chunk_size, 0, self.head_size_v)[chunk_idx]
        param_w.copy_(loaded_weight)

    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):
        """Weight spliter."""
        check_qkv_split_layout(layout)
        if layout == 'default':
            return loaded_weight.split(self.qkv_split_section, dim=0)
        elif layout == 'hgd':
            assert self.head_size == self.head_size_v
            heads = [sec // self.head_size for sec in self.qkv_split_section]
            kv_heads = heads[-1]
            loaded_weight = loaded_weight.unflatten(0, (kv_heads, -1, self.head_size))
            q = loaded_weight[:, :-2].flatten(0, 2)
            k = loaded_weight[:, -2].flatten(0, 1)
            v = loaded_weight[:, -1].flatten(0, 1)
            return q, k, v
        else:
            raise RuntimeError(f'Unsupported layout: {layout}')

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.qkv_split_section, dim=0)


================================================
FILE: lmdeploy/pytorch/nn/linear/lora.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any

import torch
from torch import nn

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.backends.lora import AdapterInfo
from lmdeploy.pytorch.distributed import get_tp_world_rank


class LoRA(nn.Module):
    """LoRA layer."""

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 ranks: torch.Tensor,
                 scalings: torch.Tensor,
                 lora_a: torch.Tensor,
                 lora_b: torch.Tensor,
                 base_slice: slice,
                 ctx_mgr: Any = None,
                 colwise: bool = True,
                 is_tp: bool = True,
                 lora_b_spliter: Any = None):
        super().__init__()
        self.adapter_info = AdapterInfo(
            in_features=in_features,
            out_features=out_features,
            ranks=ranks,
            scalings=scalings,
            base_slice=base_slice,
        )
        impl_builder = get_backend().get_layer_impl_builder(OpType.LoRA)
        self.impl = impl_builder.build()

        lora_A = nn.Parameter(lora_a, requires_grad=False)
        lora_B = nn.Parameter(lora_b, requires_grad=False)
        self.register_parameter('lora_A', lora_A)
        self.register_parameter('lora_B', lora_B)
        lora_A.weight_loader = self.weight_loader_A
        lora_B.weight_loader = self.weight_loader_B
        self.is_tp = is_tp
        self.ctx_mgr = ctx_mgr
        self.colwise = colwise
        self.lora_b_spliter = lora_b_spliter

    def forward(self, x, base_output=None):
        """Forward of loraA@loraB."""
        return self.impl.forward(x,
                                 self.lora_A,
                                 self.lora_B,
                                 base_output,
                                 self.adapter_info,
                                 ctx_mgr=self.ctx_mgr,
                                 colwise=self.colwise,
                                 is_tp=self.is_tp)

    def weight_loader_A(self, param: nn.Parameter, loaded_weight: torch.Tensor, adapter_id: int):
        """Weight loader."""
        rank = self.adapter_info.ranks[adapter_id].item()
        r_start = self.adapter_info.rank_offsets[adapter_id].item()
        r_end = r_start + rank
        param_r = param.data[r_start:r_end]

        if self.is_tp and not self.colwise:
            world_size, rank = get_tp_world_rank()
            loaded_weight = loaded_weight.to(param_r.device)
            loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank]

        param_r.copy_(loaded_weight)

    def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor, adapter_id: int):
        """Weight loader."""
        rank = self.adapter_info.ranks[adapter_id].item()
        r_start = self.adapter_info.rank_offsets[adapter_id].item()
        r_end = r_start + rank
        param_r = param.data[r_start:r_end]

        if self.is_tp and self.colwise:
            world_size, rank = get_tp_world_rank()
            if self.lora_b_spliter is not None:
                loaded_weights = self.lora_b_spliter(loaded_weight)
                new_weights = []
                for w in loaded_weights:
                    w = w.chunk(world_size, dim=0)[rank]
                    new_weights.append(w)
                loaded_weight = torch.cat(new_weights, dim=0)
            else:
                loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]

        param_r.copy_(loaded_weight.t())


================================================
FILE: lmdeploy/pytorch/nn/linear/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.utils import get_logger

from ..utils import get_distribute_size

logger = get_logger('lmdeploy')

QKV_SPLIT_LAYOUTS = ['default', 'hgd']


def check_qkv_split_layout(layout: str):
    if layout not in QKV_SPLIT_LAYOUTS:
        raise RuntimeError(f'Expect qkv split layout in {QKV_SPLIT_LAYOUTS}, '
                           f'but get: {layout}')


def update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str = 'attn'):
    """Update tp args according to the environment."""
    if is_tp:
        world, _ = get_tp_world_rank(layer_type)
        is_tp = world > 1

    if not is_tp or colwise:
        all_reduce = False

    return is_tp, all_reduce


class QKVMixin:
    """Qkv mixin."""

    def __init__(self,
                 num_q_heads: int,
                 num_kv_heads: int,
                 head_size: int,
                 head_size_v: int,
                 num_replicate_kv_heads: int = 1,
                 is_tp: bool = False,
                 tp: int = 1,
                 tp_rank: int = 0):
        qkv_split_section = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v,
                                                       num_replicate_kv_heads)
        num_q_heads, num_kv_heads = self._update_num_heads(is_tp, tp, tp_rank, num_q_heads, num_kv_heads)
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
        self.head_size_v = head_size_v
        self.num_replicate_kv_heads = num_replicate_kv_heads
        self.qkv_split_section = qkv_split_section

    def get_qkv_out_feautures(self):
        """Get qkv out features."""
        return self._get_qkv_out_features(self.num_q_heads, self.num_kv_heads, self.head_size, self.head_size_v)

    def _get_qkv_out_features(self,
                              num_q_heads: int,
                              num_kv_heads: int,
                              head_size: int,
                              head_size_v: int,
                              num_replicate_kv_heads: int = 1):
        """Get io features."""
        num_kv_heads_real = num_kv_heads // num_replicate_kv_heads
        all_out_features = (num_q_heads * head_size, num_kv_heads_real * head_size, num_kv_heads_real * head_size_v)
        return all_out_features

    def _update_num_heads(self, is_tp: bool, tp: int, tp_rank: int, num_q_heads: int, num_kv_heads: int):
        """Update num heads."""
        if not is_tp:
            return num_q_heads, num_kv_heads
        world_size, rank = tp, tp_rank
        num_q_heads = get_distribute_size(num_q_heads, world_size, rank)
        num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)

        return num_q_heads, num_kv_heads

    def split_qkv(self, x: torch.Tensor):
        """Split query, key and value."""
        num_q_heads = self.num_q_heads
        num_kv_heads = self.num_kv_heads
        head_size = self.head_size
        head_size_v = self.head_size_v

        sections = self.all_out_features
        q, k, v = x.split(sections, dim=-1)
        q = q.unflatten(-1, (num_q_heads, head_size))
        k = k.unflatten(-1, (num_kv_heads, head_size))
        v = v.unflatten(-1, (num_kv_heads, head_size_v))
        return q, k, v


================================================
FILE: lmdeploy/pytorch/nn/linear/w8a8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

from ..utils import get_distribute_size
from .base import LinearBase
from .utils import QKVMixin, check_qkv_split_layout


class W8A8Linear(LinearBase):
    """W8a8 linear."""

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 colwise: bool = True,
                 is_tp: bool = False,
                 all_reduce: bool = True,
                 quant_dtype: Optional[torch.dtype] = torch.int8,
                 layer_type: str = 'attn'):
        super().__init__(dtype=torch.float16,
                         device=device,
                         colwise=colwise,
                         is_tp=is_tp,
                         all_reduce=all_reduce,
                         layer_type=layer_type)
        if self.is_tp:
            in_features, out_features = self._get_io_features(in_features, out_features, colwise)
        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8)
        self.quant_dtype = quant_dtype
        self.impl = impl_builder.build(in_features,
                                       out_features,
                                       bias is not None,
                                       dtype=self.dtype,
                                       quant_dtype=quant_dtype)
        weight, scale, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)
        self.register_all_parameters(weight, scale, bias)

        self.in_features = in_features
        self.out_features = out_features

    def setup_loaders(self):
        """Setup weight loaders."""
        self.weight.weight_loader = self.weight_loader
        self.scale.weight_loader = self.weight_loader
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader

    def register_all_parameters(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):
        """Register all parameters."""
        weight = torch.nn.Parameter(weight, requires_grad=False)
        scale = torch.nn.Parameter(scale, requires_grad=False)
        if bias is not None:
            bias = torch.nn.Parameter(bias, requires_grad=False)
        self.register_parameter('weight', weight)
        self.register_parameter('scale', scale)
        self.register_parameter('bias', bias)
        self.setup_loaders()

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        world_size, rank = self.get_tp_world_rank()
        if colwise:
            out_features = get_distribute_size(out_features, world_size, rank)
        else:
            in_features = get_distribute_size(in_features, world_size, rank)
        return in_features, out_features

    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for colwise linear."""
        weight = loaded_weight.chunk(world_size, 0)[rank]
        return default_weight_loader(param, weight)

    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,
                                  world_size: int):
        """Weight loader for rowwise linear."""
        if loaded_weight.dim() == 2 and param.dtype in (torch.int8, torch.float8_e4m3fn, torch.float8_e5m2):
            loaded_weight = loaded_weight.to(param.device)
            weight = loaded_weight.chunk(world_size, 1)[rank]
            return default_weight_loader(param, weight)
        elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:
            # scaling
            return default_weight_loader(param, loaded_weight)
        else:
            # bias
            if rank != 0:
                loaded_weight = torch.zeros_like(loaded_weight)
            return default_weight_loader(param, loaded_weight)

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        if not self.is_tp:
            return default_weight_loader(param, loaded_weight)

        world_size, rank = self.get_tp_world_rank()
        if self.colwise:
            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)
        else:
            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)

    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):
        """Create weights."""
        weight = torch.empty((out_features, in_features), dtype=self.quant_dtype, device=device)
        scale = torch.empty((out_features, 1), dtype=torch.float32, device=device)
        if bias:
            bias = torch.empty((out_features, ), dtype=dtype, device=device)
        else:
            bias = None
        return weight, scale, bias

    def update_weights(self):
        """Update weights."""
        weight, scale, bias = self.impl.update_weights(self.weight, self.scale, self.bias)
        self.register_all_parameters(weight, scale, bias)

    def _forward_default(self, x, all_reduce, tp_sizes):
        """Default forward implement."""
        return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce, group=self.tp_group)


class MergedW8A8Linear(W8A8Linear):
    """Merged w8a8 linear."""

    def __init__(self,
                 in_features: int,
                 all_out_features: List[int],
                 bias: bool,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 out_names: Optional[List[int]] = None,
                 quant_dtype: torch.dtype = torch.int8,
                 layer_type: str = 'attn'):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)
        self.split_section = all_out_features
        all_out_features = self._update_all_out_features(all_out_features)
        self.all_out_features = all_out_features
        if out_names is None:
            out_names = torch.arange(len(self.all_out_features)).tolist()
        assert len(out_names) == len(self.all_out_features)
        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))
        out_features = sum(all_out_features)
        super().__init__(in_features,
                         out_features,
                         bias,
                         dtype,
                         device,
                         colwise=True,
                         is_tp=is_tp,
                         quant_dtype=quant_dtype,
                         layer_type=layer_type)
        self.setup_loaders()

    def setup_loaders(self):
        """Setup weight loaders."""
        self.weight.weight_loader = self.weight_loader
        self.scale.weight_loader = self.weight_loader
        self.weight.weight_spliter = self.weight_spliter
        self.scale.weight_spliter = self.weight_spliter
        if self.bias is not None:
            self.bias.weight_loader = self.weight_loader
            self.bias.weight_spliter = self.weight_spliter

    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
        """Get io features."""
        return in_features, out_features

    def _update_all_out_features(self, all_out_features: List[int]):
        """Update all out features."""
        world_size, rank = self.get_tp_world_rank()
        new_all_out_features = []
        for out_feat in all_out_features:
            new_out_feat = get_distribute_size(out_feat, world_size, rank)
            new_all_out_features.append(new_out_feat)
        return new_all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        world_size, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]
        param_w = param.data.split(self.all_out_features, 0)[shard_idx]
        loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
        param_w.copy_(loaded_weight)

    def weight_spliter(self, loaded_weight: torch.Tensor):
        """Weight spliter."""
        return loaded_weight.split(self.split_section, dim=0)

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.split_section, dim=0)


class QKVW8A8Linear(MergedW8A8Linear, QKVMixin):
    """Qkv w8a8 linear."""

    def __init__(self,
                 in_features: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 head_size: int,
                 head_size_v: int,
                 bias: bool = False,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 is_tp: bool = True,
                 num_replicate_kv_heads: int = 1,
                 quant_dtype: torch.dtype = torch.int8):
        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')
        QKVMixin.__init__(self,
                          num_q_heads=num_q_heads,
                          num_kv_heads=num_kv_heads,
                          head_size=head_size,
                          head_size_v=head_size_v,
                          num_replicate_kv_heads=num_replicate_kv_heads,
                          is_tp=is_tp,
                          tp=self.tp,
                          tp_rank=self.tp_rank)

        all_out_features = self.get_qkv_out_feautures()
        out_names = ('q', 'k', 'v')
        super().__init__(in_features,
                         all_out_features,
                         bias=bias,
                         dtype=dtype,
                         device=device,
                         is_tp=is_tp,
                         out_names=out_names,
                         quant_dtype=quant_dtype,
                         layer_type='attn')

    def _update_all_out_features(self, all_out_features: List[int]):
        """Update all out features."""
        return all_out_features

    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
        """Weight loader."""
        _, rank = self.get_tp_world_rank()
        shard_idx = self.out_names_map[shard_id]
        param_w = param.data.split(self.all_out_features, 0)[shard_idx]
        num_head = self.num_q_heads if shard_id == 'q' \
            else self.num_kv_heads
        head_dim = self.head_size if shard_id in ['q', 'k'] \
            else self.head_size_v
        # update to duplicate k/v for tp_size > num_kv_heads
        rank_idx = rank if shard_id == 'q' \
            else rank // self.num_replicate_kv_heads
        sec_start = rank_idx * num_head * head_dim
        sec_len = num_head * head_dim
        loaded_weight = loaded_weight.narrow(dim=0, start=sec_start, length=sec_len)
        param_w.copy_(loaded_weight)

    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):
        """Weight spliter."""
        check_qkv_split_layout(layout)
        if layout == 'default':
            return loaded_weight.split(self.qkv_split_section, dim=0)
        elif layout == 'hgd':
            assert self.head_size == self.head_size_v
            heads = [sec // self.head_size for sec in self.qkv_split_section]
            kv_heads = heads[-1]
            loaded_weight = loaded_weight.unflatten(0, (kv_heads, -1, self.head_size))
            q = loaded_weight[:, :-2].flatten(0, 2)
            k = loaded_weight[:, -2].flatten(0, 1)
            v = loaded_weight[:, -1].flatten(0, 1)
            return q, k, v
        else:
            raise RuntimeError(f'Unsupported layout: {layout}')

    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
        return loaded_weight.split(self.qkv_split_section, dim=0)


================================================
FILE: lmdeploy/pytorch/nn/moe/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, Optional

import torch

from lmdeploy.pytorch.models.patch import get_build_model_context

from .base import MoeType, SoftmaxTopK  # noqa: F401


def build_fused_moe(
    hidden_dim: int,
    ffn_dim: int,
    num_experts: int,
    top_k: int,
    bias: bool = False,
    renormalize: bool = False,
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    all_reduce: bool = True,
    enable_ep: bool = False,
    quant_config: Dict = None,
    layer_idx: int = 0,
    act_func: Callable = None,
    prefix: str = '',
):
    """Fused moe builder."""
    quant_method = None
    if quant_config is not None:
        quant_config = get_build_model_context().quant_config
        quant_method = quant_config.get_quant_method(prefix)

    if quant_method is None:
        from .default import FusedMoE
        return FusedMoE(
            hidden_dim=hidden_dim,
            ffn_dim=ffn_dim,
            num_experts=num_experts,
            top_k=top_k,
            bias=bias,
            renormalize=renormalize,
            dtype=dtype,
            device=device,
            all_reduce=all_reduce,
            layer_idx=layer_idx,
            act_func=act_func,
        )

    if quant_method == 'smooth_quant':
        assert not bias, 'Quant model does not support bias for now.'
        assert act_func is None, ('Quant model does not support activation function for now.')
        from .w8a8 import FusedMoEW8A8
        return FusedMoEW8A8(
            hidden_dim=hidden_dim,
            ffn_dim=ffn_dim,
            num_experts=num_experts,
            top_k=top_k,
            renormalize=renormalize,
            dtype=dtype,
            quant_dtype=quant_config.quant_dtype,
            device=device,
            all_reduce=all_reduce,
        )
    elif quant_method == 'fp8':
        from .blocked_fp8 import FusedMoEBlockedF8
        return FusedMoEBlockedF8(
            hidden_dim=hidden_dim,
            ffn_dim=ffn_dim,
            num_experts=num_experts,
            top_k=top_k,
            bias=bias,
            renormalize=renormalize,
            fp8_dtype=quant_config.quant_dtype,
            scale_fmt=quant_config.scale_fmt,
            dtype=dtype,
            device=device,
            all_reduce=all_reduce,
            layer_idx=layer_idx,
            act_func=act_func,
        )
    else:
        raise RuntimeError(f'Unsupported quant method: {quant_method}')


================================================
FILE: lmdeploy/pytorch/nn/moe/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional

import torch
import torch.nn as nn

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.config import TPMode
from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager


class MoeType(Enum):
    """Batch ecex type."""
    Default = auto()
    DSAsyncDecode = auto()
    DSAsyncPrefill = auto()


class SoftmaxTopK(nn.Module):
    """Softmax topk."""

    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):
        super().__init__()
        self.top_k = top_k
        impl_builder = get_backend().get_layer_impl_builder(OpType.SoftmaxTopK)
        self.impl = impl_builder.build(top_k, dim, n_groups=n_groups)

    def forward(self, x: torch.Tensor):
        """forward."""
        return self.impl.forward(x)


def update_dims(hidden_dim: int, ffn_dim: int):
    """Update dims."""
    world_size, _ = get_tp_world_rank('moe')
    assert ffn_dim % world_size == 0
    ffn_dim = ffn_dim // world_size
    return hidden_dim, ffn_dim


def split_size(size: int, world_size: int, align: int):
    size = size // align
    base = size // world_size
    remain = size % world_size
    split_size = [base + 1] * remain + [base] * (world_size - remain)
    split_size = [s * align for s in split_size]
    return split_size


def moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None):
    dist_config = get_dist_manager().current_config()
    tp = dist_config.moe_tp
    if tp == 1:
        return hidden_states, topk_weights, topk_ids

    tp_mode = dist_config.moe_tp_mode
    if tp_mode == TPMode.DEFAULT:
        return hidden_states, topk_weights, topk_ids
    elif tp_mode == TPMode.DP_TP:
        step_ctx = get_step_ctx_manager().current_context()
        dp_meta = step_ctx.dp_meta
        tp_sizes = dp_meta.moe_tp_sizes
        hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group)
        topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group)
        topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group)
    else:
        raise RuntimeError('Not supported.')

    return hidden_states, topk_weights, topk_ids


def moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None):
    dist_config = get_dist_manager().current_config()
    if dist_config.moe_tp == 1:
        return ret

    if tp_mode == TPMode.DEFAULT:
        dist.all_reduce(ret, group=group)
        return ret
    elif tp_mode == TPMode.DP_TP:
        step_ctx = get_step_ctx_manager().current_context()
        dp_meta = step_ctx.dp_meta
        tp_size = dp_meta.moe_tp_sizes
        ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group)
        return ret
    else:
        raise RuntimeError('Not supported.')


class MoEForwardDPTP:

    def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192):
        """MoE forward dp tp."""
        self.gemm_func = gemm_func
        self.dist_ctx = get_dist_manager().current_context()
        self.dist_config = self.dist_ctx.dist_config
        self.tp = self.dist_config.moe_tp
        self.attn_tp = self.dist_config.attn_tp

        tp_group = self.dist_ctx.moe_tp_group
        self.rank = tp_group.rank
        self.gather_rank = self.rank // self.attn_tp
        self.gather_group = tp_group.gpu_gather_group
        self.tp_group = tp_group.gpu_group
        self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp

    def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                   tp_sizes: List[int]):
        """All gather."""
        hidden_states, h0 = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)
        topk_weights, h1 = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True)
        topk_ids, h2 = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True)
        return hidden_states, topk_weights, topk_ids, (h0, h1, h2)

    def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]):
        """Reduce scatter."""
        hidden_states_list = list(hidden_states.split(tp_sizes, -2))
        cur_out_states = hidden_states_list[self.gather_rank]
        out_states.copy_(cur_out_states)
        hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)]
        hidden_states_list[self.rank] = out_states
        handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True)
        return out_states, handle

    def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                                 output_states: torch.Tensor, tp_sizes: List[int], handles: List[dist.Work]):
        """Gemm and reduce scatter."""
        for handle in handles:
            handle.wait()
        cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids)
        return self.reduce_scatter(cur_out, output_states, tp_sizes)

    def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor):
        """forward."""

        def __slice_tensor(tensor: torch.Tensor, slice_size: int):
            """Slice tensor."""
            cur_tensor = tensor[:slice_size]
            tensor = tensor[slice_size:]
            return cur_tensor, tensor

        def __slice_and_gather():
            """Slice and gather."""
            nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states
            cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round)
            tp_sizes -= cur_tp_sizes
            cur_tp_sizes = cur_tp_sizes.tolist()

            slice_size = cur_tp_sizes[self.gather_rank]
            cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size)
            cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size)
            cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size)
            cur_output, output_states = __slice_tensor(output_states, slice_size)

            # all gather
            cur_hidden_states, cur_topk_weights, cur_topk_ids, handles = self.all_gather(
                cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes)
            return dict(hidden_states=cur_hidden_states,
                        topk_weights=cur_topk_weights,
                        topk_ids=cur_topk_ids,
                        output_states=cur_output,
                        handles=handles,
                        tp_sizes=cur_tp_sizes)

        step_ctx = get_step_ctx_manager().current_context()
        tp_sizes = step_ctx.dp_meta.moe_tp_sizes
        tp_sizes = torch.tensor(tp_sizes)
        max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round)

        output_states = torch.empty_like(hidden_states)
        return_states = output_states

        # pre
        cur_inputs = __slice_and_gather()

        out_handles = []
        # main loop
        while tp_sizes.sum() > 0:
            next_inputs = __slice_and_gather()
            _, handle = self._gemm_and_reduce_scatter(**cur_inputs)
            out_handles.append(handle)
            cur_inputs = next_inputs

        # post
        _, handle = self._gemm_and_reduce_scatter(**cur_inputs)
        out_handles.append(handle)
        for handle in out_handles:
            handle.wait()
        return return_states


def _renormalize(topk_weights: torch.Tensor, renormalize: bool):
    if renormalize:
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
    if not topk_weights.is_contiguous():
        topk_weights = topk_weights.contiguous()
    return topk_weights


@dataclass
class DispatchInputs:
    """Dispatch inputs."""
    hidden_states: torch.Tensor
    topk_weights: torch.Tensor
    topk_idx: torch.LongTensor
    moe_type: MoeType = MoeType.Default

    @classmethod
    def from_dict(cls, input: Dict):
        """From dict."""
        assert ['hidden_states', 'topk_weights', 'topk_idx'] in input
        moe_type = input.get('moe_type', MoeType.Default)
        return cls(
            hidden_states=input['hidden_states'],
            topk_weights=input['topk_weights'],
            topk_idx=input['topk_idx'],
            moe_type=moe_type,
        )

    def to_dict(self) -> Dict:
        """To dict."""
        return {
            'hidden_states': self.hidden_states,
            'topk_weights': self.topk_weights,
            'topk_idx': self.topk_idx,
            'moe_type': self.moe_type,
        }


class FusedMoEBase(nn.Module):
    """Fused MoE base."""

    def __init__(self, tp: int, tp_mode: TPMode, do_renormalize: bool):
        super().__init__()
        self.tp = tp
        self.tp_mode = tp_mode
        self.do_renormalize = do_renormalize

    def init_dist_args(self, all_reduce: bool):
        """Init tp args."""
        dist_ctx = get_dist_manager().current_context()
        dist_cfg = dist_ctx.dist_config
        _, tp_mode = dist_cfg.get_tp_by_layer('moe')
        tp, tp_rank = get_tp_world_rank('moe')
        all_reduce = all_reduce if tp > 1 else False

        self.ep = dist_cfg.ep
        self.tp = tp
        self.tp_rank = tp_rank
        self.tp_mode = tp_mode
        self.all_reduce = all_reduce
        self.tp_group = dist_ctx.moe_tp_group.gpu_group
        self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group

        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:

            def __gemm_func(hidden_states, topk_weights, topk_ids):
                return self.gemm(
                    dict(
                        hidden_states=hidden_states,
                        topk_weights=topk_weights,
                        topk_idx=topk_ids,
                        moe_type=MoeType.Default,
                    ))['hidden_states']

            self._forward_dptp = MoEForwardDPTP(__gemm_func)
        else:
            self._forward_dptp = None

    def before_dispatch(self, state: DispatchInputs):
        """Before dispatch."""
        raise NotImplementedError

    def dispatch(self, state: Dict):
        """dispatch."""
        raise NotImplementedError

    def gemm(self, state: Dict):
        """gemm."""
        raise NotImplementedError

    def combine(self, state: Dict):
        """combine."""
        raise NotImplementedError

    def wait(self, state: Dict):
        """wait."""
        raise NotImplementedError

    @property
    def forward_dptp(self) -> MoEForwardDPTP:
        """Forward dptp."""
        return self._forward_dptp

    def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor):
        """Default forward."""
        state = {
            'hidden_states': hidden_states,
            'topk_idx': topk_idx,
            'topk_weights': topk_weights,
            'moe_type': MoeType.Default,
        }
        recv_state = self.dispatch(state)
        gemm_state = self.gemm(recv_state)
        out_state = self.combine(gemm_state)
        return out_state['hidden_states']

    def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor):
        """forward."""
        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:
            return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx)
        else:
            return self.forward_default(hidden_states, topk_weights, topk_idx)

    def renormalize(self, topk_weights):
        """renormalize."""
        return _renormalize(topk_weights, self.do_renormalize)


================================================
FILE: lmdeploy/pytorch/nn/moe/blocked_fp8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank

from ..quant_utils import quant_blocked_fp8
from ..utils import div_up
from .base import DispatchInputs, FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce
from .base import split_size as _split_size
from .default import LinearWeights


class LinearWeightsBlockedF8(LinearWeights):
    """Fused moe linear blocked fp8 weights."""

    def __init__(self,
                 num_experts: int,
                 in_features: int,
                 out_features: int,
                 weight_type: str,
                 block_size: int,
                 dtype: torch.dtype,
                 device: torch.device,
                 bias: bool = False,
                 expert_list: List[int] = None,
                 scale_fmt: Optional[str] = None):
        super().__init__(num_experts=num_experts,
                         in_features=in_features,
                         out_features=out_features,
                         weight_type=weight_type,
                         dtype=dtype,
                         device=device,
                         bias=bias,
                         expert_list=expert_list)
        self.scale_fmt = scale_fmt
        self.block_size = block_size
        weight_scale_inv = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)),
                                       dtype=torch.float32,
                                       device=device)
        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
        self.register_parameter('weight_scale_inv', weight_scale_inv)

        if self.ep:
            self.weight._base_weight_loader = self.weight.weight_loader
            self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep
        else:
            self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8
            self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp
        self.weight.weight_loader = self.weight_loader_with_quant

    def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor):
        """Update weight."""
        super().update_weight(weight=weight)
        weight_loader = self.weight_scale_inv.weight_loader
        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
        weight_scale_inv.weight_loader = weight_loader
        self.register_parameter('weight_scale_inv', weight_scale_inv)

    def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
                               shard_id: str):
        expert_list = self.expert_list
        if expert_id not in expert_list:
            return
        expert_ids = self.expert_map[expert_id]
        for expert_id in expert_ids:
            self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id)

    def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int):
        """Chunk with align."""
        split_size = _split_size(weight.size(dim), world_size, align)
        return weight.split(split_size, dim=dim)[rank]

    def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
                                     shard_id: str):
        """Weight loader."""
        world_size, rank = get_tp_world_rank('moe')
        if shard_id == 'gate':
            param_data = param.data[expert_id, :self.half_out]
            weight = self._chunk_weight_tp(loaded_weight,
                                           dim=0,
                                           world_size=world_size,
                                           rank=rank,
                                           align=self.block_size)
        elif shard_id == 'up':
            param_data = param.data[expert_id, self.half_out:]
            weight = self._chunk_weight_tp(loaded_weight,
                                           dim=0,
                                           world_size=world_size,
                                           rank=rank,
                                           align=self.block_size)
        elif shard_id == 'down':
            param_data = param.data[expert_id]
            # weight is not contiguous, chunk and copy in cpu is slow
            weight = loaded_weight.to(param_data.device)
            if weight.dim() > 1:
                weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size)
            elif weight.dim() == 1 and rank != 0:
                # bias with rank>0 should be 0
                weight = torch.zeros_like(weight)
        else:
            raise RuntimeError(f'Unknown shard_id: {shard_id}')
        param_data.copy_(weight)

    def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
                               shard_id: str):
        """Weight loader scale tp."""
        world_size, rank = get_tp_world_rank('moe')
        block_size = self.block_size
        half_out = self.half_out // block_size
        if shard_id == 'gate':
            param_data = param.data[expert_id, :half_out]
            weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1)
        elif shard_id == 'up':
            param_data = param.data[expert_id, half_out:]
            weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1)
        elif shard_id == 'down':
            param_data = param.data[expert_id]
            loaded_weight = loaded_weight.to(param_data.device)
            weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1)
        else:
            raise RuntimeError(f'Unknown shard_id: {shard_id}')
        param_data.copy_(weight)

    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
                                 shard_id: str):
        """Weight load with quant."""
        if loaded_weight.dtype != param.dtype:
            # quant loaded weight
            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),
                                                        param.dtype,
                                                        self.block_size,
                                                        scale_fmt=self.scale_fmt)
            self.weight._base_weight_loader(self.weight, quanted_weight, expert_id, shard_id)
            self.weight_scale_inv.weight_loader(self.weight_scale_inv, scaling, expert_id, shard_id)
        else:
            return self.weight._base_weight_loader(param, loaded_weight, expert_id, shard_id)


class FusedMoEBlockedF8(FusedMoEBase):
    """Fused moe blocked f8."""

    def __init__(self,
                 hidden_dim: int,
                 ffn_dim: int,
                 num_experts: int,
                 top_k: int,
                 bias: bool = False,
                 renormalize: bool = False,
                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,
                 scale_fmt: Optional[str] = None,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 all_reduce: bool = True,
                 layer_idx: int = 0,
                 act_func: Callable = None):

        device = device or torch.device('cpu')
        dtype = dtype or torch.float16
        # init distributed tp arguments
        self.block_size = 128
        self.init_dist_args(all_reduce)
        self.scale_fmt = scale_fmt

        super().__init__(
            tp=self.tp,
            tp_mode=self.tp_mode,
            do_renormalize=renormalize,
        )

        dist_ctx = get_dist_manager().current_context()
        self.ep_size, rank = get_ep_world_rank()
        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8)
        self.impl = impl_builder.build(top_k,
                                       num_experts,
                                       hidden_dim,
                                       renormalize,
                                       block_size=self.block_size,
                                       ep_size=self.ep_size,
                                       ep_group=dist_ctx.ep_gpu_group,
                                       out_dtype=dtype,
                                       layer_idx=layer_idx,
                                       custom_gateup_act=act_func is not None)
        self.impl.set_scale_fmt(scale_fmt)

        if self.ep_size > 1:
            expert_list = self.impl.ep_expert_list(self.ep_size, rank)
            num_experts = len(expert_list)
        else:
            hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size)
            expert_list = None
        self.expert_list = expert_list

        # create weights
        self.gate_up = LinearWeightsBlockedF8(num_experts,
                                              hidden_dim,
                                              ffn_dim * 2,
                                              weight_type='gate_up',
                                              block_size=self.block_size,
                                              dtype=fp8_dtype,
                                              device=device,
                                              bias=bias,
                                              expert_list=expert_list,
                                              scale_fmt=scale_fmt)
        self.down = LinearWeightsBlockedF8(num_experts,
                                           ffn_dim,
                                           hidden_dim,
                                           weight_type='down',
                                           block_size=self.block_size,
                                           dtype=fp8_dtype,
                                           device=device,
                                           bias=bias,
                                           expert_list=expert_list,
                                           scale_fmt=scale_fmt)

        self.hidden_dim = hidden_dim
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.dtype = dtype
        self.device = device
        self.act_func = act_func

    @staticmethod
    def _update_args(hidden_dim: int, ffn_dim: int, align: int):
        world_size, rank = get_tp_world_rank('moe')
        split_size = _split_size(ffn_dim, world_size, align)
        ffn_dim = split_size[rank]
        return hidden_dim, ffn_dim

    def update_weights(self):
        """Update weights."""
        (gate_up_weights, down_weights, gate_up_scale,
         down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.weight_scale_inv,
                                                self.down.weight_scale_inv)
        self.gate_up.update_weight(gate_up_weights, gate_up_scale)
        self.down.update_weight(down_weights, down_scale)

    def before_dispatch(self, state: DispatchInputs):
        """Before dispatch."""
        if not isinstance(state, Dict):
            state = state.to_dict()

        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = self.fusedmoe_build(low_latency_mode=False)
            state['fusedmoe'] = fusedmoe
            if hasattr(fusedmoe, 'per_token_group_quant_fp8'):
                state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states'])
            previous_event = fusedmoe.capture()
            state['previous_event'] = previous_event
        return state

    def dispatch(self, state: Dict):
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = state['fusedmoe']
            previous_event = state['previous_event']
            (
                recv_hidden_states,
                recv_topk_idx,
                recv_topk_weights,
                recv_tokens_per_expert,
                handle,
                event,
            ) = fusedmoe.dispatch_async(state['hidden_states'],
                                        state['topk_idx'],
                                        state['topk_weights'],
                                        previous_event=previous_event,
                                        async_finish=True)
            recv_state = {
                'fusedmoe': fusedmoe,
                'recv_hidden_states': recv_hidden_states,
                'recv_topk_idx': recv_topk_idx,
                'recv_topk_weights': recv_topk_weights,
                'recv_tokens_per_expert': recv_tokens_per_expert,
                'handle': handle,
                'event': event,
                'num_experts': self.num_experts,
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            fusedmoe = self.fusedmoe_build(low_latency_mode=True)
            use_event = False
            (recv_hidden_states, recv_expert_count, handle, event,
             hook) = fusedmoe.dispatch_async(state['hidden_states'],
                                             state['topk_idx'],
                                             use_fp8=True,
                                             async_finish=use_event)
            recv_state = {
                'fusedmoe': fusedmoe,
                'recv_hidden_states': recv_hidden_states,
                'recv_expert_count': recv_expert_count,
                'topk_idx': state['topk_idx'],
                'topk_weights': state['topk_weights'],
                'raw_hidden_shape': state['raw_hidden_shape'],
                'handle': handle,
                'moe_type': state['moe_type']
            }
            if use_event:
                recv_state['event'] = event
            else:
                recv_state['hook'] = hook
        else:  # MoeType.Default
            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],
                                                                      state['topk_weights'],
                                                                      state['topk_idx'],
                                                                      group=self.gather_group)
            recv_state = {
                'hidden_states': hidden_states,
                'topk_idx': topk_idx,
                'topk_weights': topk_weights,
                'moe_type': state['moe_type']
            }
        return recv_state

    def gemm(self, state: Dict):
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            if (state['recv_hidden_states'][0]
                    if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0:
                state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
                                                                                 self.gate_up.weight_scale_inv,
                                                                                 self.down.weight,
                                                                                 self.down.weight_scale_inv)
            gemm_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': state['recv_hidden_states'],
                'handle': state['handle'],
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
                                                                             self.gate_up.weight_scale_inv,
                                                                             self.down.weight,
                                                                             self.down.weight_scale_inv)
            gemm_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': state['recv_hidden_states'],
                'topk_idx': state['topk_idx'],
                'topk_weights': state['topk_weights'],
                'handle': state['handle'],
                'moe_type': state['moe_type']
            }
        else:  # MoeType.Default
            if self.gate_up.weight.numel() == 0:
                # current rank get no expert chunk
                # create a zero tensor with the same shape as hidden_states
                gemm_state = {'hidden_states': torch.zeros_like(state['hidden_states']), 'moe_type': state['moe_type']}
            else:
                # default fused moe
                hidden_states = self.impl.forward(state['hidden_states'],
                                                  state['topk_weights'],
                                                  state['topk_idx'],
                                                  self.gate_up.weight,
                                                  self.gate_up.weight_scale_inv,
                                                  self.down.weight,
                                                  self.down.weight_scale_inv,
                                                  gate_up_bias=self.gate_up.bias,
                                                  down_bias=self.down.bias,
                                                  expert_list=self.expert_list,
                                                  act_func=self.act_func)
                gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']}
        return gemm_state

    def combine(self, state: Dict):
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = state['fusedmoe']
            previous_event = fusedmoe.capture()
            out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'],
                                                              state['handle'],
                                                              previous_event=previous_event,
                                                              async_finish=True)
            out_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': out_hidden_states,
                'event': event,
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            fusedmoe = state['fusedmoe']
            use_event = False
            out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'],
                                                                    state['topk_idx'],
                                                                    state['topk_weights'],
                                                                    state['handle'],
                                                                    async_finish=use_event)
            out_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': out_hidden_states,
                'moe_type': state['moe_type']
            }
            if use_event:
                out_state['event'] = event
            else:
                out_state['hook'] = hook
        else:  # MoeType.Default
            if self.all_reduce:
                state['hidden_states'] = moe_reduce(state['hidden_states'],
                                                    rank=self.tp_rank,
                                                    tp_mode=self.tp_mode,
                                                    group=self.tp_group)
            out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']}
        return out_state

    def wait(self, state):
        if state.get('event', None) is not None:
            state['fusedmoe'].wait(state['event'])
            return True
        elif state.get('hook', None) is not None:
            state['hook']()
            return True
        else:
            return False

    def fusedmoe_build(self, low_latency_mode: bool = False):
        return self.impl.fusedmoe_build(low_latency_mode)


================================================
FILE: lmdeploy/pytorch/nn/moe/default.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, Dict, List, Optional

import torch
from torch import nn

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank

from .base import DispatchInputs, FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce, update_dims


class LinearWeights(nn.Module):
    """Fused moe linear weights."""

    def __init__(self,
                 num_experts: int,
                 in_features: int,
                 out_features: int,
                 weight_type: str,
                 dtype: torch.dtype,
                 device: torch.device,
                 bias: bool = False,
                 expert_list: Optional[List[int]] = None):
        super().__init__()
        weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device)
        weight = torch.nn.Parameter(weight, requires_grad=False)
        self.register_parameter('weight', weight)

        if bias:
            bias = torch.empty((num_experts, out_features), dtype=dtype, device=device)
            bias = torch.nn.Parameter(bias, requires_grad=False)
            self.register_parameter('bias', bias)
        else:
            self.bias = None

        self.ep = expert_list is not None
        self.expert_list = expert_list
        self.weight_type = weight_type
        self.half_out = out_features // 2

        self.setup_weight_loader()

    def setup_weight_loader(self):
        """Setup weight loader."""
        if self.expert_list is not None:
            self.expert_map = defaultdict(list)
            for idx, eid in enumerate(self.expert_list):
                self.expert_map[eid].append(idx)
            self.weight.weight_loader = self.weight_loader_ep
            if self.bias is not None:
                self.bias.weight_loader = self.weight_loader_ep
        else:
            self.weight.weight_loader = self.weight_loader_tp
            if self.bias is not None:
                self.bias.weight_loader = self.weight_loader_tp

    def update_weight(self, weight: torch.Tensor):
        """Update weight."""
        weight_loader = self.weight.weight_loader
        weight = torch.nn.Parameter(weight, requires_grad=False)
        weight.weight_loader = weight_loader
        self.register_parameter('weight', weight)

    def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str):
        """Weight loader."""
        world_size, rank = get_tp_world_rank('moe')
        if shard_id == 'gate':
            param_data = param.data[expert_id, :self.half_out]
            weight = loaded_weight.chunk(world_size, dim=0)[rank]
        elif shard_id == 'up':
            param_data = param.data[expert_id, self.half_out:]
            weight = loaded_weight.chunk(world_size, dim=0)[rank]
        elif shard_id == 'down':
            param_data = param.data[expert_id]
            # weight is not contiguous, chunk and copy in cpu is slow
            weight = loaded_weight.to(param_data.device)
            if weight.dim() > 1:
                weight = weight.chunk(world_size, dim=1)[rank]
            elif weight.dim() == 1 and rank != 0:
                # bias with rank>0 should be 0
                weight = torch.zeros_like(weight)
        else:
            raise RuntimeError(f'Unknown shard_id: {shard_id}')
        param_data.copy_(weight)

    def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str):
        """Weight loader."""
        expert_list = self.expert_list
        if expert_id not in expert_list:
            return

        expert_map = self.expert_map
        param_ids = expert_map[expert_id]
        for param_id in param_ids:
            if shard_id == 'gate':
                param_data = param.data[param_id, :self.half_out]
            elif shard_id == 'up':
                param_data = param.data[param_id, self.half_out:]
            elif shard_id == 'down':
                param_data = param.data[param_id]
            else:
                raise RuntimeError(f'Unknown shard_id: {shard_id}')
            param_data.copy_(loaded_weight)


class FusedMoE(FusedMoEBase):
    """Fused MoE."""

    def __init__(self,
                 hidden_dim: int,
                 ffn_dim: int,
                 num_experts: int,
                 top_k: int,
                 bias: bool = False,
                 renormalize: bool = False,
                 dtype: Optional[torch.dtype] = None,
                 device: Optional[torch.device] = None,
                 all_reduce: bool = True,
                 layer_idx: int = 0,
                 act_func: Callable = None):

        device = device or torch.device('cpu')
        dtype = dtype or torch.float16
        # init distributed tp arguments
        self.init_dist_args(all_reduce)

        super().__init__(
            tp=self.tp,
            tp_mode=self.tp_mode,
            do_renormalize=renormalize,
        )

        # create implementation
        dist_ctx = get_dist_manager().current_context()
        self.ep_size, rank = get_ep_world_rank()
        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE)
        self.impl = impl_builder.build(
            top_k,
            num_experts,
            renormalize,
            hidden_dim=hidden_dim,
            ep_size=self.ep_size,
            ep_group=dist_ctx.ep_gpu_group,
            layer_idx=layer_idx,
        )

        # create weights
        if self.ep_size > 1:
            expert_list = self.impl.ep_expert_list(self.ep_size, rank)
            num_experts = len(expert_list)
        else:
            hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim)
            expert_list = None
        self.expert_list = expert_list
        self.gate_up = LinearWeights(num_experts,
                                     hidden_dim,
                                     ffn_dim * 2,
                                     weight_type='gate_up',
                                     dtype=dtype,
                                     device=device,
                                     bias=bias,
                                     expert_list=expert_list)
        self.down = LinearWeights(
            num_experts,
            ffn_dim,
            hidden_dim,
            weight_type='down',
            dtype=dtype,
            device=device,
            bias=bias,
            expert_list=expert_list,
        )

        self.hidden_dim = hidden_dim
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.dtype = dtype
        self.device = device
        self.act_func = act_func

    def update_weights(self):
        """Update weights."""
        gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight)
        self.gate_up.update_weight(gate_up_weights)
        self.down.update_weight(down_weights)

    def before_dispatch(self, state: DispatchInputs):
        """Before dispatch."""
        if not isinstance(state, Dict):
            state = state.to_dict()

        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = self.fusedmoe_build(low_latency_mode=False)
            state['fusedmoe'] = fusedmoe
            previous_event = fusedmoe.capture()
            state['previous_event'] = previous_event
        return state

    def dispatch(self, state: Dict):
        """dispatch."""
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = state['fusedmoe']
            previous_event = state['previous_event']
            (
                recv_hidden_states,
                recv_topk_idx,
                recv_topk_weights,
                recv_tokens_per_expert,
                handle,
                event,
            ) = fusedmoe.dispatch_async(state['hidden_states'],
                                        state['topk_idx'],
                                        state['topk_weights'],
                                        previous_event=previous_event,
                                        async_finish=True)
            recv_state = {
                'fusedmoe': fusedmoe,
                'recv_hidden_states': recv_hidden_states,
                'recv_topk_idx': recv_topk_idx,
                'recv_topk_weights': recv_topk_weights,
                'recv_tokens_per_expert': recv_tokens_per_expert,
                'handle': handle,
                'event': event,
                'num_experts': self.num_experts,
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            fusedmoe = self.fusedmoe_build(low_latency_mode=True)
            use_event = False
            (recv_hidden_states, recv_expert_count, handle, event,
             hook) = fusedmoe.dispatch_async(state['hidden_states'],
                                             state['topk_idx'],
                                             use_fp8=False,
                                             async_finish=use_event)
            recv_state = {
                'fusedmoe': fusedmoe,
                'recv_hidden_states': recv_hidden_states,
                'recv_expert_count': recv_expert_count,
                'topk_idx': state['topk_idx'],
                'topk_weights': state['topk_weights'],
                'raw_hidden_shape': state['raw_hidden_shape'],
                'handle': handle,
                'moe_type': state['moe_type']
            }
            if use_event:
                recv_state['event'] = event
            else:
                recv_state['hook'] = hook
        elif moe_type == MoeType.Default:
            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],
                                                                      state['topk_weights'],
                                                                      state['topk_idx'],
                                                                      group=self.gather_group)
            recv_state = {
                'hidden_states': hidden_states,
                'topk_idx': topk_idx,
                'topk_weights': topk_weights,
                'moe_type': moe_type
            }
        else:
            raise NotImplementedError(f'Not supported moe type: {moe_type}')
        return recv_state

    def gemm(self, state: Dict):
        """gemm."""
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            if (state['recv_hidden_states'][0]
                    if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0:
                state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
                                                                                 self.gate_up.weight_scale_inv,
                                                                                 self.down.weight,
                                                                                 self.down.weight_scale_inv)
            gemm_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': state['recv_hidden_states'],
                'handle': state['handle'],
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,
                                                                             self.gate_up.weight_scale_inv,
                                                                             self.down.weight,
                                                                             self.down.weight_scale_inv)
            gemm_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': state['recv_hidden_states'],
                'topk_idx': state['topk_idx'],
                'topk_weights': state['topk_weights'],
                'handle': state['handle'],
                'moe_type': state['moe_type']
            }
        else:
            hidden_states = state['hidden_states']
            topk_weights = state['topk_weights']
            topk_ids = state['topk_idx']

            hidden_states = self.impl.forward(hidden_states,
                                              topk_weights,
                                              topk_ids,
                                              self.gate_up.weight,
                                              self.down.weight,
                                              self.gate_up.bias,
                                              self.down.bias,
                                              self.expert_list,
                                              act_func=self.act_func)
            gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']}
        return gemm_state

    def combine(self, state: Dict):
        """combine."""
        moe_type = state['moe_type']
        if moe_type == MoeType.DSAsyncPrefill:
            fusedmoe = state['fusedmoe']
            previous_event = fusedmoe.capture()
            out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'],
                                                              state['handle'],
                                                              previous_event=previous_event,
                                                              async_finish=True)
            out_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': out_hidden_states,
                'event': event,
                'moe_type': state['moe_type']
            }
        elif moe_type == MoeType.DSAsyncDecode:
            fusedmoe = state['fusedmoe']
            use_event = False
            out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'],
                                                                    state['topk_idx'],
                                                                    state['topk_weights'],
                                                                    state['handle'],
                                                                    async_finish=use_event)
            out_state = {
                'fusedmoe': state['fusedmoe'],
                'hidden_states': out_hidden_states,
                'moe_type': state['moe_type']
            }
            if use_event:
                out_state['event'] = event
            else:
                out_state['hook'] = hook
        elif moe_type == MoeType.Default:
            if self.all_reduce:
                state['hidden_states'] = moe_reduce(state['hidden_states'],
                                                    rank=self.tp_rank,
                                                    tp_mode=self.tp_mode,
                                                    group=self.tp_group)
            out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type}
        else:
            raise NotImplementedError(f'Not supported moe type: {moe_type}')
        return out_state

    def wait(self, state: Dict):
        """wait."""
        if state.get('event', None) is not None:
            state['fusedmoe'].wait(state['event'])
            return True
        elif state.get('hook', None) is not None:
            state['hook']()
            return True
        else:
            return False

    def fusedmoe_build(self, low_latency_mode: bool = False):
        return self.impl.fusedmoe_build(low_latency_mode)


================================================
FILE: lmdeploy/pytorch/nn/moe/route.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.backends import OpType, get_backend


class NoauxTCRouter(torch.nn.Module):

    def __init__(
        self,
        scoring_func: str,
        top_k: int,
        n_group: int,
        topk_group: int,
        n_routed_experts: int,
        routed_scaling_factor: float,
        renormalize: bool = True,
        router_n_groups: int = -1,
    ):
        super().__init__()

        impl_builder = get_backend().get_layer_impl_builder(OpType.RouterNoauxTC)
        self.impl = impl_builder.build(
            scoring_func=scoring_func,
            top_k=top_k,
            n_group=n_group,
            topk_group=topk_group,
            n_routed_experts=n_routed_experts,
            routed_scaling_factor=routed_scaling_factor,
            renormalize=renormalize,
            router_n_groups=router_n_groups,
        )

    def forward(self, router_logits: torch.Tensor,
                e_score_correction_bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Router forward."""
        return self.impl.forward(router_logits, e_score_correction_bias)


================================================
FILE: lmdeploy/pytorch/nn/moe/w8a8.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_tp_world_rank

from .base import FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce, update_dims
from .default import LinearWeights


class LinearWeightsW8A8(LinearWeights):
    """Fused moe linear w8a8 weights."""

    def __init__(self,
                 num_experts: int,
                 in_features: int,
                 out_features: int,
                 weight_type: str,
                 device: torch.device,
                 expert_list: List[int] = None,
                 quant_dtype: torch.dtype = torch.int8):
        super().__init__(
            num_experts=num_experts,
            in_features=in_features,
            out_features=out_features,
            weight_type=weight_type,
            dtype=quant_dtype,
            device=device,
            expert_list=expert_list,
        )
        scale = torch.empty((num_experts, out_features, 1), dtype=torch.float32, device=device)
        scale = torch.nn.Parameter(scale, requires_grad=False)
        self.register_parameter('scale', scale)

        if self.ep:
            self.scale.weight_loader = self.weight_loader_ep
        else:
            self.scale.weight_loader = self.weight_loader_scale_tp

    def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
        """Update weight."""
        super().update_weight(weight=weight)
        weight_loader = self.scale.weight_loader
        scale = torch.nn.Parameter(scale, requires_grad=False)
        scale.weight_loader = weight_loader
        self.register_parameter('scale', scale)

    def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,
                               shard_id: str):
        """Weight loader scale tp."""
        world_size, rank = get_tp_world_rank('moe')
        if shard_id == 'gate':
            param_data = param.data[expert_id, :self.half_out]
            weight = loaded_weight.chunk(world_size, dim=0)[rank]
        elif shard_id == 'up':
            param_data = param.data[expert_id, self.half_out:]
            weight = loaded_weight.chunk(world_size, dim=0)[rank]
        elif shard_id == 'down':
            param_data = param.data[expert_id]
            weight = loaded_weight
        else:
            raise RuntimeError(f'Unknown shard_id: {shard_id}')
        weight = weight.to(param.dtype)
        param_data.copy_(weight)


class FusedMoEW8A8(FusedMoEBase):
    """Fused moe w8a8."""

    def __init__(self,
                 hidden_dim: int,
                 ffn_dim: int,
                 num_experts: int,
                 top_k: int,
                 renormalize: bool = False,
                 dtype: Optional[torch.dtype] = None,
                 quant_dtype: Optional[torch.dtype] = torch.int8,
                 device: Optional[torch.device] = None,
                 all_reduce: bool = True):

        device = device or torch.device('cpu')
        dtype = dtype or torch.float16
        # init distributed tp arguments
        self.init_dist_args(all_reduce)

        # check ep
        if self.ep > 1:
            raise RuntimeError('FusedMoEW8A8 does not support EP mode now.')

        super().__init__(
            tp=self.tp,
            tp_mode=self.tp_mode,
            do_renormalize=renormalize,
        )

        # create implementation
        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEW8A8)
        self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype, quant_dtype=quant_dtype)

        # create weights
        hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim)
        expert_list = None
        self.expert_list = expert_list
        self.gate_up = LinearWeightsW8A8(num_experts,
                                         hidden_dim,
                                         ffn_dim * 2,
                                         weight_type='gate_up',
                                         device=device,
                                         expert_list=expert_list,
                                         quant_dtype=quant_dtype)
        self.down = LinearWeightsW8A8(num_experts,
                                      ffn_dim,
                                      hidden_dim,
                                      weight_type='down',
                                      device=device,
                                      expert_list=expert_list,
                                      quant_dtype=quant_dtype)

        self.hidden_dim = hidden_dim
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.dtype = dtype
        self.device = device
        self.all_reduce = all_reduce

    def update_weights(self):
        """Update weights."""
        (gate_up_weights, down_weights, gate_up_scale,
         down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.scale,
                                                self.down.scale)
        self.gate_up.update_weight(gate_up_weights, gate_up_scale)
        self.down.update_weight(down_weights, down_scale)

    def dispatch(self, state: Dict):
        """dispatch."""
        moe_type = state['moe_type']
        if moe_type == MoeType.Default:
            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],
                                                                      state['topk_weights'],
                                                                      state['topk_idx'],
                                                                      group=self.gather_group)
            recv_state = {
                'hidden_states': hidden_states,
                'topk_idx': topk_idx,
                'topk_weights': topk_weights,
                'moe_type': moe_type
            }
        else:
            raise NotImplementedError(f'Not supported moe type: {moe_type}')
        return recv_state

    def gemm(self, state: Dict):
        """gemm."""
        hidden_states = state['hidden_states']
        topk_weights = state['topk_weights']
        topk_ids = state['topk_idx']

        ret = self.impl.forward(hidden_states, topk_weights, topk_ids, self.gate_up.weight, self.gate_up.scale,
                                self.down.weight, self.down.scale, self.expert_list)
        return dict(hidden_states=ret, moe_type=state['moe_type'])

    def combine(self, state: Dict):
        """combine."""
        moe_type = state['moe_type']
        if moe_type == MoeType.Default:
            if self.all_reduce:
                state['hidden_states'] = moe_reduce(state['hidden_states'],
                                                    rank=self.tp_rank,
                                                    tp_mode=self.tp_mode,
                                                    group=self.tp_group)
            out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type}
        else:
            raise NotImplementedError(f'Not supported moe type: {moe_type}')
        return out_state

    def wait(self, state: Dict):
        """wait."""
        raise NotImplementedError


================================================
FILE: lmdeploy/pytorch/nn/multinomial_sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..backends import OpType, get_backend


def multinomial_sampling(scores: torch.Tensor,
                         seeds: torch.LongTensor,
                         offsets: torch.LongTensor,
                         indices: torch.Tensor = None):
    """Multinomial sampling op."""
    impl_builder = get_backend().get_layer_impl_builder(OpType.MultinomialSampling)
    return impl_builder.build().forward(scores, seeds, offsets, indices)


================================================
FILE: lmdeploy/pytorch/nn/norm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch
from torch import nn

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.models.patch import get_build_model_context

from ..backends import OpType, get_backend
from .utils import chunk_aligned, get_distribute_size


class RMSNorm(nn.Module):
    """RMS Norm with add residual."""

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        quant_config: Dict | None = None,
        tp: bool = False,
        align: int = 1,
        prefix: str = '',
    ):
        super().__init__()
        backend = get_backend()

        quant_method = None
        if quant_config is not None:
            quant_config = get_build_model_context().quant_config
            quant_method = quant_config.get_quant_method(prefix)

        w8a8_flag = quant_method == 'smooth_quant'

        if w8a8_flag:
            builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8)
        else:
            builder = backend.get_layer_impl_builder(OpType.RMSNorm)

        if tp:
            world_size, rank = get_tp_world_rank('attn')
            hidden_size = get_distribute_size(hidden_size, world_size, rank, align=align)

        self.register_parameter('weight', self.create_weight(hidden_size, dtype, device))
        if w8a8_flag:
            self.impl = builder.build(hidden_size, eps, quant_dtype=quant_config.quant_dtype)
        else:
            self.impl = builder.build(hidden_size, eps)

        if tp:
            self.weight.weight_loader = self.weight_loader
        self.align = align

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        world_size, rank = get_tp_world_rank('attn')
        loaded_weight = chunk_aligned(loaded_weight, world_size, 0, self.align)[rank]
        param.copy_(loaded_weight)

    @staticmethod
    def create_weight(hidden_size: int, dtype: torch.dtype | None = None, device: torch.device | None = None):
        """Create weight."""
        if dtype is None:
            dtype = torch.float16
        if device is None:
            device = 'cuda'
        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)
        return weight

    def forward(self, x: torch.Tensor, residual: torch.Tensor = None):
        """forward."""
        return self.impl.forward(x, self.weight, residual)


class LayerNorm(nn.Module):
    """Layer Norm with add residual."""

    def __init__(self,
                 hidden_size: int,
                 eps: float = 1e-6,
                 bias: bool = True,
                 dtype: torch.dtype | None = None,
                 device: torch.device | None = None):
        super().__init__()
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.LayerNorm)
        weight, bias = self.create_weight(hidden_size, bias, dtype, device)
        self.register_parameter('weight', weight)
        self.register_parameter('bias', bias)
        self.impl = builder.build(hidden_size, eps)

    @staticmethod
    def create_weight(hidden_size: int,
                      bias: bool = True,
                      dtype: torch.dtype | None = None,
                      device: torch.device | None = None):
        """Create weight."""
        if dtype is None:
            dtype = torch.float16
        if device is None:
            device = 'cuda'
        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)
        if bias:
            bias = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)
        else:
            bias = None

        return weight, bias

    def forward(self, x: torch.Tensor, residual: torch.Tensor | None = None):
        """forward."""
        return self.impl.forward(x, self.weight, self.bias, residual)


================================================
FILE: lmdeploy/pytorch/nn/nsa.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor, nn

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.backends.attention import AttentionMetadata
from lmdeploy.pytorch.backends.nsa import NSAIndexMeta
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager


class IndexerTopKFP8(nn.Module):

    def __init__(self, topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1):
        super().__init__()
        backend = get_backend()
        index_builder = backend.get_layer_impl_builder(OpType.NSAIndexFP8)
        self.index_impl = index_builder.build(topk, softmax_scale, block_size, fill)

    def forward(
        self,
        q: Tensor,
        k: Tensor,
        weights: Tensor,
        k_cache: Tensor,
        k_s_cache: Tensor,
        attn_metadata: AttentionMetadata = None,
    ):
        """forward."""
        step_ctx = get_step_ctx_manager().current_context()
        cache_config = step_ctx.cache_config
        max_tokens = cache_config.block_size * cache_config.num_gpu_blocks
        is_decoding = attn_metadata.is_decoding
        if q.size(0) == attn_metadata.kv_seqlens.size(0):
            is_decoding = True
        max_q_seqlen = 1 if is_decoding else q.size(0)
        # we need to make max_kv_seqlen=max_allocated_cache_len to enable cudagraph
        max_kv_seqlen = max_tokens if is_decoding else attn_metadata.kv_flatten_size
        meta = NSAIndexMeta(cu_seqlen_q=attn_metadata.cu_seqlens_q,
                            q_seqlens=attn_metadata.q_seqlens,
                            k_seqlens=attn_metadata.kv_seqlens,
                            block_offset=attn_metadata.block_offsets,
                            max_q_seqlen=max_q_seqlen,
                            max_kv_seqlen=max_kv_seqlen)
        ret = self.index_impl.forward(q, k, weights, k_cache, k_s_cache, meta=meta)
        return ret


================================================
FILE: lmdeploy/pytorch/nn/quant_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8  # noqa: F401


================================================
FILE: lmdeploy/pytorch/nn/rotary_embedding.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import math

import torch
from torch import Tensor, nn
from transformers import PretrainedConfig

from ..backends import OpType, get_backend
from ..backends.rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
                                         YarnParameters)


def get_rope_parameters(config: PretrainedConfig):
    """Try get rope parameters from config."""
    if hasattr(config, 'rope_parameters'):
        # for transformers v5
        return config.rope_parameters
    else:
        return getattr(config, 'rope_scaling', None)


def _get_default_rope_parameters(config: PretrainedConfig):
    """Get default rope parameters."""
    return dict(emb_type=RopeType.Default, scaling_factor=1.0)


def _get_linear_scaling_rope_parameters(config: PretrainedConfig):
    """Get linear rope parameters."""
    rope_scaling = get_rope_parameters(config=config)
    scaling_factor = rope_scaling['factor']
    return dict(emb_type=RopeType.LinearScaling, scaling_factor=scaling_factor)


def _get_dynamic_ntk_parameters(config: PretrainedConfig):
    """Get dynamic ntk parameters."""
    rope_scaling = get_rope_parameters(config=config)
    scaling_factor = rope_scaling['factor']
    return dict(emb_type=RopeType.DynamicNTKScaling, scaling_factor=scaling_factor)


def _get_yarn_parameters(config: PretrainedConfig):
    """Get yarn parameters."""

    def get_mscale(scale, mscale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * mscale * math.log(scale) + 1.0

    rope_scaling = get_rope_parameters(config=config)
    factor = rope_scaling['factor']
    params = YarnParameters()
    params.beta_fast = rope_scaling.get('beta_fast', params.beta_fast)
    params.beta_slow = rope_scaling.get('beta_slow', params.beta_slow)
    mscale = rope_scaling.get('mscale', params.mscale)
    mscale_all_dim = rope_scaling.get('mscale_all_dim', params.mscale_all_dim)
    truncate = rope_scaling.get('truncate', params.truncate)

    if 'attention_factor' in rope_scaling:
        attention_factor = rope_scaling.get('attention_factor')
    else:
        if mscale_all_dim and mscale:
            attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
        else:
            attention_factor = get_mscale(factor)

    params.attention_factor = attention_factor
    params.mscale = mscale
    params.mscale_all_dim = mscale_all_dim
    params.truncate = truncate

    ret = dict(emb_type=RopeType.Yarn, scaling_factor=factor, yarn_params=params)
    if 'original_max_position_embeddings' in rope_scaling:
        ret['max_position_embeddings'] = rope_scaling['original_max_position_embeddings']
    return ret


def _get_longrope_parameters(config: PretrainedConfig):
    """Get longrope parameters."""
    rope_scaling = get_rope_parameters(config=config)
    scaling_factor = rope_scaling.get('factor', 1.0)
    long_factor = rope_scaling['long_factor']
    short_factor = rope_scaling['short_factor']
    original_max_position_embeddings = getattr(config, 'original_max_position_embeddings',
                                               config.max_position_embeddings)
    original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings',
                                                        original_max_position_embeddings)
    params = LongRoPEScalingParameters(
        long_factor=long_factor,
        short_factor=short_factor,
        original_max_position_embeddings=original_max_position_embeddings,
    )
    return dict(emb_type=RopeType.LongRoPEScaling, scaling_factor=scaling_factor, longrope_params=params)


def _get_llama3_parameters(config: PretrainedConfig):
    """Get llama rope parameters."""
    rope_scaling = get_rope_parameters(config=config)
    params = Llama3Parameters()
    scaling_factor = rope_scaling['factor']
    params.low_freq_factor = rope_scaling['low_freq_factor']
    params.high_freq_factor = rope_scaling['high_freq_factor']
    params.original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings',
                                                               params.original_max_position_embeddings)
    return dict(emb_type=RopeType.Llama3, scaling_factor=scaling_factor, llama3_params=params)


def _get_fope_parameters(config: PretrainedConfig):
    """Get fope parameters."""
    # check if fope is used
    rope_scaling = getattr(config, 'rope_scaling', dict())
    fope_keys = ['fope_sep_head', 'fope_num_inv_freq']
    is_fope = any(key in rope_scaling for key in fope_keys)
    if not is_fope:
        return dict()

    params = FopeParameters()
    rope_scaling = get_rope_parameters(config=config)
    params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq))
    params.num_key_value_heads = config.num_key_value_heads
    params.fope_sep_head = rope_scaling['fope_sep_head']
    return dict(fope_params=params)


def build_rotary_params(config: PretrainedConfig):
    """Get scaling_factor rotary params, and emb_type."""
    params = dict(emb_type=RopeType.Default)
    # cannot access config.rope_scaling when the model is "Qwen/Qwen2-Math-RM-72B"
    rope_scaling = get_rope_parameters(config=config)
    if rope_scaling is not None:
        # BC: "rope_type" was originally "type"
        rope_type_str = rope_scaling.get('rope_type', rope_scaling.get('type', 'default'))
        if rope_type_str == 'fope':
            rope_type_str = 'default'
        build_funcs = dict(default=_get_default_rope_parameters,
                           linear=_get_linear_scaling_rope_parameters,
                           dynamic=_get_dynamic_ntk_parameters,
                           yarn=_get_yarn_parameters,
                           longrope=_get_longrope_parameters,
                           su=_get_longrope_parameters,
                           llama3=_get_llama3_parameters)
        params.update(build_funcs[rope_type_str](config))
        params.update(_get_fope_parameters(config))

    # update partial_rotary_factor
    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None
    if partial_rotary_factor is not None:
        params['partial_rotary_factor'] = partial_rotary_factor

    return params


def build_rotary_embedding(dim: int,
                           max_position_embeddings: int = 2048,
                           base: int = 10000,
                           scaling_factor: float = 1.0,
                           yarn_params: YarnParameters = None,
                           longrope_params: LongRoPEScalingParameters = None,
                           llama3_params: Llama3Parameters = None,
                           fope_params: FopeParameters = None,
                           emb_type: RopeType = RopeType.Default,
                           partial_rotary_factor: float = None,
                           device: torch.device = None) -> nn.Module:
    """Build rotary embedding op."""
    backend = get_backend()

    builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)

    # update rope_dim
    if partial_rotary_factor is not None:
        dim = int(dim * partial_rotary_factor)
    impl = builder.build(dim,
                         max_position_embeddings,
                         base,
                         scaling_factor,
                         yarn_params=yarn_params,
                         longrope_params=longrope_params,
                         llama3_params=llama3_params,
                         emb_type=emb_type)

    if fope_params is not None:
        inv_freq = impl.inv_freq
        fope_params.inv_freq = inv_freq
        fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params, device)
        return fope

    return impl


def get_rope_theta(config: PretrainedConfig, default: int = 10000) -> int:
    """Get rope theta from config."""
    if hasattr(config, 'rope_parameters'):
        # for transformers v5
        rope_base = config.rope_parameters.get('rope_theta', default)
    else:
        rope_base = getattr(config, 'rope_theta', default)
    return rope_base


def build_rotary_embedding_from_config(config: PretrainedConfig, device: torch.device = None) -> nn.Module:
    """Build rotary embedding op from config."""
    emb_type = RopeType.LinearScaling
    rope_dim = getattr(config, 'head_dim', None)
    if rope_dim is None:
        rope_dim = config.hidden_size // config.num_attention_heads
    rope_max_pos_emb = config.max_position_embeddings

    rope_base = get_rope_theta(config, default=10000)
    rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
    update_params = build_rotary_params(config)
    rope_params.update(update_params)
    return build_rotary_embedding(**rope_params, device=device)


class ApplyRotaryEmb(nn.Module):
    """Apply rotary embedding."""

    def __init__(self):
        super().__init__()
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.ApplyRotaryEmb)
        self.impl = builder.build()

    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
        """forward."""

        assert cos.dim() <= 3 and sin.dim() <= 3

        need_reshape = False
        if cos.dim() == 3:
            # for fope
            assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'
            need_reshape = True
            query_shape = query.shape
            key_shape = key.shape
            cos = cos.flatten(0, 1)
            sin = sin.flatten(0, 1)
            seq_len = cos.size(0)
            query = query.view(seq_len, -1, query.size(-1))
            key = key.view(seq_len, -1, key.size(-1))

        query, key = self.impl.forward(query, key, cos, sin, inplace)

        if need_reshape:
            query = query.view(query_shape)
            key = key.view(key_shape)
        return query, key


class FopeRotaryEmbedding(nn.Module):
    """Fope rotary embedding."""

    def __init__(self,
                 dim: int,
                 max_position_embeddings: int,
                 attention_scaling: float,
                 params: FopeParameters,
                 device: torch.device = None):
        super().__init__()

        num_key_value_heads, tp = self.update_num_kv_heads(params.num_key_value_heads)
        self.tp = tp
        params.num_key_value_heads = num_key_value_heads

        # build impl
        backend = get_backend()
        builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)
        self.impl = builder.build(dim,
                                  max_position_embeddings=max_position_embeddings,
                                  scaling_factor=attention_scaling,
                                  fope_params=params,
                                  emb_type=RopeType.Fope)

        # setup params
        inv_freq = self.impl.inv_freq
        self.input_dim = inv_freq.shape[-1]
        self.output_dim = inv_freq.shape[-1]
        self.cos_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim, device=device),
                                     requires_grad=False)
        self.sin_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim, device=device),
                                     requires_grad=False)
        if self.tp:
            self.cos_coef.weight_loader = self.weight_loader
            self.sin_coef.weight_loader = self.weight_loader

    @staticmethod
    def update_num_kv_heads(num_key_value_heads: int):
        """Update num_key_value_heads."""
        from lmdeploy.pytorch.distributed import get_dist_manager
        dist_mgr = get_dist_manager()
        dist_ctx = dist_mgr.current_context()
        tp = dist_ctx.dist_config.attn_tp
        # tp = dist_ctx.dist_config.attn_config.tp
        if tp > 1:
            num_key_value_heads = max(1, num_key_value_heads // tp)
        return num_key_value_heads, tp

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        """Weight loader."""
        from lmdeploy.pytorch.distributed import get_tp_world_rank
        world_size, rank = get_tp_world_rank()
        num_key_value_heads = loaded_weight.size(0)

        if num_key_value_heads < world_size:
            n_replicate = world_size // num_key_value_heads
            world_size = num_key_value_heads
            rank = rank // n_replicate

        loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]
        param.copy_(loaded_weight)

    def forward(self, x: Tensor, position_ids: Tensor):
        """forward."""
        return self.impl.forward(x, position_ids, sin_coef=self.sin_coef, cos_coef=self.cos_coef)


================================================
FILE: lmdeploy/pytorch/nn/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def div_up(a: int, b: int):
    """Div up."""
    return (a + b - 1) // b


def get_distribute_size(feature_size: int, world_size: int, rank: int, align: int = 1):
    """Update feature size."""
    assert feature_size % align == 0
    aligned_size = feature_size // align
    # try to make every rank has same amount of feats
    updated_aligned_size = aligned_size // world_size
    # if there are still some remain, given them to
    # each rank
    if rank < aligned_size % world_size:
        updated_aligned_size += 1
    return updated_aligned_size * align


def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int):
    """Chunk aligned."""
    if align == 1:
        return weight.chunk(chunks, dim=dim)
    size = weight.size(dim)
    assert size % align == 0
    aligned_size = size // align

    # try best to evenly split chunks
    align_per_chunk = aligned_size // chunks
    remain = aligned_size % chunks
    sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
    sections = [sec * align for sec in sections]
    return weight.split(sections, dim=dim)


================================================
FILE: lmdeploy/pytorch/paging/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .scheduler import Scheduler

__all__ = ['Scheduler']


================================================
FILE: lmdeploy/pytorch/paging/block_manager/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ...config import CacheConfig
from .base_block_manager import BaseBlockManager
from .default_block_manager import DefaultBlockManager
from .window_block_manager import WindowBlockManager


def build_block_manager(cache_config: CacheConfig) -> BaseBlockManager:
    """Build block manager.

    Args:
        cache_config (CacheConfig):  cache_config.
    """

    num_cpu_blocks = cache_config.num_cpu_blocks
    num_gpu_blocks = cache_config.num_gpu_blocks
    window_size = cache_config.window_size
    num_gpu_reserved = cache_config.num_reserved_gpu_blocks

    if window_size < 0:
        return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved=num_gpu_reserved)
    else:
        return WindowBlockManager(num_gpu_blocks,
                                  num_cpu_blocks,
                                  window_size=window_size,
                                  num_gpu_reserved=num_gpu_reserved)


================================================
FILE: lmdeploy/pytorch/paging/block_manager/base_block_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Dict

import numpy as np

from ...messages import SchedulerSequence


class LogicalMemory:
    """Logical memory blocks."""

    def __init__(self, num_blocks: int) -> None:
        self._num_blocks = num_blocks

        self.phy_map: np.ndarray = np.zeros(self._num_blocks, dtype=np.int64)
        self.ref_count: np.ndarray = np.zeros((self._num_blocks, ), dtype=np.int64)
        self.access_time: np.ndarray = np.zeros((self._num_blocks, ), dtype=np.int64)

    def get_physical_blocks(self, logical_address: np.ndarray):
        """Get physical address."""
        if isinstance(logical_address, np.ndarray) and len(logical_address) == 0:
            return np.empty((0, ), dtype=np.int64)
        return self.phy_map[logical_address]

    def num_blocks(self):
        """Get num blocks."""
        return self._num_blocks


class PhysicalAllocator:
    """The physical block allocator.

    The allocator won't allocate real memory. It is used to support block manager.
    """

    def __init__(self, num_blocks: int, offset: int = 0):
        self._num_blocks = num_blocks
        self._offset = offset

        self._free_blocks = np.arange(num_blocks, dtype=np.int64) + offset
        self._free_count = num_blocks

    def allocate(self, num_blocks: int):
        """Allocate block from block pool."""
        if self.get_num_free_blocks() >= num_blocks:
            num_used = self._num_blocks - self._free_count
            blocks = self._free_blocks[num_used:num_used + num_blocks]
            self._free_count -= num_blocks
            return blocks
        else:
            raise MemoryError('No enough free memory blocks.')

    def free(self, blocks: np.ndarray):
        """Free block to block pool."""
        freed_blocks = blocks
        num_freed_blocks = len(freed_blocks)
        if num_freed_blocks > 0:
            num_used = self._num_blocks - self._free_count
            self._free_blocks[num_used - num_freed_blocks:num_used] = freed_blocks
            self._free_count += num_freed_blocks
        return freed_blocks

    def get_num_free_blocks(self):
        """Get numbers of free blocks."""
        return self._free_count


class LogicalAllocator:
    """The logical block allocator."""

    def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int, num_gpu_reserved: int = 0) -> None:
        self._log_mem = LogicalMemory(num_cpu_blocks + num_gpu_blocks)

        self._cpu_mem_offset = num_gpu_blocks
        num_gpu_blocks -= num_gpu_reserved
        self._gpu_allocator = PhysicalAllocator(num_gpu_blocks, num_gpu_reserved)
        self._cpu_allocator = PhysicalAllocator(num_cpu_blocks, self._cpu_mem_offset)

        num_blocks = self._log_mem.num_blocks()
        self._num_blocks = num_blocks
        self._free_blocks = np.arange(num_blocks)
        self._free_count = num_blocks

    def get_phy_allocator(self, device: str):
        """Get allocator."""
        if device == 'gpu':
            return self._gpu_allocator
        elif device == 'cpu':
            return self._cpu_allocator
        else:
            raise ValueError(f'Unsupported device: {device}')

    def allocate(self, num_blocks: int, device: str = 'gpu'):
        """Allocate logical blocks."""
        if num_blocks == 0:
            return np.empty((0, ), dtype=np.int64)
        phy_allocator = self.get_phy_allocator(device)
        logical_enable = self.get_num_free_blocks() >= num_blocks
        physical_enable = phy_allocator.get_num_free_blocks() >= num_blocks
        if logical_enable and physical_enable:
            num_used = self._num_blocks - self._free_count
            blocks = self._free_blocks[num_used:num_used + num_blocks]
            phy_blocks = phy_allocator.allocate(num_blocks)
            self._log_mem.phy_map.put(blocks, phy_blocks)
            self._log_mem.ref_count.put(blocks, 1)
            self.update_access_time(blocks)
            self._free_count -= num_blocks
            return blocks.copy()
        else:
            raise MemoryError('No enough free memory blocks.')

    def free(self, blocks: np.ndarray):
        """Free logical block."""

        self.add_ref_count(blocks, -1)
        self.update_access_time(blocks)
        ref_count = self.get_ref_count(blocks)
        freed_blocks = blocks[ref_count == 0]
        num_freed_blocks = len(freed_blocks)
        if num_freed_blocks <= 0:
            return

        # free logical
        num_used = self._num_blocks - self._free_count
        self._free_blocks[num_used - num_freed_blocks:num_used] = freed_blocks
        self._free_count += num_freed_blocks

        # free physical
        phy_blocks = self.get_physical_blocks(freed_blocks)

        cpu_blocks = phy_blocks[phy_blocks >= self._cpu_mem_offset]
        gpu_blocks = phy_blocks[phy_blocks < self._cpu_mem_offset]
        if len(cpu_blocks) > 0:
            self._cpu_allocator.free(cpu_blocks)
        if len(gpu_blocks) > 0:
            self._gpu_allocator.free(gpu_blocks)

    def get_num_free_blocks(self):
        """Get numbers of free blocks."""
        return self._free_count

    def get_physical_blocks(self, blocks: np.ndarray):
        """Get physical address."""
        return self._log_mem.get_physical_blocks(blocks)

    def get_ref_count(self, blocks: np.ndarray):
        """Get ref count."""
        return self._log_mem.ref_count[blocks]

    def add_ref_count(self, blocks: np.ndarray, value: np.ndarray):
        """Update ref count."""
        np.add.at(self._log_mem.ref_count, blocks, value)

    def get_access_time(self, blocks: np.ndarray):
        """Get access time."""
        return self._log_mem.access_time[blocks]

    def update_access_time(self, blocks: np.ndarray):
        """Update access time."""
        now = time.perf_counter()
        self._log_mem.access_time[blocks] = now

    def cpu_mem_offset(self):
        """Get cpu mem offset in unified physical memory."""
        return self._cpu_mem_offset

    def count_cpu_blocks(self, blocks: np.ndarray):
        """Count cpu blocks."""
        phy_blocks = self.get_physical_blocks(blocks)
        return np.count_nonzero(phy_blocks >= self.cpu_mem_offset())

    def count_gpu_blocks(self, blocks: np.ndarray):
        """Count gpu blocks."""
        phy_blocks = self.get_physical_blocks(blocks)
        return np.count_nonzero(phy_blocks < self.cpu_mem_offset())

    def update_phy_map(self, log_blocks: np.ndarray, phy_blocks: np.ndarray):
        """Update physical map."""
        assert len(phy_blocks) == len(log_blocks)
        self._log_mem.phy_map.put(log_blocks, phy_blocks)

    def on_device(self, blocks: np.ndarray, device: str):
        """Blocks on given device."""
        if len(blocks) == 0:
            return False

        # TODO: check all blocks
        cpu_mem_offset = self.cpu_mem_offset()

        phy_blocks = self.get_physical_blocks(blocks[:1])
        if phy_blocks[0] < cpu_mem_offset:
            phy_device = 'gpu'
        else:
            phy_device = 'cpu'
        return device == phy_device


BlockTable = np.ndarray


class BaseBlockManager:
    """ABC of block manager.

    Args:
        num_gpu_blocks (int): number of gpu blocks.
        num_cpu_blocks (int): number of cpu blocks.
    """

    def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, num_gpu_reserved: int = 0) -> None:
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks, num_gpu_reserved)

        self.block_tables: Dict[int, BlockTable] = {}

    @classmethod
    def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
        """Get num required blocks."""
        raise NotImplementedError('Not implemented.')

    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Return if physical block can be allocated for given message."""
        raise NotImplementedError('Not implemented.')

    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Allocate physical blocks for given message according to logical
        blocks."""
        raise NotImplementedError('Not implemented.')

    def free(self, msg: SchedulerSequence):
        """Free all physical blocks allocated for the session."""
        raise NotImplementedError('Not implemented.')

    def try_swap_out(self, msg: SchedulerSequence):
        """Try swap msg out."""
        raise NotImplementedError('Not implemented.')

    def try_swap_in(self, msg: SchedulerSequence):
        """Try swap msg in."""
        raise NotImplementedError('Not implemented.')

    def get_block_table(self, msg: SchedulerSequence):
        """Get the block table of given msg.

        Args:
            msg (SchedulerSequence): The msg to get block table.
        """
        logical_blocks = msg.logical_blocks
        return self.allocator.get_physical_blocks(logical_blocks.get_real_blocks())

    def allocate(self, data: SchedulerSequence, prealloc_size: int = 0):
        """Allocate stuff."""
        return self.allocate_msg(data, prealloc_size)

    def get_num_free_gpu_blocks(self) -> int:
        """Get number of free gpu blocks."""
        return self.allocator.get_phy_allocator('gpu').get_num_free_blocks()

    def get_num_free_cpu_blocks(self) -> int:
        """Get number of free cpu blocks."""
        return self.allocator.get_phy_allocator('cpu').get_num_free_blocks()

    def on_device(self, msg: SchedulerSequence, device: str):
        allocator = self.allocator
        logical_blocks = msg.logical_blocks
        return allocator.on_device(logical_blocks.get_real_blocks(), device)


================================================
FILE: lmdeploy/pytorch/paging/block_manager/default_block_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import numpy as np

from ...messages import SchedulerSequence
from .base_block_manager import BaseBlockManager


def _div_up(x, n):
    """Perform div up."""
    return (x + n - 1) // n


BlockTable = np.ndarray


class DefaultBlockManager(BaseBlockManager):
    """Manage the usage of blocks, generate block tables.

    Args:
        num_gpu_blocks (int): number of gpu blocks.
        num_cpu_blocks (int): number of cpu blocks.
    """

    @classmethod
    def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
        """Get num required blocks."""
        num_tokens = obj.num_all_ids + prealloc_size

        num_all_blocks = _div_up(num_tokens, obj.block_size)
        return max(0, num_all_blocks - len(obj.logical_blocks))

    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Return if physical block can be allocated for given message."""
        num_required_blocks = self.num_required_blocks(msg, prealloc_size)
        num_free_phy = self.get_num_free_gpu_blocks()
        return num_required_blocks <= num_free_phy

    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Allocate physical blocks for given message according to logical
        blocks."""
        logical_blocks = msg.logical_blocks
        num_required_blocks = self.num_required_blocks(msg, prealloc_size)
        if num_required_blocks > 0:
            blocks = self.allocator.allocate(num_required_blocks, 'gpu')
            logical_blocks.append(blocks)

    def free(self, msg: SchedulerSequence):
        """Free all physical blocks allocated for the session."""
        self.allocator.free(msg.logical_blocks.get_real_blocks())
        msg.logical_blocks.reset()

    def try_swap_out(self, msg: SchedulerSequence):
        """Try swap msg out."""
        swap_map = dict()
        logical_blocks = msg.logical_blocks
        cpu_mem_offset = self.allocator.cpu_mem_offset()
        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
        cpu_allocator = self.allocator.get_phy_allocator('cpu')
        gpu_allocator = self.allocator.get_phy_allocator('gpu')

        def _can_swap():
            """Check swap."""
            if len(logical_blocks) == 0:
                return False

            # we only support all blocks of a sequence on same device
            if phy_blocks[0] >= cpu_mem_offset:
                return False

            # no free blocks
            num_free = self.get_num_free_cpu_blocks()
            if num_free < len(phy_blocks):
                return False

            # don't swap sequence with multiple reference
            ref_count = self.allocator.get_ref_count(logical_blocks)
            if np.count_nonzero(ref_count != 1) > 0:
                return False

            return True

        def _do_swap():
            """Perform swap."""
            new_blocks = cpu_allocator.allocate(len(logical_blocks))

            old_blocks = phy_blocks
            swap_map = dict(zip(old_blocks, new_blocks - self.num_gpu_blocks))

            gpu_allocator.free(old_blocks)
            self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks)
            return True, swap_map

        if not _can_swap():
            return False, swap_map
        else:
            return _do_swap()

    def try_swap_in(self, msg: SchedulerSequence):
        """Try swap msg in."""
        swap_map = dict()
        logical_blocks = msg.logical_blocks
        cpu_mem_offset = self.allocator.cpu_mem_offset()
        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
        cpu_allocator = self.allocator.get_phy_allocator('cpu')
        gpu_allocator = self.allocator.get_phy_allocator('gpu')

        def _can_swap():
            """Check swap."""
            if len(logical_blocks) == 0:
                return False

            # we only support all blocks of a sequence on same device
            if phy_blocks[0] < cpu_mem_offset:
                return False

            # no free blocks
            num_free = self.get_num_free_gpu_blocks()
            if num_free < len(phy_blocks):
                return False

            # don't swap sequence with multiple reference
            ref_count = self.allocator.get_ref_count(logical_blocks)
            if np.count_nonzero(ref_count != 1) > 0:
                return False

            return True

        def _do_swap():
            """Perform swap."""
            new_blocks = gpu_allocator.allocate(len(logical_blocks))

            old_blocks = phy_blocks
            swap_map = dict(zip(old_blocks - self.num_gpu_blocks, new_blocks))

            cpu_allocator.free(old_blocks)
            self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks)
            return True, swap_map

        if not _can_swap():
            return False, swap_map
        else:
            return _do_swap()


================================================
FILE: lmdeploy/pytorch/paging/block_manager/window_block_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

from ...block import LogicalTokenBlocks
from ...messages import SchedulerSequence
from .default_block_manager import DefaultBlockManager

BlockTable = np.ndarray


def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int):
    """Num blocks to free."""
    history_len = seq.num_history_ids
    if seq.num_history_ids <= window_size:
        return 0
    block_size = seq.block_size
    num_blocks = len(seq.logical_blocks)
    win_start_block_id = (history_len - window_size) // block_size
    win_end_block_id = (history_len - 1) // block_size
    num_win_blocks = win_end_block_id - win_start_block_id + 1
    return max(0, num_blocks - num_win_blocks)


class WindowBlockManager(DefaultBlockManager):
    """Manage the usage of blocks, generate block tables.

    Args:
        num_gpu_blocks (int): number of gpu blocks.
        num_cpu_blocks (int): number of cpu blocks.
    """

    def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, window_size: int, num_gpu_reserved: int = 0):
        super().__init__(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved)
        assert window_size > 0, ('expect window size > 0, '
                                 f'but get window_size = {window_size}')
        self.window_size = window_size

    def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0):
        """Get num required blocks."""

        # blocks is not enough
        if obj.num_history_ids <= self.window_size:
            return super().num_required_blocks(obj, prealloc_size)

        return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size

    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Return if physical block can be allocated for given message."""
        num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)
        num_required_blocks = self.num_required_blocks(msg, prealloc_size)
        num_free_phy = self.get_num_free_gpu_blocks()
        return num_required_blocks <= num_free_phy + num_drop_blocks

    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):
        """Allocate physical blocks for given message according to logical
        blocks."""
        logical_blocks = msg.logical_blocks

        def __get_droped_blocks(num_drop_blocks):
            """Get dropped blocks."""
            nonlocal logical_blocks
            droped_blocks = None
            if num_drop_blocks > 0:
                remain_blocks = logical_blocks[num_drop_blocks:]
                droped_blocks = logical_blocks[:num_drop_blocks]
                logical_blocks = LogicalTokenBlocks(remain_blocks)
                msg.logical_blocks = logical_blocks
            return droped_blocks

        def __reuse_droped_blocks(num_required_blocks, num_drop_blocks, droped_blocks):
            """Reuse dropped blocks."""
            num_used_blocks = min(num_drop_blocks - num_required_blocks, num_required_blocks)
            if num_used_blocks > 0:
                reused_blocks = droped_blocks[:num_used_blocks]
            else:
                reused_blocks = droped_blocks
            logical_blocks.append(reused_blocks)

            if num_used_blocks > 0:
                droped_blocks = droped_blocks[num_used_blocks:]
            else:
                num_used_blocks = num_drop_blocks
                droped_blocks = None
            num_required_blocks = num_required_blocks - num_used_blocks
            return num_required_blocks, droped_blocks

        num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)
        num_required_blocks = self.num_required_blocks(msg, prealloc_size)
        msg.num_ignored_history += num_drop_blocks * msg.block_size

        droped_blocks = __get_droped_blocks(num_drop_blocks)

        if num_required_blocks > 0:
            if num_drop_blocks > 0:
                num_required_blocks, droped_blocks = __reuse_droped_blocks(num_required_blocks, num_drop_blocks,
                                                                           droped_blocks)
            if num_required_blocks > 0:
                blocks = self.allocator.allocate(num_required_blocks, 'gpu')
                logical_blocks.append(blocks)

        # drop unused blocks
        if droped_blocks is not None:
            self.allocator.free(droped_blocks)


================================================
FILE: lmdeploy/pytorch/paging/block_trie.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import heapq
from dataclasses import dataclass
from typing import Dict, Set

import numpy as np

from lmdeploy.pytorch.messages import SchedulerSequence

from ..config import CacheConfig
from .block_manager import BaseBlockManager


@dataclass
class PrefixCacheStats:
    """Prefix caching stats."""
    num_query_tokens: int = 0
    num_hit_tokens: int = 0

    def reset(self):
        self.num_query_tokens = 0
        self.num_hit_tokens = 0

    def hit_rate(self):
        return 0.0 if self.num_query_tokens <= 0 else float(self.num_hit_tokens) / self.num_query_tokens


class Node:
    """Node of block trie."""

    def __init__(self, hash_key: int, block: int, tokens: np.ndarray, num_matched: int = 0):
        self.hash_key = hash_key
        self.block = block
        self.tokens = tokens
        self.num_matched = num_matched
        self.children: Dict[int, 'Node'] = dict()
        self._parent: 'Node' = None

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, val: 'Node'):
        old_parent = self._parent
        if old_parent is not None:
            old_parent.children.pop(self.hash_key)
        if val is not None:
            val.children[self.hash_key] = self
        self._parent = val

    def __lt__(self, other):
        return True

    def __le__(self, other):
        return True


class BlockTrie:
    """Block trie for prefix caching."""

    def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager):
        self.block_manager = block_manager
        self.cache_config = cache_config
        self.allocator = self.block_manager.allocator
        self.block_size = cache_config.block_size
        self.enable = self.cache_config.enable_prefix_caching

        # caches with different adapter should not be shared.
        self._roots: Dict[str, Node] = dict()
        self.leaves: Set[Node] = set()
        self.stats = PrefixCacheStats()

    def hit_rate(self):
        """Get hit rate."""
        return self.stats.hit_rate()

    def get_root(self, adapter_name: str):
        """Get root by adapter name."""
        if adapter_name not in self._roots:
            self._roots[adapter_name] = Node(-1, -1, None)
        return self._roots[adapter_name]

    def match(self, seq: SchedulerSequence):
        """Match sequence and cache."""
        if not self.enable:
            return

        block_size = self.block_size
        matched_blocks = []

        logical_blocks = seq.logical_blocks
        curr: Node = getattr(logical_blocks, 'last_shared_node', None)
        if curr is None:
            curr = self.get_root(seq.adapter_name)
        init_num_matched = curr.num_matched
        num_matched = curr.num_matched

        def __match_success(node: Node):
            nonlocal curr, num_matched
            matched_blocks.append(node.block)
            curr = node
            num_matched += block_size

        while num_matched + block_size < seq.num_valid_ids:
            curr_tokens = seq.history_cache[num_matched:num_matched + block_size]

            key = hash(('random', tuple(curr_tokens)))
            if key not in curr.children:
                break

            child = curr.children[key]
            if not np.array_equal(curr_tokens, child.tokens):
                break

            __match_success(child)

        if len(matched_blocks) > 0:
            matched_blocks = np.array(matched_blocks)
            self.allocator.update_access_time(matched_blocks)
            self.allocator.add_ref_count(matched_blocks, 1)
            seq.logical_blocks.append(matched_blocks)
            seq.set_step(num_matched)

        # record prefix hit
        self.stats.num_query_tokens += seq.num_all_ids - init_num_matched
        self.stats.num_hit_tokens += num_matched - init_num_matched

        seq.logical_blocks.last_shared_node = curr

    def allocate(self, seq: SchedulerSequence):
        """allocate."""
        if not self.enable:
            return

        block_size = self.block_size
        logical_blocks = seq.logical_blocks
        node: Node = getattr(logical_blocks, 'last_shared_node', None)
        if node is None:
            node = self.get_root(seq.adapter_name)
            logical_blocks.last_shared_node = node

        num_matched = node.num_matched
        num_valid_ids = seq.num_valid_ids

        if num_matched + block_size > num_valid_ids:
            return

        if len(node.children) == 0 and node.parent is not None:
            self.leaves.remove(node)

        block_id = num_matched // block_size
        blocks = []
        free_blocks = []
        while num_matched + block_size <= num_valid_ids:
            curr_tokens = seq.history_cache[num_matched:num_matched + block_size]

            block = logical_blocks[block_id]

            hash_key = hash(('random', tuple(curr_tokens)))
            parent = node
            if hash_key in parent.children:
                child = parent.children[hash_key]
                if not np.array_equal(curr_tokens, child.tokens):
                    break
                node = child
                free_blocks.append(block)
                logical_blocks[block_id] = node.block
            else:
                node = Node(hash_key=hash_key, block=block, tokens=curr_tokens, num_matched=num_matched + block_size)
                node.parent = parent
            blocks.append(node.block)
            num_matched += block_size
            block_id += 1

        logical_blocks.last_shared_node = node
        if node.parent is not None and len(node.children) == 0:
            # ignore root
            self.leaves.add(node)
        if len(blocks) > 0:
            self.allocator.add_ref_count(np.array(blocks), 1)
        if len(free_blocks) > 0:
            self.allocator.free(np.array(free_blocks))

    def evict(self, max_num_blocks: int):
        """evict."""
        if not self.enable:
            return 0

        def __remove_leaf(leaves, evicted_blocks):
            _, leaf = heapq.heappop(leaves)
            evicted_blocks.append(leaf.block)
            parent = leaf.parent
            leaf.parent = None
            self.leaves.remove(leaf)
            return parent

        def __add_leaf(leaves, parent):
            self.leaves.add(parent)
            if self.allocator.get_ref_count(parent.block) == 1:
                access_time = self.allocator.get_access_time(parent.block)
                heapq.heappush(leaves, (access_time, parent))

        if len(self.leaves) == 0:
            return 0

        evicted_blocks = []
        leaves = list(self.leaves)

        # filter ref-cnt == 1 (trie own one block ref)
        leave_blocks = np.array(list(leaf.block for leaf in leaves))
        ref_cnt = self.allocator.get_ref_count(leave_blocks)
        indices = (ref_cnt == 1).nonzero()[0]
        if len(indices) == 0:
            return 0

        # make heap
        leaves = list(leaves[i] for i in indices)
        access_times = self.allocator.get_access_time(leave_blocks)
        access_times = list(access_times[i] for i in indices)
        leaves = list(zip(access_times, leaves))
        heapq.heapify(leaves)

        while len(leaves) > 0 and len(evicted_blocks) < max_num_blocks:
            parent = __remove_leaf(leaves, evicted_blocks)
            if parent.parent is None:
                # ignore root
                continue
            if len(parent.children) == 0:
                __add_leaf(leaves, parent)

        self.allocator.free(np.array(evicted_blocks))

        return len(evicted_blocks)


================================================
FILE: lmdeploy/pytorch/paging/eviction_helper/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def build_eviction_helper(scheduler, eviction_type: str):
    """Build eviction helper."""
    if eviction_type == 'copy':
        logger.warning('`copy` eviction has been deprecated, '
                       'use `recompute` instead.')
        eviction_type = 'recompute'
    if eviction_type == 'recompute':
        from .recompute_eviction_helper import RecomputeEvictionHelper
        return RecomputeEvictionHelper(scheduler)
    else:
        raise TypeError(f'Unknown eviction type: {eviction_type}')


================================================
FILE: lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from ...messages import SchedulerSequence
from ..scheduler import Scheduler

SeqList = List[SchedulerSequence]


class BaseEvictionHelper:
    """Base eviction helper."""

    def __init__(self, scheduler: Scheduler):
        self.scheduler = scheduler
        self.block_manager = scheduler.block_manager
        self.block_trie = scheduler.block_trie
        self.state_manager = scheduler.state_manager
        self.cache_config = scheduler.cache_config

    def need_swap_in(self, seq: SchedulerSequence):
        """Sequence need swap in."""
        raise NotImplementedError('Not implemented.')

    def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):
        """Evict seqs."""
        raise NotImplementedError('Not implemented.')


================================================
FILE: lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from ...messages import SchedulerSequence
from ..scheduler import Scheduler
from .base_eviction_helper import BaseEvictionHelper


class RecomputeEvictionHelper(BaseEvictionHelper):
    """Recompute eviction."""

    def __init__(self, scheduler: Scheduler):
        super().__init__(scheduler)

        if len(self.cache_config.states_shapes) == 0:
            self.evict_for_seq = self._evict_for_seq_default
        else:
            self.evict_for_seq = self._evict_for_ssm

    def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence],
                               prealloc_size: int):
        """Evict seqs."""
        block_manager = self.block_manager
        block_trie = self.block_trie
        num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size)

        if block_manager.get_num_free_gpu_blocks() >= num_required_blocks:
            return True

        success = False
        while len(evictable_seqs) > 0:
            evict_seq = evictable_seqs.pop(0)

            # skip sequence with no blocks
            if evict_seq.num_blocks == 0:
                continue

            evict_seq.state.free()
            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
            if num_req <= 0:
                success = True
                break

            block_trie.evict(num_req)
            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
            if num_req <= 0:
                success = True
                break

        # for empty evictable_seqs case
        num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks()
        if num_req > 0:
            block_trie.evict(num_req)
            if num_required_blocks <= block_manager.get_num_free_gpu_blocks():
                success = True

        return success

    def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):
        """Evict seqs."""
        block_manager = self.block_manager
        state_manager = self.state_manager
        block_trie = self.block_trie
        num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size)
        has_free_state = state_manager.get_num_free() > 0

        if has_free_state and block_manager.get_num_free_gpu_blocks() >= num_required_blocks:
            return True

        success = False
        while len(evictable_seqs) > 0:
            evict_seq = evictable_seqs.pop(0)

            # skip sequence with no blocks
            if evict_seq.num_blocks == 0 and evict_seq.logical_state < 0:
                continue

            # free sequence
            evict_seq.state.free()
            has_free_state = True
            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
            if num_req <= 0:
                success = True
                break

            # clear cached prefix
            block_trie.evict(num_req)
            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())
            if num_req <= 0:
                success = True
                break

        if not has_free_state:
            return False

        # for empty evictable_seqs case
        num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks()
        if num_req > 0:
            block_trie.evict(num_req)
            if num_required_blocks <= block_manager.get_num_free_gpu_blocks():
                success = True

        return success


================================================
FILE: lmdeploy/pytorch/paging/scheduler.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm

from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, List

from torch.profiler import record_function

from lmdeploy.messages import EventType, ScheduleMetrics
from lmdeploy.utils import get_logger

from ..config import CacheConfig, SchedulerConfig
from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta
from .block_manager import build_block_manager
from .block_trie import BlockTrie
from .eviction_helper import build_eviction_helper
from .state_manager import build_state_manager

logger = get_logger('lmdeploy')

MapType = Dict[int, int]
SeqList = List[SchedulerSequence]


@dataclass
class SchedulerOutput:
    """Output of schedule."""

    running: SeqList
    swap_in_map: MapType
    swap_out_map: MapType
    copy_map: MapType


class Scheduler:
    """Tools to schedule next step.

    Args:
        scheduler_config (SchedulerConfig): The config of scheduler.
        cache_config (CacheConfig): The config of cache info.
    """

    def __init__(
        self,
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
        seq_meta: SequenceMeta = None,
    ) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.sessions: Dict[int, SchedulerSession] = OrderedDict()

        # For Disaggregation
        self.locked_sessions: Dict[int, SchedulerSession] = OrderedDict()

        self.block_manager = build_block_manager(cache_config)
        self.block_trie = BlockTrie(self.cache_config, self.block_manager)
        self.state_manager = build_state_manager(self.cache_config)
        self.is_ssm = len(self.cache_config.states_shapes) > 0

        self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type)

        seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size)
        self.seq_meta = seq_meta
        self.seq_manager = SequenceManager(seq_meta)

    @staticmethod
    def create_status_list_property(status: MessageStatus):
        """Create status list property."""

        def _get_status_list(self):
            seq_map = self.seq_manager.get_sequences(status)
            return list(seq_map.values())

        return property(_get_status_list)

    @staticmethod
    def create_num_status_method(status: MessageStatus):
        """Create num status method."""

        def _num_status(self):
            return self.seq_manager.num_sequences(status)

        return _num_status

    @staticmethod
    def create_has_status_method(status: MessageStatus):
        """Create has status method."""

        def _has_status(self):
            return self.seq_manager.num_sequences(status) > 0

        return _has_status

    # status list properties
    waiting = create_status_list_property(MessageStatus.WAITING)
    ready = create_status_list_property(MessageStatus.READY)
    hanging = create_status_list_property(MessageStatus.STOPPED)
    running = create_status_list_property(MessageStatus.RUNNING)
    migration_waiting = create_status_list_property(MessageStatus.MIGRATION_WAITING)
    migration_done = create_status_list_property(MessageStatus.MIGRATION_DONE)

    # num status methods
    num_waiting = create_num_status_method(MessageStatus.WAITING)
    num_ready = create_num_status_method(MessageStatus.READY)
    num_running = create_num_status_method(MessageStatus.RUNNING)
    num_migration_waiting = create_num_status_method(MessageStatus.MIGRATION_WAITING)
    num_migration_done = create_num_status_method(MessageStatus.MIGRATION_DONE)

    # has status methods
    has_waiting = create_has_status_method(MessageStatus.WAITING)
    has_ready = create_has_status_method(MessageStatus.READY)
    has_migration_waiting = create_has_status_method(MessageStatus.MIGRATION_WAITING)
    has_migration_done = create_has_status_method(MessageStatus.MIGRATION_DONE)

    def add_session(self, session_id: int):
        """Add new session.

        Args:
            session_id (int): New session id.
        """
        assert session_id not in self.sessions
        session = SchedulerSession(session_id, seq_manager=self.seq_manager, scheduler=self)
        self.sessions[session_id] = session
        return session

    def _schedule_migration(self):
        migration_ready: SeqList = []
        migrating_token_count = 0

        def _to_running(seq: SchedulerSequence):
            """To running."""
            seq.state.activate()
            migration_ready.append(seq)
            nonlocal migrating_token_count
            migrating_token_count += seq.num_token_ids

        def __evict_for_seq(seq: SchedulerSequence, waiting):
            """Evict until can append."""
            from itertools import chain

            hanging = reversed(self.hanging)
            waiting = reversed(waiting)
            evictable = list(chain(hanging, waiting))
            return self.eviction_helper.evict_for_seq(seq, evictable, 0)

        def _reorder_migrating():
            """Reorder waiting."""
            return sorted(self.migration_waiting, key=lambda seq: seq.arrive_time)

        migration_waiting = _reorder_migrating()

        max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running()
        while len(migration_waiting) > 0 and len(migration_ready) < max_batches:
            seq = migration_waiting.pop(0)
            self.block_trie.match(migration_waiting)
            if not __evict_for_seq(seq, migration_waiting):
                break

            # allocate session memory
            self.block_manager.allocate(seq)
            _to_running(seq)

        return migration_ready

    @record_function('schedule_prefill')
    def _schedule_prefill(self, prealloc_size: int = 0):
        """Schedule for prefilling."""

        max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running()
        eviction_helper = self.eviction_helper
        swap_out_map: MapType = dict()
        swap_in_map: MapType = dict()
        copy_map: MapType = dict()
        running: SeqList = []
        token_count = 0

        def _to_running(seq: SchedulerSequence):
            """To running."""
            seq.state.activate()
            running.append(seq)
            nonlocal token_count
            token_count += seq.num_token_ids

        def __evict_for_seq(seq: SchedulerSequence, waiting):
            """Evict until can append."""
            from itertools import chain
            hanging = reversed(self.hanging)
            waiting = reversed(waiting)
            evictable = list(chain(hanging, waiting))
            return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)

        def _reorder_waiting():
            """Reorder waiting."""
            return sorted(self.waiting, key=lambda seq: seq.arrive_time)

        num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING)
        if (len(running) >= max_batches or num_waiting == 0):
            return running, swap_in_map, swap_out_map, copy_map

        waiting = _reorder_waiting()
        while len(waiting) > 0 and len(running) < max_batches:
            seq = waiting.pop(0)

            if (len(running) > 0 and token_count + seq.num_token_ids > self.cache_config.max_prefill_token_num):
                break

            self.block_trie.match(seq)

            if not __evict_for_seq(seq, waiting):
                break

            # allocate session memory
            self.block_manager.allocate(seq, prealloc_size)
            self.block_trie.allocate(seq)
            if self.is_ssm:
                self.state_manager.allocate(seq)
            _to_running(seq)

            seq.record_event(EventType.SCHEDULED)

        return running, swap_in_map, swap_out_map, copy_map

    @record_function('schedule_decoding')
    def _schedule_decoding(self, prealloc_size: int = 0):
        """Schedule decoding."""

        def _reorder_running():
            """Reorder running."""
            return sorted(self.ready, key=lambda seq: seq.arrive_time)

        running = _reorder_running()
        assert len(running) != 0

        eviction_helper = self.eviction_helper
        swap_out_map: MapType = dict()
        swap_in_map: MapType = dict()
        copy_map: MapType = dict()

        def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int):
            """Evict until can append."""
            if num_required_blocks == 0:
                # No need to evict, just return True.
                return True
            elif num_required_blocks < self.block_manager.get_num_free_gpu_blocks():
                # Enough free blocks, just return True.
                return True

            from itertools import chain
            hanging = reversed(self.hanging)
            waiting = reversed(self.waiting)
            evictable = list(chain(hanging, waiting))
            return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)

        # 1. running
        while len(running) > 0:
            # token + n
            seq = running.pop(0)
            num_required_blocks = self.block_manager.num_required_blocks(seq, prealloc_size)
            assert seq.num_blocks + num_required_blocks <= self.block_manager.num_gpu_blocks, (
                'Sequence requires more blocks than total gpu blocks.')

            while not __evict_for_seq(seq, num_required_blocks):
                if len(running) == 0:
                    break
                seq_preempted = running.pop(-1)
                seq_preempted.state.evict()

            if self.block_manager.get_num_free_gpu_blocks() < num_required_blocks:
                seq.state.evict()
                continue

            self.block_manager.allocate(seq, prealloc_size)
            self.block_trie.allocate(seq)

        return self.ready[:self.scheduler_config.max_batches], swap_in_map, swap_out_map, copy_map

    def schedule(self, is_prefill: bool, prealloc_size: int = 0):
        """Schedule inputs for next steps."""
        if is_prefill:
            output = self._schedule_prefill(prealloc_size)
        else:
            output = self._schedule_decoding(prealloc_size)
        running, swap_in_map, swap_out_map, copy_map = output

        return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)

    @record_function('schedule_running')
    def schedule_running(self, running: SeqList, num_decode_tokens: int = 1, prealloc_size: int = 1):
        """Schedule running sequences.

        This function is used to add blocks for running sequences request would be marked as invalid if not enough
        blocks can be allocated.
        """
        assert len(running) > 0
        eviction_helper = self.eviction_helper

        valid_mask = [True for _ in running]

        # loop over reverse running
        rev_running = reversed(running)
        for idx, seq in enumerate(rev_running):
            if not seq.status == MessageStatus.RUNNING:
                valid_mask[idx] = False
                continue

            num_required_blocks = self.block_manager.num_required_blocks(seq, num_decode_tokens)

            if num_required_blocks == 0:
                continue

            if eviction_helper.evict_for_seq(seq, self.hanging + self.waiting, prealloc_size):
                self.block_manager.allocate(seq, prealloc_size)
                self.block_trie.allocate(seq)
                continue

            # running to ready
            seq.state.deactivate()
            # ready to waiting
            seq.state.evict()
            valid_mask[idx] = False
        valid_mask = list(reversed(valid_mask))
        return valid_mask

    def stop_session(self, session_id: int):
        """Stop session.

        Args:
            session_id (int): The session id.
        """
        assert session_id in self.sessions
        session = self.sessions[session_id]
        for seq in session.sequences.values():
            seq.state.stop()

    def end_session(self, session_id: int):
        """End session.

        Args:
            session_id (int): The session id.
        """
        if self.seq_meta.sampling_strategy is not None:
            self.seq_meta.sampling_strategy.on_session_end(session_id)
        session = self.sessions[session_id]
        seqs = list(session.sequences.values())
        for seq in seqs:
            # stop session so it won't get scheduled again
            seq.state.stop()
            session.remove_sequence(seq)
        self.sessions.pop(session_id)

    def has_unfinished(self):
        """Check if there are any unfinished message."""
        return self.has_ready() or self.has_waiting() or self.has_migration_done()

    def get_block_tables(self, seqs: SeqList):
        """Get block table of the sequences."""
        return [self.block_manager.get_block_table(seq) for seq in seqs]

    def evict_seqs(self, running: SeqList):
        """Evict running sequences."""
        for seq in running:
            seq.state.evict()

    def activate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY):
        """Lock running sequence."""
        for seq in running:
            if seq.status == filter_status:
                seq.state.activate()

    def deactivate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING):
        for seq in running:
            if seq.status == filter_status:
                seq.state.deactivate()

    @contextmanager
    def seqs_activation(self, running: SeqList):
        """Context manager to activate and deactivate sequences."""
        self.activate_seqs(running, MessageStatus.READY)
        try:
            yield running
        finally:
            self.deactivate_seqs(running, MessageStatus.RUNNING)

    def activate_migration_seqs(self, running: SeqList):
        """Lock running sequence."""
        return self.activate_seqs(running, filter_status=MessageStatus.MIGRATION_READY)

    def deactivate_migration_seqs(self, running: SeqList):
        """Unlock running migration."""
        return self.deactivate_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING)

    @contextmanager
    def seqs_migration_activation(self, running: SeqList):
        """Context manager to activate and deactivate sequences."""
        self.activate_migration_seqs(running)
        try:
            yield running
        finally:
            self.deactivate_migration_seqs(running)

    def collect_migration_done(self):
        for seq in self.migration_done:
            seq.state.activate()

    @property
    def schedule_metrics(self):
        return ScheduleMetrics(
            active_seqs=self.num_running(),
            waiting_seqs=self.num_waiting() + self.num_ready(),
            total_blocks=self.block_manager.num_gpu_blocks,
            free_blocks=self.block_manager.get_num_free_gpu_blocks(),
            prefix_cache_hit_rate=self.block_trie.hit_rate(),
        )


================================================
FILE: lmdeploy/pytorch/paging/seq_states/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .states import StateBase, build_seq_state  # noqa: F401


================================================
FILE: lmdeploy/pytorch/paging/seq_states/states.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from lmdeploy.pytorch.messages import MessageStatus, SchedulerSequence

if TYPE_CHECKING:
    from lmdeploy.pytorch.paging import Scheduler


def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'):
    """Free the sequence."""
    if seq.num_blocks > 0:
        scheduler.block_manager.free(seq)
    if seq.logical_state >= 0:
        scheduler.state_manager.free(seq)
    seq.set_step(0)


class StateBase:
    status = None
    _registry = dict()

    def __init_subclass__(cls, **kargs) -> None:
        super().__init_subclass__(**kargs)
        if cls.status:
            cls._registry[cls.status] = cls

    @classmethod
    def build(cls, scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> 'StateBase':
        """Build sequence state."""
        if status not in cls._registry:
            raise NotImplementedError(f'Unsupported status {status} for building seq state.')
        return cls._registry[status](seq, scheduler)

    def __init__(self, seq: SchedulerSequence, scheduler: 'Scheduler'):
        self.seq = seq
        self.scheduler = scheduler

    def to_state(self, new_state):
        """Transition to a new state."""
        self.scheduler.seq_manager.update_sequence_status(self.seq, new_state.status)
        self.seq.set_state(new_state(self.seq, self.scheduler))

    def evict(self):
        """Evict the state."""
        raise NotImplementedError(f'evict not implemented for state {self.status}')

    def activate(self):
        """Activate the state."""
        raise NotImplementedError(f'activate not implemented for state {self.status}')

    def deactivate(self):
        """Deactivate the state."""
        raise NotImplementedError(f'deactivate not implemented for state {self.status}')

    def finish(self):
        """Finish the state."""
        raise NotImplementedError(f'finish not implemented for state {self.status}')

    def stop(self):
        """Stop the state."""
        self.to_state(StoppedState)

    def free(self):
        """Free the state."""
        _free_seq(self.seq, self.scheduler)


class WaitingState(StateBase):
    """State for waiting sequences."""
    status = MessageStatus.WAITING

    def activate(self):
        """From WAITING to READY."""
        num_req_blocks = self.scheduler.block_manager.num_required_blocks(self.seq)
        assert self.seq.num_blocks >= num_req_blocks
        if self.scheduler.is_ssm:
            assert self.seq.logical_state >= 0
        self.to_state(ReadyState)

    def evict(self):
        self.to_state(WaitingState)


class ReadyState(StateBase):
    """State for ready sequences."""
    status = MessageStatus.READY

    def activate(self):
        """From READY to RUNNING."""
        self.to_state(RunningState)

    def evict(self):
        self.to_state(WaitingState)


class StoppedState(StateBase):
    """State for stopped sequences."""
    status = MessageStatus.STOPPED

    def activate(self):
        """From STOPPED to WAITING."""
        assert self.seq.num_token_ids > 0
        self.to_state(WaitingState)

    def evict(self):
        self.to_state(StoppedState)


class RunningState(StateBase):
    """State for running sequences."""
    status = MessageStatus.RUNNING

    def deactivate(self):
        self.to_state(ReadyState)

    def finish(self):
        if self.seq.preserve_cache:
            self.to_state(ToBeMigratedState)
        else:
            self.to_state(StoppedState)


class ToBeMigratedState(StateBase):
    """State for to be migrated sequences."""
    status = MessageStatus.TO_BE_MIGRATED

    def finish(self):
        self.to_state(StoppedState)


class MigrationWaitingState(StateBase):
    """State for migration waiting sequences."""
    status = MessageStatus.MIGRATION_WAITING

    def activate(self):
        self.to_state(MigrationReadyState)

    def evict(self):
        self.to_state(MigrationWaitingState)


class MigrationReadyState(StateBase):
    """State for migration ready sequences."""
    status = MessageStatus.MIGRATION_READY

    def activate(self):
        self.to_state(MigrationRunningState)

    def evict(self):
        self.to_state(MigrationWaitingState)


class MigrationDoneState(StateBase):
    """State for migration done sequences."""
    status = MessageStatus.MIGRATION_DONE

    def activate(self):
        self.to_state(WaitingState)

    def finish(self):
        self.to_state(WaitingState)


class MigrationRunningState(StateBase):
    """State for migration running sequences."""
    status = MessageStatus.MIGRATION_RUNNING

    def deactivate(self):
        self.to_state(MigrationDoneState)

    def finish(self):
        self.to_state(MigrationDoneState)


def build_seq_state(scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> StateBase:
    """Build sequence state."""
    return StateBase.build(scheduler, seq, status)


================================================
FILE: lmdeploy/pytorch/paging/state_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

from lmdeploy.pytorch.config import CacheConfig
from lmdeploy.pytorch.messages import SchedulerSequence


class StateAllocator:
    """State allocator."""

    def __init__(self, num_states: int, offset: int = 0):
        self.num_states = num_states
        self._free_states = np.arange(offset, offset + num_states, dtype=np.int64)
        self._free_count = num_states

    def allocate(self):
        """allocate."""
        if self.get_num_free() == 0:
            raise RuntimeError('No free states.')
        alloc_id = self._free_states[-self._free_count]
        self._free_count -= 1
        return alloc_id

    def free(self, state_id: int):
        """free."""
        if self._free_count >= self.num_states:
            raise RuntimeError('All states are free.')
        self._free_count += 1
        self._free_states[-self._free_count] = state_id

    def get_num_free(self):
        return self._free_count


class StateManager:

    def __init__(self, num_states: int, num_reserved: int = 0):
        if num_states is None:
            num_states = 1
        self.allocator = StateAllocator(num_states, offset=num_reserved)

    def is_allocated(self, seq: SchedulerSequence):
        """Check if a sequence is allocated."""
        return seq.logical_state >= 0

    def allocate(self, seq: SchedulerSequence):
        """Allocate states for a sequence."""
        if self.is_allocated(seq):
            return None
        seq.logical_state = self.allocator.allocate()

    def free(self, seq: SchedulerSequence):
        """Free states for a sequence."""
        if not self.is_allocated(seq):
            return None
        self.allocator.free(seq.logical_state)
        seq.logical_state = -1

    def get_num_free(self):
        """Get num free."""
        return self.allocator.get_num_free()


def build_state_manager(cache_config: CacheConfig) -> StateManager:
    """Build state manager."""
    num_states = cache_config.num_state_caches
    # state is different from block, we always reserve one state for system use
    num_reserved = 1
    return StateManager(num_states, num_reserved)


================================================
FILE: lmdeploy/pytorch/ray.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
import time
from typing import Dict, List

import ray
from ray.util.placement_group import PlacementGroup

from lmdeploy.pytorch.devices import get_device_manager
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')
PG_WAIT_TIMEOUT = 1800


def get_device_str(device_type: str = None) -> str:
    """Get device str."""
    device_type = device_type or get_device_manager().current_context().device_type
    if device_type in ['cuda', 'maca']:
        device_type = 'GPU'
    elif device_type == 'ascend':
        device_type = 'NPU'
    elif device_type == 'camb':
        device_type = 'MLU'
    else:
        raise ValueError(f'Unsupported device type: {device_type}')

    return device_type


def get_resource_kwargs(device_str: str, resource_used: float = 0.01) -> Dict[str, float]:
    """Get resource kwargs."""
    if device_str == 'GPU':
        resource_kwargs = {'num_gpus': resource_used}
    elif device_str == 'NPU':
        resource_kwargs = {'resources': {device_str: resource_used}}
    else:
        raise ValueError(f'Unsupported device type: {device_str}')
    return resource_kwargs


def _wait_until_pg_ready(current_placement_group: PlacementGroup):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is not created within time.
    """
    # copy from vLLM
    # Wait until PG is ready - this will block until all
    # requested resources are available, and will timeout
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            'Waiting for creating a placement group of specs for '
            '%d seconds. specs=%s. Check '
            '`ray status` to see if you have enough resources,'
            ' and make sure the IP addresses used by ray cluster'
            ' are the same as VLLM_HOST_IP environment variable'
            ' specified in each node if you are running on a multi-node.', int(time.time() - s), placement_group_specs)

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
        raise ValueError('Cannot provide a placement group of '
                         f'{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See '
                         '`ray status` to make sure the cluster has enough resources.') from None


def _get_obj_store_memory(dp: int = 1):
    """Get obj store memory."""
    import psutil
    DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = os.getenv('RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION', '0.3')
    DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = float(DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)
    DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = os.getenv('RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES', None)
    if DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES is None:
        DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = 80 * (10**9)
    else:
        DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = int(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES)
    total_mem = psutil.virtual_memory().total
    obj_store_mem = int(total_mem * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)
    obj_store_mem = min(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES, obj_store_mem)
    if dp > 1:
        obj_store_mem = obj_store_mem // min(8, dp)
    return obj_store_mem


def init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1, device_type: str = 'cuda'):
    """Init ray cluster."""
    # modifier from vLLM
    if not ray.is_initialized():
        try:
            num_cpus = world_size
            object_store_memory = _get_obj_store_memory(dp=dp)
            ray.init(address=ray_address,
                     ignore_reinit_error=True,
                     num_cpus=num_cpus,
                     object_store_memory=object_store_memory)
        except ValueError as e:
            if e.args is not None and len(e.args) >= 1 and e.args[
                    0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.':
                ray.init(address=ray_address, ignore_reinit_error=True)
            else:
                raise

    device_str = get_device_str(device_type)

    # Create placement group for worker processes
    current_placement_group = ray.util.get_current_placement_group()
    owned_pg = False
    if not current_placement_group:
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
        if world_size > num_devices_in_cluster:
            logger.warning(
                'The number of required %ss exceeds the total '
                'number of available %ss in the placement group.', device_str, device_str)
        # Create a new placement group
        placement_group_specs: List[Dict[str, float]] = ([{device_str: 1.0} for _ in range(world_size)])

        # Pin at least one bundle to the local node.
        # This helps multi-node DP keep each dp_rank process's workers co-located with
        # the node where the process is launched.
        current_ip = ray.util.get_node_ip_address()
        placement_group_specs[0][f'node:{current_ip}'] = 0.001

        # By default, Ray packs resources as much as possible.
        current_placement_group = ray.util.placement_group(placement_group_specs, strategy='PACK')
        _wait_until_pg_ready(current_placement_group)
        owned_pg = True

    assert current_placement_group is not None
    # Set the placement group in the parallel config
    placement_group = current_placement_group
    return placement_group, owned_pg


class RayContext:
    """Context manager for Ray."""

    def __init__(self, world_size: int, ray_address: str = None, dp: int = 1, device_type: str = 'cuda'):
        """Initialize Ray context."""
        placement_group, owned_pg = init_ray_cluster(world_size=world_size,
                                                     ray_address=ray_address,
                                                     dp=dp,
                                                     device_type=device_type)

        self.placement_group = placement_group
        self.owned_pg = owned_pg

    def get_placement_group(self):
        """Get the placement group."""
        return self.placement_group

    def shutdown(self):
        """Shutdown Ray."""
        if self.owned_pg:
            ray.util.remove_placement_group(self.placement_group)
            logger.debug('RayContext placement group removed.')

        if ray.is_initialized():
            try:
                ray.shutdown()
                logger.debug('Ray shutdown.')
            except Exception:
                logger.exception('Error during Ray shutdown.')
        else:
            logger.debug('Ray is not initialized, skipping shutdown.')


================================================
FILE: lmdeploy/pytorch/spec_decode/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from ..config import BackendConfig, SpecDecodeConfig
from ..distributed import DistContext


def build_spec_agent(specdecode_config: SpecDecodeConfig,
                     backend_config: BackendConfig,
                     dist_ctx: DistContext,
                     inputs_strategy,
                     agent_strategy,
                     device: str = 'cuda'):
    """Build spec agent."""
    enable = dist_ctx.rank % dist_ctx.dist_config.attn_tp == 0 and specdecode_config is not None
    if enable:
        from .spec_agent import SpecModelAgent
        return SpecModelAgent(specdecode_config, backend_config, inputs_strategy, agent_strategy, device=device)
    else:
        from .base import BaseSpecModelAgent
        return BaseSpecModelAgent()


__all__ = ['build_spec_agent']


================================================
FILE: lmdeploy/pytorch/spec_decode/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict

import torch

from ..config import CacheConfig, ModelConfig
from ..engine.logits_process import SamplingInputs
from ..model_inputs import ModelInputs
from ..strategies.base.model_agent import ExtraInputs


class BaseSpecModelAgent:
    """Speculative model agent."""

    def __init__(self, enable: bool = False):
        self._enabled = enable

    def is_enabled(self):
        return self._enabled

    def set_cache_config(self, cache_config: CacheConfig):
        """Set all cache config."""
        pass

    def set_model_config(self, model_config: ModelConfig):
        """Set model config."""
        pass

    def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None):
        """Build draft model."""
        pass

    def build_graph_runner(self):
        """Build graph runner."""
        pass

    def build_cache_engine(self, cache_stream: torch.cuda.Stream):
        """Build cache engine."""
        pass

    async def async_model_forward(self, next_token_ids: torch.Tensor, model_inputs: ModelInputs,
                                  extra_inputs: ExtraInputs, sampling_inputs: SamplingInputs):
        """Draft model forward."""
        return extra_inputs

    def warmup(self, max_batches: int, target_model_config: ModelConfig):
        """warmup."""
        pass

    def reset_graph_runner(self):
        'reset graph runner'
        pass

    def update_main_model_outputs(self, output: Dict[str, torch.Tensor], model_inputs: ModelInputs):
        """Update outputs of main model."""
        if not self.is_enabled():
            hidden_states = output.pop('hidden_states')
            return hidden_states, output

        hidden_states = output['hidden_states']
        if not model_inputs.is_decoding:
            logits_indices = model_inputs.seq_length.cumsum(0) - 1
            hidden_states = hidden_states[:, logits_indices]
        if 'aux_hidden_states' in output:
            # replace with aux
            output['hidden_states'] = output.pop('aux_hidden_states')
        return hidden_states, output


================================================
FILE: lmdeploy/pytorch/spec_decode/proposers/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.

from .deepseek_mtp import DeepseekMTP  # noqa F401
from .eagle import Eagle  # noqa F401
from .eagle3 import Eagle3  # noqa F401


================================================
FILE: lmdeploy/pytorch/spec_decode/proposers/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional

import torch
from mmengine import Registry
from torch.profiler import record_function

from lmdeploy.utils import get_logger

from ...config import ModelConfig, SpecDecodeConfig
from ...engine.cache_engine import CacheEngine
from ...model_inputs import ModelInputs, step_ctx_manager
from ...models.patch import build_patched_model, update_custom_module_map
from ...strategies.base.model_agent import ExtraInputs
from ...weight_loader.model_weight_loader import load_model_weights

SPEC_PROPOSERS = Registry('spec_proposers')

logger = get_logger('lmdeploy')


@torch.inference_mode()
def draft_model_forward(
    model: torch.nn.Module,
    inputs: ModelInputs,
    model_config: Optional[ModelConfig] = None,
    cache_engine: Optional[CacheEngine] = None,
):
    """Perform model forward."""
    stream = torch.cuda.current_stream()
    with torch.cuda.stream(stream), step_ctx_manager(model.ctx_mgr):
        # forward
        ctx_mgr = model.ctx_mgr
        kv_caches = None if cache_engine is None else cache_engine.gpu_cache
        context = ctx_mgr.build_context(
            inputs=inputs,
            model_config=model_config,
            cache_config=cache_engine.cache_config,
            kv_caches=kv_caches,
        )
        with ctx_mgr.context(context):
            model_metas = None
            model_metas = model.update_model_metas(
                past_key_values=kv_caches,
                context=context,
            )
            input_dict = model.prepare_inputs_for_generation(
                past_key_values=kv_caches,
                context=context,
            )
            outputs = model(**input_dict)
            if not isinstance(outputs, dict):
                outputs = dict(hidden_states=outputs)
            outputs.update(dict(model_metas=model_metas))
    return outputs


class BaseSpecProposer:

    def __init__(self, specdecode_config: SpecDecodeConfig, device: torch.device = None):
        self.specdecode_config = specdecode_config
        self.model = None
        self.device = device
        self.lm_head = None
        self.num_speculative_tokens = specdecode_config.num_speculative_tokens
        self.target_model = None

    def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None):
        if self.specdecode_config is None:
            return
        model_path = self.specdecode_config.model
        model_config = self.specdecode_config.model_config
        custom_module_map = model_config.custom_module_map
        if custom_module_map is not None:
            update_custom_module_map(custom_module_map)
        logger.debug('build draft model')
        patched_model = build_patched_model(
            model_config,
            device=self.device,
            build_model_ctx=build_model_ctx,
        )
        logger.debug('loading weights for draft model.')
        if not empty_init:
            load_model_weights(patched_model, model_path, device=self.device)
        self.model = patched_model
        self.target_model = target_model

    def get_outputs(self,
                    model_outputs: Dict[str, torch.Tensor],
                    model_inputs: ModelInputs,
                    extra_inputs: ExtraInputs = None):
        """Get outputs."""
        raise NotImplementedError()

    @record_function('draft_model_forward')
    def _forward(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None):
        """Forward."""
        return draft_model_forward(
            self.model,
            model_inputs,
            model_config=self.specdecode_config.model_config,
            cache_engine=cache_engine,
        )

    def update_inputs_decoding(self, model_inputs: ModelInputs, extra_inputs: ExtraInputs, next_input_ids: torch.Tensor,
                               target_hidden_states: torch.Tensor, model_metas: List[Any]):
        """Update to decoding inputs."""
        model_inputs.is_decoding = True
        batch_size = model_inputs.seq_length.size(0)
        model_inputs.input_ids = next_input_ids
        model_inputs.max_q_seqlen = 1
        model_inputs.max_kv_seqlen += 1
        model_inputs.sum_kv_seqlen += model_inputs.seq_length.numel()
        model_inputs.history_lengths += model_inputs.seq_length
        if extra_inputs.num_rejected_tokens is not None:
            model_inputs.history_lengths -= extra_inputs.num_rejected_tokens
        model_inputs.seq_length = model_inputs.seq_length.new_ones(batch_size)
        model_inputs.target_position_ids = model_inputs.history_lengths.unsqueeze(0).clone()
        model_inputs.model_metas = model_metas
        model_inputs.target_hidden_states = target_hidden_states
        return model_inputs

    @record_function('draft_get_logits')
    def get_logits(self, hidden_states: torch.Tensor):
        """Get logits of model output."""
        draft_model = self.model
        if not isinstance(draft_model, torch.nn.Module):
            draft_model = draft_model.model

        if hasattr(draft_model, 'get_logits'):
            logits = draft_model.get_logits(hidden_states)
        else:
            logits = self.target_model.get_logits(hidden_states)
        return logits

    def get_target_hidden_size(self, model_config: ModelConfig):
        """Get target hidden size."""
        return model_config.hidden_size


def build_specdecode_proposer(specdecode_config: SpecDecodeConfig, device: str = 'cuda'):
    """Build spec decoding proposer."""
    method = specdecode_config.method
    if method in SPEC_PROPOSERS.module_dict:
        spec_cls = SPEC_PROPOSERS.module_dict[method]
        obj = spec_cls(specdecode_config, device=device)
        return obj
    raise ValueError(f'{method} not found in {SPEC_PROPOSERS.module_dict.keys()}')


================================================
FILE: lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch

from lmdeploy.utils import get_logger

from ...model_inputs import ModelInputs
from ...strategies.ar_spec.model_agent import ARSpecExtraInputs
from .base import SPEC_PROPOSERS, BaseSpecProposer

logger = get_logger('lmdeploy')


@SPEC_PROPOSERS.register_module(name='deepseek_mtp')
class DeepseekMTP(BaseSpecProposer):

    def get_outputs(self,
                    model_outputs: Dict[str, torch.Tensor],
                    model_inputs: ModelInputs,
                    extra_inputs: ARSpecExtraInputs = None):
        """Get outputs."""
        hidden_states = model_outputs['hidden_states']
        model_metas = model_outputs['model_metas']
        if extra_inputs is not None and extra_inputs.last_token_indices is not None:
            # for long input
            if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:
                hidden_states = hidden_states[:, -1:]
            else:
                last_token_loc = extra_inputs.last_token_indices
                hidden_states = hidden_states[:, last_token_loc]

        logits = self.get_logits(hidden_states)[0]
        draft_token_ids = logits.argmax(dim=-1, keepdim=True)
        return draft_token_ids, model_metas, hidden_states


================================================
FILE: lmdeploy/pytorch/spec_decode/proposers/eagle.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .base import SPEC_PROPOSERS
from .deepseek_mtp import DeepseekMTP


@SPEC_PROPOSERS.register_module(name='eagle')
class Eagle(DeepseekMTP):
    """Eagle."""


================================================
FILE: lmdeploy/pytorch/spec_decode/proposers/eagle3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch

from lmdeploy.utils import get_logger

from ...config import ModelConfig
from ...model_inputs import ModelInputs
from ...strategies.base.model_agent import ExtraInputs
from .base import SPEC_PROPOSERS
from .deepseek_mtp import DeepseekMTP

logger = get_logger('lmdeploy')


@SPEC_PROPOSERS.register_module(name='eagle3')
class Eagle3(DeepseekMTP):

    def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None):
        super().build_model(empty_init, target_model=target_model, build_model_ctx=build_model_ctx)
        self.draft_id_to_target_id = self.model.draft_id_to_target_id
        if not self.model.include_embed_tokens:
            logger.info('Using embed_tokens from target model.')
            del self.model.model.embed_tokens
            self.model.model.embed_tokens = target_model.get_input_embeddings()

    def get_target_hidden_size(self, model_config: ModelConfig):
        """Get target hidden size."""
        hf_config = self.specdecode_config.model_config.hf_config
        hidden_size = getattr(hf_config, 'target_hidden_size', hf_config.hidden_size)
        return hidden_size * 3

    def get_outputs(self,
                    model_outputs: Dict[str, torch.Tensor],
                    model_inputs: ModelInputs,
                    extra_inputs: ExtraInputs = None):
        """Get outputs."""
        hidden_states = model_outputs['hidden_states']
        hidden_states_prenorm = model_outputs['hidden_states_prenorm']
        model_metas = model_outputs['model_metas']
        if extra_inputs is not None and extra_inputs.last_token_indices is not None:
            # for long input
            if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:
                hidden_states = hidden_states[:, -1:]
                hidden_states_prenorm = hidden_states_prenorm[:, -1:]
            else:
                last_token_loc = extra_inputs.last_token_indices
                hidden_states = hidden_states[:, last_token_loc]
                hidden_states_prenorm = hidden_states_prenorm[:, last_token_loc]

        logits = self.get_logits(hidden_states)[0]
        draft_token_ids = logits.argmax(dim=-1, keepdim=True)
        # token mapping
        draft_token_ids = self.draft_id_to_target_id[draft_token_ids]
        return draft_token_ids, model_metas, hidden_states_prenorm


================================================
FILE: lmdeploy/pytorch/spec_decode/reject_sampler.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from typing import Optional

import torch
from torch import LongTensor, Tensor, nn
from torch.profiler import record_function


class SamplePolicy(enum.Enum):
    """Sample policy."""

    ALL_GREEDY = enum.auto()


class RejectionSampler(nn.Module):

    def __init__(self, sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY):
        super().__init__()
        self.sample_policy = sample_policy

    def forward(
        self,
        target_logits: Tensor,
        draft_token_ids: LongTensor,
        bonus_token_ids: LongTensor,
        draft_probs: Optional[Tensor] = None,
    ):
        """forward
        Args:
            target_logits (Tensor): The logits of target model in shape of [batch_size, num_spec_tokens, vocab_size].
            draft_token_ids (LongTensor): The input draft tokens ishape of [batch_size, num_spec_tokens]
            bonus_token_ids (LongTensor): The bonus token ids in shape of [batch_size, 1].
            draft_probs (Tensor): The probability of draft model in shape of [batch_size, num_spec_tokens, vocab_size].
                Default to ``None``.
        """
        output_token_ids, num_rejected_tokens, last_token_ids = rejection_sample(
            target_logits,
            draft_token_ids,
            bonus_token_ids,
            draft_probs=draft_probs,
        )
        return output_token_ids, num_rejected_tokens, last_token_ids


@record_function('rejection_sample')
def rejection_sample(
    target_probs: Tensor,
    draft_token_ids: LongTensor,
    bonus_token_ids: LongTensor,
    sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY,
    draft_probs: Optional[Tensor] = None,
):
    """rejection sample
    Args:
        target_probs (Tensor):

    """
    assert draft_probs is None or draft_probs.is_contiguous()
    assert sample_policy == SamplePolicy.ALL_GREEDY, 'only support all greedy sampling policy'

    target_argmax_tokens = target_probs.argmax(dim=-1)
    return greedy_reject_sampler(draft_token_ids, target_argmax_tokens, bonus_token_ids)


def greedy_reject_sampler(draft_token_ids, target_token_ids, bonus_token_ids):
    """Greedy reject sampler
    1. keep targets tokens that are equal to draft tokens
    2. keep first not equal target tokens
    3. add bonus tokens if all equal
    Args:
        draft_token_ids: (batch_size, num_spec_tokens)
        target_token_ids: (batch_size, num_spec_tokens)
        bonus_token_ids: (batch_size, 1)
    Returns:
        output_token_ids: (batch_size, num_spec_tokens + 1)
    """
    masks = draft_token_ids == target_token_ids
    batch_size, num_spec_tokens = draft_token_ids.shape
    # check rest draft tokens
    range_data = torch.arange(num_spec_tokens, device=draft_token_ids.device)[None, :]
    equals = (masks.cumsum(dim=1) - 1) == range_data
    num_rejected_tokens = num_spec_tokens - equals.sum(dim=1)
    first_diff_indices = torch.argmin(equals.int(), dim=1, keepdim=True)
    keeps = range_data.repeat(batch_size, 1) <= first_diff_indices
    keeps = keeps | equals
    keep_token_ids = torch.where(keeps, target_token_ids, -1)
    # add bonus tokens
    keep_bonus_ids = torch.where(equals[:, -1:], bonus_token_ids, -1)
    output_token_ids = torch.cat([keep_token_ids, keep_bonus_ids], dim=1)
    # get last token ids
    last_indices = (torch.cat([keeps, equals[:, -1:]], dim=1).cumsum(dim=1) - 1)[:, -1].flatten()
    last_token_ids = output_token_ids[torch.arange(batch_size, device=draft_token_ids.device), last_indices]
    return output_token_ids, num_rejected_tokens, last_token_ids


================================================
FILE: lmdeploy/pytorch/spec_decode/spec_agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import asyncio

import torch

from lmdeploy.utils import get_logger

from ..backends import get_backend
from ..config import BackendConfig, CacheConfig, ModelConfig, SpecDecodeConfig
from ..engine.cache_engine import CacheEngine
from ..engine.logits_process import SamplingInputs
from ..model_inputs import ModelInputs
from ..strategies.ar_spec.model_agent import ARSpecExtraInputs
from ..strategies.base.model_agent import ExtraInputs
from .base import BaseSpecModelAgent
from .proposers.base import build_specdecode_proposer
from .reject_sampler import RejectionSampler

logger = get_logger('lmdeploy')


class SpecModelAgent(BaseSpecModelAgent):
    """Speculative model agent."""

    def __init__(
        self,
        specdecode_config: SpecDecodeConfig,
        backend_config: BackendConfig,
        inputs_strategy,
        agent_strategy,
        device: str = 'cuda',
    ):
        super().__init__(enable=True)

        self.backend_config = backend_config
        self.device = device
        self.cache_engine = None
        self.inputs_strategy = inputs_strategy
        self.agent_strategy = agent_strategy
        self.rejection_sampler = RejectionSampler()
        self.proposer = build_specdecode_proposer(specdecode_config, device=device)
        self.method = specdecode_config.method
        self.model_config = specdecode_config.model_config
        self.cache_config = specdecode_config.cache_config
        self.num_spec_tokens = specdecode_config.num_speculative_tokens

    def set_cache_config(self, cache_config: CacheConfig):
        """Set all cache config."""
        self.cache_config = cache_config

    def set_model_config(self, model_config: ModelConfig):
        """Set model config."""
        self.model_config = model_config

    def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None):
        """Build draft model."""
        self.proposer.build_model(empty_init, target_model=target_model, build_model_ctx=build_model_ctx)

    def build_graph_runner(self):
        """Build graph runner."""
        backend = get_backend()
        self.proposer.model = backend.build_graph_runner(self.proposer.model,
                                                         model_config=self.model_config,
                                                         cache_config=self.cache_config,
                                                         backend_config=self.backend_config,
                                                         device=self.device)

    def build_cache_engine(self, cache_stream: torch.cuda.Stream):
        """Build cache engine."""
        if self.cache_config is not None:
            self.cache_engine = CacheEngine(self.cache_config,
                                            self.model_config,
                                            rank=0,
                                            tp_rank=0,
                                            world_size=1,
                                            cache_stream=cache_stream)

    def _rejection_sampling(self, next_token_ids, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs):
        """Do rejection sampling."""
        num_rejected_tokens = torch.zeros_like(model_inputs.seq_length)
        bonus_token_ids = output_token_ids = next_token_ids.unsqueeze(-1)
        last_token_indices = model_inputs.seq_length.cumsum(0) - 1
        if model_inputs.is_decoding:
            # only do rejection sample for decoding with draft tokens
            input_draft_token_ids = model_inputs.input_ids.squeeze(0).unflatten(0, (-1, self.num_spec_tokens + 1))[:,
                                                                                                                   1:]
            output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler(
                extra_inputs.target_logits,
                input_draft_token_ids,
                bonus_token_ids,
            )
            # update last token indices
            last_token_indices = last_token_indices - num_rejected_tokens

        # create new inputs
        input_ids = model_inputs.input_ids.clone()
        seq_length = model_inputs.seq_length
        # # offset by 1 token
        input_ids[:, :-1] = model_inputs.input_ids[:, 1:]
        # # update next tokens
        input_ids[:, last_token_indices] = next_token_ids
        # use new inputs
        new_model_inputs = ModelInputs(
            input_ids=input_ids,
            seq_length=seq_length,
            max_kv_seqlen=model_inputs.max_kv_seqlen,
            max_q_seqlen=model_inputs.max_q_seqlen,
            sum_kv_seqlen=model_inputs.sum_kv_seqlen,
            history_lengths=model_inputs.history_lengths.clone(),
            block_offsets=model_inputs.block_offsets,
            num_ignored_history=model_inputs.num_ignored_history,
            is_decoding=model_inputs.is_decoding,
            target_hidden_states=extra_inputs.target_hidden_states,
            target_position_ids=extra_inputs.target_position_ids,
        )
        new_extra_inputs = ARSpecExtraInputs(
            next_token_ids=next_token_ids,
            last_token_indices=last_token_indices,
            num_rejected_tokens=num_rejected_tokens,
            output_token_ids=output_token_ids,
        )
        return new_model_inputs, new_extra_inputs

    def _forward_impl(self, inputs: ModelInputs):
        """Forward impl."""
        output = self.proposer._forward(inputs, cache_engine=self.cache_engine)
        return output

    async def _async_forward(self, inputs: ModelInputs):
        """Model forward.

        Args:
            inputs (Dict): The input data comes from _make_inputs.
        """
        output = self._forward_impl(inputs)
        await asyncio.sleep(0)
        return output

    async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecExtraInputs,
                                   sampling_inputs: SamplingInputs):
        """Model forward.

        Args:
            inputs (Dict): The input data comes from _make_inputs.
        """
        outputs = await self._async_forward(inputs)
        if inputs.is_chunk:
            return torch.zeros_like(inputs.input_ids)

        loop_count = self.num_spec_tokens - 1
        draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs, extra_inputs)
        draft_tokens_li = [draft_token_ids]
        if loop_count > 0:
            # set last_token_indices to None for decoding
            extra_inputs.last_token_indices = None
            inputs = self.proposer.update_inputs_decoding(inputs, extra_inputs, draft_token_ids.transpose(0, 1),
                                                          target_hidden_states, model_metas)
            for loop_idx in range(loop_count):
                outputs = await self._async_forward(inputs)
                draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)
                draft_tokens_li.append(draft_token_ids)
                if loop_idx < loop_count - 1:
                    step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0))
                    inputs.step(draft_token_ids.transpose(0, 1), step_seqlens)
                    inputs.model_metas = model_metas
                    inputs.target_hidden_states = target_hidden_states
                    if inputs.target_position_ids is not None:
                        inputs.target_position_ids += 1

        output_draft_ids = torch.cat(draft_tokens_li, dim=-1)
        return output_draft_ids

    async def async_model_forward(
        self,
        next_token_ids: torch.Tensor,
        model_inputs: ModelInputs,
        extra_inputs: ExtraInputs,
        sampling_inputs: SamplingInputs,
    ):
        """Draft model forward."""
        draft_model_inputs, draft_extra_inputs = self._rejection_sampling(next_token_ids, model_inputs, extra_inputs)
        next_draft_ids = await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs)
        draft_extra_inputs.output_draft_token_ids = next_draft_ids
        return draft_extra_inputs

    def warmup(self, max_batches: int, target_model_config: ModelConfig):
        """warmup."""
        target_hidden_size = self.proposer.get_target_hidden_size(target_model_config)

        # warmup prefill
        inputs = self.inputs_strategy.make_dummy(max_batches,
                                                 is_decoding=False,
                                                 device='cuda',
                                                 vocab_size=self.model_config.vocab_size,
                                                 target_hidden_size=target_hidden_size,
                                                 target_dtype=self.model_config.dtype)

        self._forward_impl(inputs)

        capture_batch_sizes = self.proposer.model.get_capture_batch_sizes()
        capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)

        for batch_size in capture_batch_sizes:
            # decode with num_spec_tokens + 1 per seq
            inputs = self.inputs_strategy.make_dummy(
                batch_size,
                is_decoding=True,
                device='cuda',
                vocab_size=self.model_config.vocab_size,
                max_q_seqlen=self.num_spec_tokens + 1,
                target_hidden_size=target_hidden_size,
                target_dtype=self.model_config.dtype,
            )
            self._forward_impl(inputs)
            # decode 1 tokens per sequence
            inputs = self.inputs_strategy.make_dummy(
                batch_size,
                is_decoding=True,
                device='cuda',
                vocab_size=self.model_config.vocab_size,
                max_q_seqlen=1,
                target_hidden_size=self.model_config.hidden_size,
                target_dtype=self.model_config.dtype,
            )
            self._forward_impl(inputs)

    def reset_graph_runner(self):
        'reset graph runner'
        if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'):
            self.proposer.model.reset()


================================================
FILE: lmdeploy/pytorch/strategies/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import MiscConfig, ModelConfig, SpecDecodeConfig


def build_strategy_factory(model_config: ModelConfig,
                           misc_config: MiscConfig,
                           specdecode_config: SpecDecodeConfig = None):
    """Build strategy factory."""
    model_paradigm = model_config.model_paradigm

    if model_paradigm == 'ar':
        from .ar import ARStrategyFactory
        return ARStrategyFactory(model_config=model_config)
    elif model_paradigm == 'dllm':
        from .dllm import DLLMStrategyFactory
        return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config)
    elif model_paradigm == 'ar_spec':
        from .ar_spec import ARSpecStrategyFactory
        assert specdecode_config is not None, 'specdecode_config must be provided for ar_spec model'
        return ARSpecStrategyFactory(model_config=model_config, specdecode_config=specdecode_config)
    else:
        raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}')


================================================
FILE: lmdeploy/pytorch/strategies/ar/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from lmdeploy.pytorch.config import ModelConfig
from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy

if TYPE_CHECKING:
    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy
    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

from ..base import StrategyFactoryBase


class ARStrategyFactory(StrategyFactoryBase):

    def __init__(self, model_config: ModelConfig):
        """config."""
        self.model_config = model_config

    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
        """Build cudagraph strategy."""
        from .cudagraph import ARCudagraphStrategy
        return ARCudagraphStrategy()

    def build_sampling_strategy(self) -> 'SamplingStrategy':
        """Build sampling strategy."""
        from .sampling import ARSamplingStrategy
        pad_token_id = self.model_config.bos_token_id
        pad_token_id = 0 if pad_token_id is None else pad_token_id
        return ARSamplingStrategy(pad_token_id)

    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
        """Build model inputs strategy."""
        from .model_inputs import ARModelInputsStrategy
        return ARModelInputsStrategy()

    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
        """Build model agent strategy."""
        from .model_agent import ARModelAgentStrategy
        return ARModelAgentStrategy()

    def build_engine_strategy(self, cache_config: 'CacheConfig',
                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
        """Build engine strategy."""
        from .engine import AREngineStrategy
        return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config)

    def build_sequence_strategy(self) -> SequenceStrategy:
        from .sequence import ARSequenceStrategy
        return ARSequenceStrategy()


================================================
FILE: lmdeploy/pytorch/strategies/ar/cudagraph.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..base.cudagraph import CudagraphStrategy


class ARCudagraphStrategy(CudagraphStrategy):

    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:
        """Get max tokens."""
        return batch_size


================================================
FILE: lmdeploy/pytorch/strategies/ar/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

from ..base.engine import EngineStrategy


class AREngineStrategy(EngineStrategy):
    """AR Engine Strategy."""

    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config

    def get_prealloc_size(self, is_decoding: bool):
        """Get prealloc_size."""
        return self.scheduler_config.prefill_interval if is_decoding else 0

    def get_num_loops(self, is_decoding: bool) -> int:
        """Get num_loops."""
        return self.scheduler_config.prefill_interval if is_decoding else 1

    def get_num_decode_tokens(self) -> int:
        """Get num_decode_tokens."""
        return 1


================================================
FILE: lmdeploy/pytorch/strategies/ar/model_agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch.profiler import record_function

from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria

SeqList = List[SchedulerSequence]


def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen: int,
                                   model_metas) -> ModelInputs:
    """Next decoding step."""
    if input_ids.dim() == 1:
        input_ids = input_ids[None, :]
    state_offsets = inputs.state_offsets
    if state_offsets is not None:
        state_offsets = state_offsets.clone()
    return ModelInputs(
        input_ids=input_ids,
        seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),
        history_lengths=inputs.history_lengths + inputs.seq_length,
        block_offsets=inputs.block_offsets,
        is_decoding=True,
        num_ignored_history=inputs.num_ignored_history.clone(),
        max_q_seqlen=max_q_seqlen,
        max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,
        sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,
        local_adapter_ids=inputs.local_adapter_ids,
        model_metas=model_metas,
        state_offsets=state_offsets,
    )


@dataclass
class ARExtraInputs(ExtraInputs):
    """Ar extra inputs."""


@dataclass
class ARExtraOutputs(ExtraOutputs):
    """Ar extra outputs."""


@dataclass
class ARStoppingCriteria(StoppingCriteria):
    num_appendable_ids: torch.Tensor

    def clone(self):
        """clone."""
        return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids)

    def merge(self, other: 'ARStoppingCriteria'):
        """Merge two stopping criteria."""
        new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0)
        return ARStoppingCriteria(num_appendable_ids=new_num_appendable)

    def update(self, delta: ModelInputsDelta):
        """Update stopping criteria."""
        indices = delta.indices
        new_num_appendable = self.num_appendable_ids[indices]
        return ARStoppingCriteria(num_appendable_ids=new_num_appendable)

    @record_function('stopping_criteria')
    def step(self,
             token_ids: torch.Tensor,
             stop_words: torch.Tensor,
             inputs: Optional[ModelInputs] = None,
             extra_inputs: Optional[ARExtraInputs] = None):
        """Check whether to stop generation."""
        num_appendable_ids = self.num_appendable_ids - 1
        stopped = num_appendable_ids <= 0
        stop_pos = torch.zeros_like(num_appendable_ids)
        if stop_words is not None:
            sw_stopped = (token_ids[:, None] == stop_words).any(1)
            stopped = stopped | sw_stopped
            one_ids = torch.clamp_max(num_appendable_ids, 0)
            num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)

        # I don't know why assign inplace does not works...
        new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids)
        return stopped, stop_pos, new_stopping


class ARModelAgentStrategy(ModelAgentStrategy):

    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
        """Slice outputs."""
        # batch size == 1
        if len(seq_length) == 1:
            return inputs[-1:]

        if len(seq_length) == inputs.size(0):
            return inputs
        last_idx = seq_length.cumsum(-1) - 1
        return inputs[last_idx]

    def slice_extra_inputs(self, extra_inputs: ARExtraInputs, model_inputs: ModelInputs,
                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ARExtraInputs:
        """Slice outputs."""
        return extra_inputs

    @record_function('step_sampling_inputs')
    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, **kwargs):
        """step."""
        sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1
        if sampling_inputs.random_offsets is not None:
            # random offset is used to generate random numbers for multinomial sampling
            # so we need to increase it by 1 at each step
            sampling_inputs.random_offsets += 1

        all_ids = sampling_inputs.all_ids
        if all_ids is not None:
            sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)

        return sampling_inputs

    def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria:
        """Create stopping criteria."""
        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
        num_appendable = torch.tensor(num_appendable)
        return ARStoppingCriteria(num_appendable_ids=num_appendable)

    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:
        """Create extra inputs."""
        return ARExtraInputs()

    def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs:
        """Create extra outputs."""
        return ARExtraOutputs()

    def update_prefill_for_next_step(
        self,
        model_inputs: 'ModelInputs',
        extra_inputs: ARExtraInputs,
        next_token_ids: torch.Tensor,
        model_metas: Any,
        extra_outputs: ARExtraOutputs,
    ) -> Tuple['ModelInputs', ARExtraInputs]:
        """Step next decoding."""
        inputs = get_model_inputs_next_decoding(model_inputs, next_token_ids, max_q_seqlen=1, model_metas=model_metas)
        return inputs, extra_inputs

    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
                                      extra_inputs: ARExtraInputs, **kwargs):
        """Step next inputs."""
        model_inputs.model_metas = model_metas
        step_seqlens = model_inputs.seq_length
        model_inputs.step(next_token_ids, step_seqlens)
        return model_inputs, extra_inputs

    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
                      extra_inputs: ARExtraInputs):
        """Post sampling."""
        return next_token_ids, extra_inputs

    @contextmanager
    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: DistContext):
        """Broadcast next token ids and extra inputs."""
        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group
        rank = dist.get_global_rank(tp_gpu_group, 0)
        handle = dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)
        yield
        handle.wait()


================================================
FILE: lmdeploy/pytorch/strategies/ar/model_inputs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import numpy as np
import torch
from torch.profiler import record_function

from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs


def merge_model_inputs(inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
    """Merge model inputs."""

    def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int):
        """Try pad block offsets to target size."""
        cur_size = block_offsets.size(1)
        if cur_size < target_size:
            pad_size = target_size - cur_size
            pad_tensor = torch.zeros((block_offsets.size(0), pad_size),
                                     dtype=block_offsets.dtype,
                                     device=block_offsets.device)
            block_offsets = torch.cat([block_offsets, pad_tensor], dim=1)
        return block_offsets

    assert inputs.is_decoding and other.is_decoding, 'Only support merge in decoding.'
    input_ids = torch.cat([inputs.input_ids, other.input_ids], dim=-1)
    seq_length = torch.cat([inputs.seq_length, other.seq_length], dim=0)
    history_lengths = torch.cat([inputs.history_lengths, other.history_lengths], dim=0)

    # block offsets
    max_blocks = max(inputs.block_offsets.size(1), other.block_offsets.size(1))
    block_offsets0 = __try_pad_block_offsets(inputs.block_offsets, max_blocks)
    block_offsets1 = __try_pad_block_offsets(other.block_offsets, max_blocks)
    block_offsets = torch.cat([block_offsets0, block_offsets1], dim=0)
    num_ignored_history = torch.cat([inputs.num_ignored_history, other.num_ignored_history], dim=0)

    # lora adapter ids
    local_adapter_ids = inputs.local_adapter_ids
    if local_adapter_ids is not None and other.local_adapter_ids is not None:
        local_adapter_ids = torch.cat([local_adapter_ids, other.local_adapter_ids], dim=0)

    # model metas for vl models
    model_metas = None
    if inputs.model_metas is not None and other.model_metas is not None:
        model_metas = inputs.model_metas + other.model_metas

    # ssm
    state_offsets = None
    if inputs.state_offsets is not None:
        state_offsets = torch.cat([inputs.state_offsets, other.state_offsets], dim=0)

    return ModelInputs(
        input_ids=input_ids,
        seq_length=seq_length,
        history_lengths=history_lengths,
        block_offsets=block_offsets,
        is_decoding=inputs.is_decoding,
        num_ignored_history=num_ignored_history,
        max_q_seqlen=max(inputs.max_q_seqlen, other.max_q_seqlen),
        max_kv_seqlen=max(inputs.max_kv_seqlen, other.max_kv_seqlen),
        sum_kv_seqlen=inputs.sum_kv_seqlen + other.sum_kv_seqlen,
        local_adapter_ids=local_adapter_ids,
        model_metas=model_metas,
        state_offsets=state_offsets,
    )


class ARModelInputsStrategy(ModelInputsStrategy):

    def make_dummy(self,
                   batch_size: int,
                   is_decoding: bool,
                   device: str = 'cpu',
                   dummy_block_id: int = 0,
                   vocab_size: int = 1) -> ModelInputs:
        """Create dummy model inputs."""
        return make_dummy_inputs(batch_size,
                                 max_q_seqlen=1,
                                 is_decoding=is_decoding,
                                 device=device,
                                 dummy_block_id=dummy_block_id,
                                 vocab_size=vocab_size)

    @record_function('ModelInputs.merge')
    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
        """Merge model inputs."""
        return merge_model_inputs(inputs, other)

    @staticmethod
    def index_select(inputs: ModelInputs,
                     indices: torch.Tensor,
                     indice_cpu: np.ndarray = None,
                     block_offsets: torch.Tensor = None,
                     max_q_seqlen: Optional[int] = None,
                     max_kv_seqlen: Optional[int] = None,
                     sum_kv_seqlen: Optional[int] = None,
                     num_ignored_history: Optional[torch.Tensor] = None):
        """Index select."""
        assert inputs.is_decoding, 'Only support index_select in decoding.'

        if len(indices) == len(inputs.seq_length):
            # we will not change the order of indices
            # so same length means no change
            indices = Ellipsis

        # required inputs
        input_ids = inputs.input_ids[..., indices]
        seq_length = inputs.seq_length[indices]
        history_lengths = inputs.history_lengths[indices]
        if block_offsets is None:
            block_offsets = inputs.block_offsets[indices]
        if num_ignored_history is None:
            num_ignored_history = inputs.num_ignored_history[indices]
        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen
        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen
        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen

        # lora adapter ids
        local_adapter_ids = inputs.local_adapter_ids
        if local_adapter_ids is not None:
            local_adapter_ids = local_adapter_ids[indices]

        # model metas for vl models
        model_metas = inputs.model_metas
        if model_metas is not None and indice_cpu is not None:
            model_metas = [model_metas[i] for i in indice_cpu]

        # for ssm
        state_offsets = inputs.state_offsets
        if state_offsets is not None:
            state_offsets = state_offsets[indices]

        # spec decoding
        target_hidden_states = inputs.target_hidden_states
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states[indices]
        target_position_ids = inputs.target_position_ids
        if target_position_ids is not None:
            target_position_ids = target_position_ids[indices]

        # return new inputs
        return ModelInputs(
            input_ids=input_ids,
            seq_length=seq_length,
            history_lengths=history_lengths,
            block_offsets=block_offsets,
            is_decoding=inputs.is_decoding,
            num_ignored_history=num_ignored_history,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            local_adapter_ids=local_adapter_ids,
            model_metas=model_metas,
            state_offsets=state_offsets,
            target_hidden_states=target_hidden_states,
            target_position_ids=target_position_ids,
        )

    @record_function('ModelInputs.update_inputs')
    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:
        """Update model inputs with delta."""
        assert inputs.is_decoding, 'Only support update_delta in decoding.'
        return self.index_select(
            inputs=inputs,
            indices=delta.indices,
            indice_cpu=delta.indice_cpu,
            block_offsets=delta.block_offsets,
            max_q_seqlen=delta.max_q_seqlen,
            max_kv_seqlen=delta.max_kv_seqlen,
            sum_kv_seqlen=delta.sum_kv_seqlen,
            num_ignored_history=delta.num_ignored_history,
        )


================================================
FILE: lmdeploy/pytorch/strategies/ar/sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
import torch
from torch.profiler import record_function

from lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputsDelta

from ..base.sampling import SamplingStrategy

SeqList = list[SchedulerSequence]


def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs):
    """Gather history."""
    if not any(sampling_inputs.logits_processors):
        return None
    batch = len(seqs)
    max_len = max(seq.num_valid_ids for seq in seqs)
    output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
    for idx, seq in enumerate(seqs):
        h_len = seq.num_valid_ids
        if h_len == 0:
            continue
        h_ids = torch.from_numpy(seq.valid_ids)
        output[idx, -h_len:] = h_ids
    return output


def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None:
    """Gather history."""
    if sampling_inputs.repetition_penalty is None and sampling_inputs.max_repetition_ngram_size == 0:
        return None
    batch = len(seqs)
    max_len = max(seq.num_new_tokens for seq in seqs)
    output = np.full((batch, max_len), pad_id, dtype=np.int64)
    for idx, seq in enumerate(seqs):
        h_len = seq.num_new_tokens
        if h_len == 0:
            continue
        h_ids = seq.generated_ids
        output[idx, -h_len:] = h_ids
    return output


def _get_num_ignore_eos(seqs: SeqList):
    """Get num ignore eos."""
    ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]
    return torch.tensor(ret)


class ARSamplingStrategy(SamplingStrategy):
    """Sampling strategy for autoregressive models."""

    def __init__(self, pad_token_id: int) -> None:
        pad_token_id = 0 if pad_token_id is None else pad_token_id
        self.pad_token_id = pad_token_id
        self.session_to_cleanup = []

    @record_function('make_sampling_inputs')
    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
        """Create sampling inputs from the sequences."""
        batch_size = len(seqs)
        temperature = [None] * batch_size
        repetition_penalty = [None] * batch_size
        top_k = [None] * batch_size
        top_p = [None] * batch_size
        min_p = [None] * batch_size
        bad_words = [None] * batch_size
        stop_words = [None] * batch_size
        random_seeds = [np.random.randint(0xffffffff)] * batch_size
        random_offsets = [None] * batch_size
        response_formats = [None] * batch_size
        logits_processors = [None] * batch_size
        num_logprobs = [None] * batch_size
        session_to_cleanup = self.session_to_cleanup
        self.session_to_cleanup = []
        repetition_ngram_sizes = [None] * batch_size
        repetition_ngram_thresholds = [None] * batch_size

        def __gather_params():
            """Gather params."""
            for idx, seq in enumerate(seqs):
                param = seq.sampling_param
                temperature[idx] = param.temperature
                repetition_penalty[idx] = param.repetition_penalty
                top_k[idx] = max(0, param.top_k)
                top_p[idx] = param.top_p
                min_p[idx] = param.min_p
                random_offsets[idx] = seq.num_valid_ids
                response_formats[idx] = param.response_format
                if param.random_seed is not None:
                    random_seeds[idx] = param.random_seed & 0xffffffff

                bw = param.bad_words
                sw = param.stop_words
                if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens):
                    bw = bw + sw
                bad_words[idx] = bw
                stop_words[idx] = sw
                logits_processors[idx] = param.logits_processors
                num_logprobs[idx] = param.num_logprobs
                repetition_ngram_sizes[idx] = param.repetition_ngram_size
                repetition_ngram_thresholds[idx] = param.repetition_ngram_threshold

        def __get_topp(top_p):
            """Get topp."""
            min_top_p = min(top_p)
            if min_top_p == 1.0:
                top_p = None
            else:
                top_p = torch.tensor(top_p)
            return top_p, min_top_p

        def __get_minp(min_p):
            """Get minp."""
            max_min_p = max(min_p)
            if max_min_p == 0.0:
                min_p = None
            else:
                min_p = torch.Tensor(min_p)
            return min_p

        def __get_bad_words(bad_words):
            """Get bad words."""
            max_bw_len = max(len(bw) for bw in bad_words)
            if max_bw_len == 0:
                return None, None
            if all(len(bw) == max_bw_len for bw in bad_words):
                ret = torch.tensor(bad_words)
                mask = torch.ones_like(ret, dtype=bool)
                return ret, mask
            ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)
            for idx, bw in enumerate(bad_words):
                bw_len = len(bw)
                if bw_len == 0:
                    continue
                bw = ret.new_tensor(bw)
                ret[idx, :bw_len] = bw

            mask = ret >= 0
            return ret, mask

        __gather_params()

        if all(rp == 1.0 for rp in repetition_penalty):
            repetition_penalty = None
        else:
            repetition_penalty = torch.tensor(repetition_penalty)

        temperature = torch.tensor(temperature)
        if (temperature == 1.0).all():
            # skip temperature processing if all temperature are 1.0
            temperature = None

        bad_words, bad_mask = __get_bad_words(bad_words)
        stop_words, stop_mask = __get_bad_words(stop_words)

        max_top_k = max(top_k)
        if min(top_k) <= 0:
            max_top_k = 0
        if max_top_k == 1:
            top_k = None
            top_p, min_top_p = None, 1.0
            min_p = None
            random_seeds = None
        else:
            top_k = torch.tensor(top_k)
            if (top_k == max_top_k).all():
                # we would perform max_top_k before top_k
                # if all top_k are same, we do not need to filter topk again
                top_k = None
            top_p, min_top_p = __get_topp(top_p)
            min_p = __get_minp(min_p)
            random_seeds = torch.tensor(random_seeds)
        random_offsets = torch.tensor(random_offsets)

        max_num_logprobs = max(num_logprobs)

        session_ctx = [{
            'session_id': seq.session.session_id,
            'seq_id': seq.seq_id,
        } for seq in seqs]

        # repetition ngram
        max_repetition_ngram_size = max(repetition_ngram_sizes)
        if max_repetition_ngram_size == 0:
            repetition_ngram_sizes = None
            repetition_ngram_thresholds = None
        else:
            repetition_ngram_sizes = torch.tensor(repetition_ngram_sizes)
            repetition_ngram_thresholds = torch.tensor(repetition_ngram_thresholds)
            repetition_ngram_same_n = (repetition_ngram_sizes == max_repetition_ngram_size).all().item()
            if repetition_ngram_same_n:
                repetition_ngram_sizes = None

        sampling_input = SamplingInputs(
            temperature=temperature,
            bad_words=bad_words,
            bad_mask=bad_mask,
            stop_words=stop_words,
            stop_mask=stop_mask,
            repetition_penalty=repetition_penalty,
            top_k=top_k,
            top_p=top_p,
            min_p=min_p,
            random_seeds=random_seeds,
            random_offsets=random_offsets,
            response_formats=tuple(response_formats),
            max_top_k=max_top_k,
            min_top_p=min_top_p,
            logits_processors=logits_processors,
            max_num_logprobs=max_num_logprobs,
            batch_size=batch_size,
            session_ctx=session_ctx,
            session_to_cleanup=session_to_cleanup,
            repetition_ngram_size=repetition_ngram_sizes,
            repetition_ngram_threshold=repetition_ngram_thresholds,
            max_repetition_ngram_size=max_repetition_ngram_size,
        )

        pad_token_id = self.pad_token_id
        sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input)
        sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input)
        sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs)
        return sampling_input

    def on_session_end(self, session_id: int):
        self.session_to_cleanup.append(session_id)

    def merge_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        other: 'SamplingInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Merge two sampling deltas."""
        num_ignore_eos = torch.cat([sampling_delta.num_ignore_eos, other.num_ignore_eos], 0)
        random_offsets = torch.cat([sampling_delta.random_offsets, other.random_offsets], 0)

        batch_size = num_ignore_eos.size(0)
        all_ids0 = sampling_delta.all_ids
        all_ids1 = other.all_ids
        if all_ids0 is None and all_ids1 is None:
            all_ids = None
        else:
            max_len0 = 0 if all_ids0 is None else all_ids0.size(1)
            max_len1 = 0 if all_ids1 is None else all_ids1.size(1)
            max_len = max(max_len0, max_len1)
            all_ids = torch.full((batch_size, max_len),
                                 self.pad_token_id,
                                 dtype=torch.int64,
                                 device=num_ignore_eos.device)
            if all_ids0 is not None:
                bs0 = all_ids0.size(0)
                all_ids[:bs0, :max_len0] = all_ids0
            if all_ids1 is not None:
                bs1 = all_ids1.size(0)
                all_ids[-bs1:, :max_len1] = all_ids1

        return SamplingInputsDelta(
            num_ignore_eos=num_ignore_eos,
            random_offsets=random_offsets,
            all_ids=all_ids,
        )

    def step_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        next_token_ids: torch.Tensor,
        **kwargs,
    ) -> 'SamplingInputsDelta':
        """Step next delta."""
        sampling_delta.num_ignore_eos = sampling_delta.num_ignore_eos - 1
        if sampling_delta.random_offsets is not None:
            # random offset is used to generate random numbers for multinomial sampling
            # so we need to increase it by 1 at each step
            sampling_delta.random_offsets += 1

        all_ids = sampling_delta.all_ids
        if all_ids is not None:
            sampling_delta.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)

        return sampling_delta

    def update_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        delta: 'ModelInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Update sampling delta with model inputs delta."""
        indices = delta.indices
        num_ignore_eos = sampling_delta.num_ignore_eos[indices]
        if sampling_delta.random_offsets is not None:
            random_offsets = sampling_delta.random_offsets[indices]
        else:
            random_offsets = None
        all_ids = sampling_delta.all_ids
        if all_ids is not None:
            all_ids = all_ids[indices]
        return SamplingInputsDelta(
            num_ignore_eos=num_ignore_eos,
            random_offsets=random_offsets,
            all_ids=all_ids,
        )


================================================
FILE: lmdeploy/pytorch/strategies/ar/sequence.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
from torch import Tensor

from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,
                                       SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray)
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..base.sequence import SequenceStrategy

SeqList = List[SchedulerSequence]


@dataclass
class SchedulerSequenceDefault(SchedulerSequence):

    def update_token_ids(self,
                         token_ids: Tensor,
                         multimodals: MultiModalInputs = None,
                         embeddings: List[InputEmbeddings] = None,
                         model_meta: Dict[str, Any] = None,
                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
                         routed_experts: np.ndarray = None,
                         **kwargs):
        """Update token ids, old token ids will be added to history."""
        # update history image nums
        self._update_embeddings(embeddings)

        # update multimodals
        self._update_multimodals(multimodals)

        token_ids = _to_ndarray(token_ids)

        num_valid = len(token_ids)
        # record cached expert ids
        self.append_routed_experts(routed_experts)

        if mode == UpdateTokenMode.INPUTS:
            self.arrive_time = time.perf_counter()
            self.output_start_pos = self.num_all_ids + len(token_ids)
            self._num_token_ids += num_valid
            self.num_new_tokens = 0
        else:
            self._num_history_ids += self._num_token_ids
            num_token_ids = num_valid
            self._num_token_ids = num_token_ids
            self.num_new_tokens += num_token_ids

        self.history_cache.append(token_ids)

        if model_meta is not None:
            self.model_meta = model_meta

    def set_step(self, step: int):
        """Set step."""
        num_all_ids = self.num_all_ids
        # update step for vlm
        if len(self.history_embeddings) > 0:
            new_step, self._num_history_images, self._num_images = \
                self.history_embeddings.get_step(step)
            assert 0 <= new_step <= step
            step = new_step
        self._num_history_ids = step
        self._num_token_ids = num_all_ids - step
        self.num_ignored_history = min(step, self.num_ignored_history)

        self.model_meta = None

        if self.return_routed_experts:
            # chunk long context might not have all routed experts
            if len(self.all_routed_experts) > step:
                self.all_routed_experts.resize(step)


class ARSequenceStrategy(SequenceStrategy):

    def make_sequence(self,
                      seq_id: int,
                      session: 'SchedulerSession',
                      sampling_param: 'SamplingParam' = None,
                      adapter_name: str = None,
                      migration_request: Optional[MigrationRequest] = None,
                      resp_cache: bool = False,
                      preserve_cache: bool = False) -> 'SchedulerSequence':
        """Make sequence."""
        return SchedulerSequenceDefault(
            seq_id=seq_id,
            session=session,
            sampling_param=sampling_param,
            adapter_name=adapter_name,
            migration_request=migration_request,
            resp_cache=resp_cache,
            preserve_cache=preserve_cache,
        )

    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',
                       delta: 'ModelInputsDelta') -> None:
        """Update running sequences."""
        next_token_ids = batched_outputs.next_token_ids
        stopped = batched_outputs.stopped
        stopped = stopped.tolist()
        model_metas = batched_outputs.model_metas
        if model_metas is None:
            model_metas = [None] * len(running)

        next_token_ids = next_token_ids.numpy()
        if model_inputs is None:
            num_tokens = delta.seq_length.tolist()
            is_decoding = delta.is_decoding
        else:
            num_tokens = model_inputs.seq_length.tolist()
            is_decoding = model_inputs.is_decoding
        all_routed_experts = [None] * len(num_tokens)
        if batched_outputs.all_routed_experts is not None:
            all_routed_experts = batched_outputs.all_routed_experts.split(num_tokens, dim=0)
            all_routed_experts = [experts.numpy() for experts in all_routed_experts]
        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL
        for token, msg, stop, model_meta, routed_experts in zip(next_token_ids, running, stopped, model_metas,
                                                                all_routed_experts):
            if msg.status != MessageStatus.RUNNING:
                continue

            # fill token
            msg.update_token_ids(token, model_meta=model_meta, mode=update_mode, routed_experts=routed_experts)
            if stop:
                msg.state.finish()


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from lmdeploy.pytorch.config import ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy

if TYPE_CHECKING:
    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy
    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

from ..base import StrategyFactoryBase


class ARSpecStrategyFactory(StrategyFactoryBase):

    def __init__(self, model_config: ModelConfig, specdecode_config: SpecDecodeConfig):
        """config."""
        self.model_config = model_config
        self.specdecode_config = specdecode_config
        self.pad_token_id = model_config.bos_token_id or 0

    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
        """Build cudagraph strategy."""
        from .cudagraph import ARSpecCudagraphStrategy
        return ARSpecCudagraphStrategy(self.specdecode_config.num_speculative_tokens)

    def build_sampling_strategy(self) -> 'SamplingStrategy':
        """Build sampling strategy."""
        from .sampling import ARSpecSamplingStrategy
        pad_token_id = self.model_config.bos_token_id
        pad_token_id = 0 if pad_token_id is None else pad_token_id
        return ARSpecSamplingStrategy(pad_token_id)

    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
        """Build model inputs strategy."""
        from .model_inputs import ARSpecModelInputsStrategy
        return ARSpecModelInputsStrategy(self.specdecode_config.num_speculative_tokens)

    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
        """Build model agent strategy."""
        from .model_agent import ARSpecModelAgentStrategy
        return ARSpecModelAgentStrategy(self.specdecode_config.num_speculative_tokens)

    def build_engine_strategy(self, cache_config: 'CacheConfig',
                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
        """Build engine strategy."""
        from .engine import ARSpecEngineStrategy
        return ARSpecEngineStrategy(cache_config=cache_config,
                                    scheduler_config=scheduler_config,
                                    num_spec_tokens=self.specdecode_config.num_speculative_tokens)

    def build_sequence_strategy(self) -> SequenceStrategy:
        from .sequence import ARSpecSequenceStrategy
        return ARSpecSequenceStrategy()


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/cudagraph.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..base.cudagraph import CudagraphStrategy


class ARSpecCudagraphStrategy(CudagraphStrategy):

    def __init__(self, num_spec_tokens: int):
        super().__init__()
        self.num_spec_tokens = num_spec_tokens

    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:
        """Get max tokens."""
        if num_tokens == origin_batch_size:
            return batch_size

        assert num_tokens % (self.num_spec_tokens + 1) == 0, 'The input_ids length must be divisible by batch_size.'
        return batch_size * (self.num_spec_tokens + 1)


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

from ..base.engine import EngineStrategy


class ARSpecEngineStrategy(EngineStrategy):
    """AR Engine Strategy."""

    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, num_spec_tokens: int) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.num_spec_tokens = num_spec_tokens

    def get_prealloc_size(self, is_decoding: bool):
        """Get prealloc_size."""
        return self.scheduler_config.prefill_interval * (1 +
                                                         self.num_spec_tokens) if is_decoding else self.num_spec_tokens

    def get_num_loops(self, is_decoding: bool) -> int:
        """Get num_loops."""
        return self.scheduler_config.prefill_interval if is_decoding else 1

    def get_num_decode_tokens(self) -> int:
        """Get num_decode_tokens."""
        return self.num_spec_tokens + 1


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/model_agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch.profiler import record_function

from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..ar.model_agent import ARStoppingCriteria, get_model_inputs_next_decoding
from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy

SeqList = List[SchedulerSequence]


@dataclass
class ARSpecExtraInputs(ExtraInputs):
    """ARSpec extra inputs."""
    # draft model inputs
    target_logits: torch.Tensor = None
    target_hidden_states: torch.Tensor = None
    target_position_ids: torch.Tensor = None
    next_token_ids: torch.LongTensor = None
    last_token_indices: torch.LongTensor = None

    # draft model outputs
    output_draft_token_ids: torch.Tensor = None
    num_rejected_tokens: torch.Tensor = None
    output_token_ids: torch.Tensor = None

    def __repr__(self):
        return (f'ARSpecExtraInputs(next_token_ids={self.next_token_ids}, '
                f'output_draft_token_ids={self.output_draft_token_ids}, '
                f'last_token_indices={self.last_token_indices}, '
                f'num_rejected_tokens={self.num_rejected_tokens}, '
                f'output_token_ids={self.output_token_ids})')

    def broadcast(self, src: int, group, async_op=False):
        dist.broadcast(self.output_draft_token_ids, src=src, group=group, async_op=async_op)
        handle = dist.broadcast(self.num_rejected_tokens, src=src, group=group, async_op=async_op)
        return handle

    def merge(self, other: 'ARSpecExtraInputs'):
        """Merge extra inputs."""
        output_token_ids = torch.cat([self.output_token_ids, other.output_token_ids], dim=0)
        return ARSpecExtraInputs(output_token_ids=output_token_ids)


@dataclass
class ARSpecExtraOutputs(ExtraOutputs):
    """ARSpec extra outputs."""
    # output the draft tokens to seq only for last loop step
    draft_token_ids: torch.Tensor = None

    def __repr__(self):
        return (f'ARSpecExtraOutputs(draft_token_ids={self.draft_token_ids})')


@dataclass
class ARSpecStoppingCriteria(ARStoppingCriteria):
    num_appendable_ids: torch.Tensor

    def clone(self):
        """clone."""
        return ARSpecStoppingCriteria(num_appendable_ids=self.num_appendable_ids)

    def merge(self, other: 'ARSpecStoppingCriteria'):
        """Merge two stopping criteria."""
        new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0)
        return ARSpecStoppingCriteria(num_appendable_ids=new_num_appendable)

    def update(self, delta: ModelInputsDelta):
        """Update stopping criteria."""
        indices = delta.indices
        new_num_appendable = self.num_appendable_ids[indices]
        return ARSpecStoppingCriteria(num_appendable_ids=new_num_appendable)

    @record_function('stopping_criteria')
    def step(self,
             next_token_ids: torch.Tensor,
             stop_words: torch.Tensor,
             inputs: Optional[ModelInputs] = None,
             extra_inputs: Optional[ARSpecExtraInputs] = None):
        """Check whether to stop generation."""
        token_ids = extra_inputs.output_token_ids

        if token_ids.ndim == 1:
            token_ids = token_ids.unsqueeze(-1)
        valid_tokens = token_ids > -1
        mask = (self.num_appendable_ids.unsqueeze(-1) - valid_tokens.cumsum(dim=-1)) <= 0
        if stop_words is not None:
            token_ids_rsp = token_ids.unsqueeze(-1).repeat(1, 1, stop_words.numel())
            stop_words_rsp = stop_words.reshape(1, 1, -1)
            assert stop_words_rsp.ndim == token_ids_rsp.ndim == 3
            stop_mask = (token_ids_rsp == stop_words_rsp).any(-1)
            mask = mask ^ stop_mask
        # find the index of first `1`,  if not found, would be 0
        index = torch.argmax(mask.int(), dim=-1, keepdim=True)
        # update index of 0 to -1 if not found
        stop_pos = torch.where(index == 0, mask[:, 0:1].int() - 1, index).ravel()
        stopped = stop_pos != -1
        num_valid_tokens = valid_tokens.sum(dim=-1)
        num_appendable_ids = self.num_appendable_ids - num_valid_tokens
        one_ids = torch.clamp_max(num_appendable_ids, 0)
        num_appendable_ids = torch.where(stopped, one_ids, num_appendable_ids)
        return stopped, stop_pos, ARSpecStoppingCriteria(num_appendable_ids=num_appendable_ids)


class ARSpecModelAgentStrategy(ModelAgentStrategy):

    def __init__(self, num_spec_tokens: int):
        self.num_spec_tokens = num_spec_tokens

    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
        """Slice outputs."""
        # batch size == 1
        if len(seq_length) == 1:
            return inputs[-1:]

        if len(seq_length) == inputs.size(0):
            return inputs
        last_idx = seq_length.cumsum(-1) - 1
        return inputs[last_idx]

    def slice_extra_inputs(self, extra_inputs: ARSpecExtraInputs, model_inputs: ModelInputs,
                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ARSpecExtraInputs:
        """Slice outputs."""
        extra_inputs = ARSpecExtraInputs()
        extra_inputs.target_hidden_states = model_outputs.get('hidden_states')
        extra_inputs.target_position_ids = model_outputs.get('position_ids', None)
        if model_inputs.is_decoding:
            batch_size = model_inputs.seq_length.size(0)
            logits = model_outputs['logits'][0]
            extra_inputs.target_logits = logits.unflatten(0, (batch_size, -1))[:, :-1]
        return extra_inputs

    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, **kwargs):
        """step."""
        sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1

        all_ids = sampling_inputs.all_ids
        if all_ids is not None:
            sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)

        return sampling_inputs

    def make_stopping_criteria(self, seqs: SeqList) -> ARSpecStoppingCriteria:
        """Create stopping criteria."""
        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
        num_appendable = torch.tensor(num_appendable)
        return ARSpecStoppingCriteria(num_appendable_ids=num_appendable)

    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:
        """Create extra inputs."""
        return ARSpecExtraInputs()

    def update_extra_inputs(self, extra_inputs: ARSpecExtraInputs, delta: 'ModelInputsDelta') -> ARSpecExtraInputs:
        """Update extra inputs with model inputs delta."""
        indices = delta.indices
        output_token_ids = extra_inputs.output_token_ids[indices]
        return ARSpecExtraInputs(output_token_ids=output_token_ids)

    def make_extra_outputs(self, extra_inputs: ARSpecExtraInputs) -> ARSpecExtraOutputs:
        """Create extra outputs."""
        output = ARSpecExtraOutputs()
        output.draft_token_ids = extra_inputs.output_draft_token_ids
        return output

    def update_prefill_for_next_step(
        self,
        model_inputs: 'ModelInputs',
        extra_inputs: ARSpecExtraInputs,
        next_token_ids: torch.Tensor,
        model_metas: Any,
        extra_outputs: ARSpecExtraOutputs,
    ) -> Tuple['ModelInputs', ARSpecExtraInputs]:
        """Step next decoding."""
        next_token_ids = next_token_ids[:, None]
        next_token_ids = torch.cat([next_token_ids, extra_outputs.draft_token_ids], dim=-1)
        max_q_seqlen = next_token_ids.size(-1)
        next_token_ids = next_token_ids.flatten()[None, :]
        inputs = get_model_inputs_next_decoding(model_inputs,
                                                next_token_ids,
                                                max_q_seqlen=max_q_seqlen,
                                                model_metas=model_metas)
        extra_inputs = ARSpecExtraInputs(output_token_ids=extra_outputs.draft_token_ids)
        return inputs, extra_inputs

    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
                                      extra_inputs: ARSpecExtraInputs, extra_outputs: ARSpecExtraOutputs):
        """Step next inputs."""
        model_inputs.model_metas = model_metas
        step_seqlens = model_inputs.seq_length
        batch_size = step_seqlens.size(0)

        # update extra inputs
        extra_inputs.output_token_ids = extra_outputs.draft_token_ids

        # update inputs
        step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens
        input_ids = next_token_ids.new_empty((batch_size, self.num_spec_tokens + 1))
        input_ids[:, 0] = next_token_ids
        input_ids[:, 1:] = extra_inputs.output_draft_token_ids
        input_ids = input_ids.flatten()[None, :]
        model_inputs.step(input_ids, step_seqlens)
        return model_inputs, extra_inputs

    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
                      extra_inputs: ARSpecExtraInputs):
        """Post sampling."""
        return next_token_ids, extra_inputs

    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):
        """Make dummy next token for broadcast."""
        with torch.inference_mode():
            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
            extra_inputs.output_draft_token_ids = inputs.input_ids.new_zeros((logits.size(0), self.num_spec_tokens))
            extra_inputs.num_rejected_tokens = inputs.input_ids.new_zeros(logits.size(0))
        return next_token_ids, extra_inputs

    @contextmanager
    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ARSpecExtraInputs,
                             dist_ctx: DistContext):
        """Broadcast next token ids and extra inputs."""
        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group
        rank = dist.get_global_rank(tp_gpu_group, 0)
        dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)
        handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True)
        yield
        handle.wait()


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/model_inputs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.profiler import record_function

from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..ar.model_inputs import merge_model_inputs
from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs


class ARSpecModelInputsStrategy(ModelInputsStrategy):

    def __init__(self, num_spec_tokens: int):
        self.num_spec_tokens = num_spec_tokens

    def make_dummy(
        self,
        batch_size: int,
        is_decoding: bool,
        device: str = 'cpu',
        dummy_block_id: int = 0,
        vocab_size: int = 1,
        max_q_seqlen: int = 1,
        target_hidden_size: int = None,
        target_dtype: torch.dtype = torch.bfloat16,
    ) -> ModelInputs:
        """Create dummy model inputs."""
        inputs = make_dummy_inputs(batch_size,
                                   max_q_seqlen=max_q_seqlen,
                                   is_decoding=is_decoding,
                                   device=device,
                                   dummy_block_id=dummy_block_id,
                                   vocab_size=vocab_size)
        if target_hidden_size is not None:
            inputs.target_hidden_states = torch.randn((1, batch_size * max_q_seqlen, target_hidden_size),
                                                      dtype=target_dtype,
                                                      device=device)
        return inputs

    @record_function('ModelInputs.merge')
    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
        """Merge model inputs."""
        return merge_model_inputs(inputs, other)

    @record_function('ModelInputs.update_inputs')
    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:
        """Update model inputs with delta."""
        assert inputs.is_decoding, 'Only support update_delta in decoding.'
        indices = delta.indices
        indice_cpu = delta.indice_cpu
        block_offsets = delta.block_offsets
        max_q_seqlen = delta.max_q_seqlen
        max_kv_seqlen = delta.max_kv_seqlen
        sum_kv_seqlen = delta.sum_kv_seqlen
        num_ignored_history = delta.num_ignored_history

        # required inputs
        # input_ids = inputs.input_ids[..., indices]
        inputs_ids = inputs.input_ids.reshape(1, -1, self.num_spec_tokens + 1)
        input_ids = inputs_ids[:, indices].reshape(1, -1)
        seq_length = inputs.seq_length[indices]
        history_lengths = inputs.history_lengths[indices]
        if block_offsets is None:
            block_offsets = inputs.block_offsets[indices]
        if num_ignored_history is None:
            num_ignored_history = inputs.num_ignored_history[indices]
        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen
        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen
        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen

        # lora adapter ids
        local_adapter_ids = inputs.local_adapter_ids
        if local_adapter_ids is not None:
            local_adapter_ids = local_adapter_ids[indices]

        # model metas for vl models
        model_metas = inputs.model_metas
        if model_metas is not None and indice_cpu is not None:
            model_metas = [model_metas[i] for i in indice_cpu]

        # for ssm
        state_offsets = inputs.state_offsets
        if state_offsets is not None:
            state_offsets = state_offsets[indices]

        # return new inputs
        return ModelInputs(
            input_ids=input_ids,
            seq_length=seq_length,
            history_lengths=history_lengths,
            block_offsets=block_offsets,
            is_decoding=inputs.is_decoding,
            num_ignored_history=num_ignored_history,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            local_adapter_ids=local_adapter_ids,
            model_metas=model_metas,
            state_offsets=state_offsets,
        )


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..ar.sampling import ARSamplingStrategy


class ARSpecSamplingStrategy(ARSamplingStrategy):
    """Sampling strategy for AR with spec models."""


================================================
FILE: lmdeploy/pytorch/strategies/ar_spec/sequence.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
from torch import Tensor

from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,
                                       SchedulerSession, UpdateTokenMode, _to_ndarray)
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..ar.sequence import ARSequenceStrategy, SchedulerSequenceDefault

SeqList = List['SchedulerSequenceARSpec']


@dataclass
class SchedulerSequenceARSpec(SchedulerSequenceDefault):

    def __post_init__(self):
        """Post init."""
        super().__post_init__()
        self._num_spec_ids: int = 0
        self._num_new_valid: int = 0
        self._num_valid_ids: int = len(self.history_cache)
        self._strategy: ARSpecSequenceStrategy = self._seq_meta.strategy

    @property
    def num_valid_ids(self):
        return self._num_valid_ids

    @property
    def num_spec_ids(self):
        return self._num_spec_ids

    @property
    def generated_ids(self) -> np.ndarray:
        end = self.num_valid_ids
        start = end - self.num_new_tokens
        return self.history_cache[start:end]

    def set_stop_pos(self, pos: int):
        val = self._num_new_valid - pos - 1
        self._num_valid_ids -= val
        self.num_new_tokens -= val
        self._num_token_ids = 1
        self._num_history_ids -= val

        self._num_spec_ids = 0
        self._num_new_valid = 0
        self.history_cache.resize(self.num_valid_ids)

    def _update_token_ids_inputs(self, token_ids: np.ndarray):
        """Append tokens."""
        num_tokens = len(token_ids)
        self.output_start_pos = self.num_valid_ids + num_tokens
        self._num_valid_ids = self.num_history_ids + num_tokens
        self._num_token_ids = num_tokens
        self.num_new_tokens = 0
        self._num_spec_ids = 0
        self._num_new_valid = 0
        self.history_cache.append(token_ids)

    def _update_token_ids_prefill(self, token_ids: np.ndarray, draft_token_ids: np.ndarray):
        """Update token ids for prefill."""
        num_valid = len(token_ids)
        self._num_spec_ids = len(draft_token_ids)
        token_ids = np.concatenate([token_ids, draft_token_ids])
        num_tokens = len(token_ids)
        self._num_history_ids += self._num_token_ids
        self._num_token_ids = num_tokens
        self.num_new_tokens += num_valid
        self._num_new_valid = num_valid
        self._num_valid_ids = self.num_history_ids + num_valid
        self.history_cache.append(token_ids)

    def _update_token_ids_decode(self, token_ids: np.ndarray, draft_token_ids: np.ndarray = None):
        """Update token ids for decode."""
        valid_ids = token_ids[token_ids > -1]
        num_valid = len(valid_ids)
        self.num_new_tokens = self.num_new_tokens + num_valid

        self._num_new_valid = num_valid
        self._num_valid_ids += num_valid
        self._num_history_ids = self.num_valid_ids - 1

        # last step has spec ids
        if self.num_spec_ids > 0:
            token_ids = valid_ids[-1:]
        else:
            token_ids = valid_ids

        num_tokens = len(token_ids)

        if draft_token_ids is not None:
            num_tokens = 1 + len(draft_token_ids)
            token_ids = np.concatenate([token_ids, draft_token_ids])
            self._num_spec_ids = len(draft_token_ids)
        else:
            self._num_spec_ids = 0

        self._num_token_ids = num_tokens
        if self.num_history_ids < len(self.history_cache):
            self.history_cache.resize(self.num_history_ids)
        self.history_cache.append(token_ids)

    def update_token_ids(self,
                         token_ids: Tensor,
                         multimodals: MultiModalInputs = None,
                         embeddings: List[InputEmbeddings] = None,
                         model_meta: Dict[str, Any] = None,
                         draft_token_ids: Tensor = None,
                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
                         **kwargs):
        """Update token ids, old token ids will be added to history."""
        # update history image nums
        self._update_embeddings(embeddings)

        # update multimodals
        self._update_multimodals(multimodals)

        self.arrive_time = time.perf_counter()

        token_ids: np.ndarray = _to_ndarray(token_ids)
        if draft_token_ids is not None:
            draft_token_ids = _to_ndarray(draft_token_ids)
        if mode == UpdateTokenMode.INPUTS:
            self._update_token_ids_inputs(token_ids)
        elif mode == UpdateTokenMode.PREFILL:
            self._update_token_ids_prefill(token_ids, draft_token_ids)
        else:
            self._update_token_ids_decode(token_ids, draft_token_ids)
        if model_meta is not None:
            self.model_meta = model_meta


class ARSpecSequenceStrategy(ARSequenceStrategy):

    def make_sequence(self,
                      seq_id: int,
                      session: 'SchedulerSession',
                      sampling_param: 'SamplingParam' = None,
                      adapter_name: str = None,
                      migration_request: Optional[MigrationRequest] = None,
                      resp_cache: bool = False,
                      preserve_cache: bool = False) -> 'SchedulerSequenceARSpec':
        """Make sequence."""
        return SchedulerSequenceARSpec(seq_id=seq_id,
                                       session=session,
                                       sampling_param=sampling_param,
                                       adapter_name=adapter_name,
                                       migration_request=migration_request,
                                       resp_cache=resp_cache,
                                       preserve_cache=preserve_cache)

    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',
                       delta: 'ModelInputsDelta', **kwargs) -> None:
        """Update running sequences."""
        next_token_ids = batched_outputs.next_token_ids
        extra_outputs = batched_outputs.extra_outputs
        stopped = batched_outputs.stopped
        stopped = stopped.tolist()
        model_metas = batched_outputs.model_metas
        if model_metas is None:
            model_metas = [None] * len(running)
        stop_pos = batched_outputs.stop_pos

        if model_inputs is None:
            is_decoding = delta.is_decoding
        else:
            is_decoding = model_inputs.is_decoding

        batch_size = len(running)
        next_token_ids = next_token_ids.view(batch_size, -1).numpy()
        if extra_outputs is None or extra_outputs.draft_token_ids is None:
            draft_token_ids = [None] * batch_size
        else:
            draft_token_ids = extra_outputs.draft_token_ids.numpy()
        stop_pos = stop_pos.tolist()
        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL

        for idx, token in enumerate(next_token_ids):
            msg = running[idx]
            stop = stopped[idx]
            model_meta = model_metas[idx]
            if msg.status != MessageStatus.RUNNING:
                continue
            cur_draft_tokens = draft_token_ids[idx]
            # fill token
            msg.update_token_ids(token, draft_token_ids=cur_draft_tokens, model_meta=model_meta, mode=update_mode)
            if stop:
                msg.set_stop_pos(stop_pos[idx])
                msg.state.finish()


================================================
FILE: lmdeploy/pytorch/strategies/base/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

    from .cudagraph import CudagraphStrategy
    from .engine import EngineStrategy
    from .model_agent import ModelAgentStrategy
    from .model_inputs import ModelInputsStrategy
    from .sampling import SamplingStrategy
    from .sequence import SequenceStrategy


class StrategyFactoryBase(ABC):

    @abstractmethod
    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
        """Build cudagraph strategy."""
        pass

    @abstractmethod
    def build_sampling_strategy(self) -> 'SamplingStrategy':
        """Build sampling strategy."""
        pass

    @abstractmethod
    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
        """Build model inputs strategy."""
        pass

    @abstractmethod
    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
        """Build model agent strategy."""
        pass

    @abstractmethod
    def build_engine_strategy(self, cache_config: 'CacheConfig',
                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
        """Build engine strategy."""
        pass

    @abstractmethod
    def build_sequence_strategy(self) -> 'SequenceStrategy':
        """Build sequence strategy."""
        pass


================================================
FILE: lmdeploy/pytorch/strategies/base/cudagraph.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod


class CudagraphStrategy(ABC):

    @abstractmethod
    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:
        """Get max tokens."""
        pass


================================================
FILE: lmdeploy/pytorch/strategies/base/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod


class EngineStrategy(ABC):
    """Engine strategy."""

    @abstractmethod
    def get_prealloc_size(self, is_decoding: bool) -> int:
        """Get prealloc_size."""
        pass

    @abstractmethod
    def get_num_loops(self, is_decoding: bool) -> int:
        """Get num_loops."""
        pass

    @abstractmethod
    def get_num_decode_tokens(self) -> int:
        """Get num_decode_tokens."""
        pass

    def get_num_required_tokens(self) -> int:
        """Get num_require_tokens."""
        return self.get_num_decode_tokens()


================================================
FILE: lmdeploy/pytorch/strategies/base/model_agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np
import torch

if TYPE_CHECKING:
    from lmdeploy.pytorch.distributed import DistContext
    from lmdeploy.pytorch.engine.logits_process import SamplingInputs
    from lmdeploy.pytorch.messages import SchedulerSequence
    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta
    SeqList = List[SchedulerSequence]


def to_device(self, device: str, non_blocking: bool = False):
    """To device."""
    out_dict = dict()
    for f in fields(self):
        k = f.name
        v = getattr(self, k)
        if isinstance(v, torch.Tensor):
            v = v.to(device, non_blocking=non_blocking)
        out_dict[k] = v

    return type(self)(**out_dict)


@dataclass
class ExtraInputs(ABC):

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        return to_device(self, device, non_blocking)

    def broadcast(self, src: int, group, async_op=False):
        """Broadcast extra inputs."""
        pass

    def merge(self, other: 'ExtraInputs'):
        """Merge extra inputs."""
        return self


@dataclass
class ExtraOutputs(ABC):

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        return to_device(self, device, non_blocking)

    def to_cpu(self):
        """To cpu."""
        return self.to_device('cpu', non_blocking=False)

    def to_numpy(self):
        """To numpy."""
        out = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16:
                v = v.detach().numpy()
            elif hasattr(v, 'to_numpy'):
                v = v.to_numpy()
            out[k] = v
        return type(self)(**out)

    def to_tensor(self):
        """To tensor."""
        out = dict()
        for f in fields(self):
            k = f.name
            v = getattr(self, k)
            if isinstance(v, np.ndarray):
                v = torch.from_numpy(v)
            elif hasattr(v, 'to_tensor'):
                v = v.to_tensor()
            out[k] = v
        return type(self)(**out)


@dataclass
class StoppingCriteria(ABC):
    """Base class for stopping criteria."""

    @abstractmethod
    def clone(self) -> 'StoppingCriteria':
        """clone."""

    @abstractmethod
    def merge(self, other: 'StoppingCriteria') -> 'StoppingCriteria':
        """Merge two stopping criteria."""

    @abstractmethod
    def update(self, delta: 'ModelInputsDelta') -> 'StoppingCriteria':
        """Update stopping criteria."""

    @abstractmethod
    def step(self,
             token_ids: torch.Tensor,
             stop_words: torch.Tensor,
             inputs: Optional['ModelInputs'] = None,
             extra_inputs: Optional[ExtraInputs] = None):
        """Check whether to stop generation."""
        pass

    def to_device(self, device: str, non_blocking: bool = False):
        """To device."""
        return to_device(self, device, non_blocking)


class ModelAgentStrategy(ABC):
    """Base class for model agent strategies."""

    @abstractmethod
    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
        """Slice outputs."""
        pass

    @abstractmethod
    def slice_extra_inputs(self, extra_inputs: ExtraInputs, model_inputs: 'ModelInputs',
                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ExtraInputs:
        """Slice outputs."""
        pass

    @abstractmethod
    def make_stopping_criteria(self, seqs: 'SeqList') -> StoppingCriteria:
        """Create stopping criteria."""
        pass

    @abstractmethod
    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:
        """Create extra inputs."""
        pass

    def update_extra_inputs(self, extra_inputs: ExtraInputs, delta: 'ModelInputsDelta') -> ExtraInputs:
        """Update extra inputs with model inputs delta."""
        return extra_inputs

    @abstractmethod
    def make_extra_outputs(self, extra_inputs: ExtraInputs) -> ExtraOutputs:
        """Create extra outputs."""
        pass

    @abstractmethod
    def step_sampling_inputs(
        self,
        sampling_inputs: 'SamplingInputs',
        next_token_ids: torch.Tensor,
        extra_inputs: ExtraInputs,
    ):
        """step."""
        pass

    @abstractmethod
    def update_prefill_for_next_step(
        self,
        model_inputs: 'ModelInputs',
        extra_inputs: ExtraInputs,
        next_token_ids: torch.Tensor,
        model_metas: Any,
        extra_outputs: ExtraOutputs,
    ) -> Tuple['ModelInputs', ExtraInputs]:
        """Step next decoding."""
        pass

    @abstractmethod
    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
                                      extra_inputs: ExtraInputs,
                                      extra_outputs: ExtraOutputs) -> Tuple['ModelInputs', ExtraInputs]:
        """Step next inputs."""
        pass

    @abstractmethod
    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
                      extra_inputs: ExtraInputs):
        """Post sampling."""
        pass

    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):
        """Make dummy next token for broadcast."""
        with torch.inference_mode():
            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
        return next_token_ids, extra_inputs

    @abstractmethod
    @contextmanager
    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: 'DistContext'):
        """Broadcast next token ids and extra inputs."""


================================================
FILE: lmdeploy/pytorch/strategies/base/model_inputs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch
from torch.profiler import record_function

from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta


@record_function('make_dummy_input')
def make_dummy_inputs(batch_size: int,
                      max_q_seqlen: int,
                      is_decoding: bool,
                      device: str = 'cpu',
                      dummy_block_id: int = 0,
                      vocab_size: int = 1):
    """Make dummy inputs global implement."""
    num_tokens = batch_size * max_q_seqlen
    max_kv_seqlen = max_q_seqlen
    input_ids = torch.randint(0, vocab_size, (
        1,
        num_tokens,
    ), dtype=torch.long, device=device)
    seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long, device=device)
    history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device)
    block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device)
    num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)
    local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device)
    state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device)

    return ModelInputs(
        input_ids=input_ids,
        seq_length=seq_length,
        history_lengths=history_lengths,
        block_offsets=block_offsets,
        is_decoding=is_decoding,
        num_ignored_history=num_ignored_history,
        max_q_seqlen=max_q_seqlen,
        max_kv_seqlen=max_kv_seqlen,
        sum_kv_seqlen=num_tokens,
        local_adapter_ids=local_adapter_ids,
        is_dummy=True,
        state_offsets=state_offsets,
    )


class ModelInputsStrategy(ABC):

    @abstractmethod
    def make_dummy(self,
                   batch_size: int,
                   is_decoding: bool,
                   device: str = 'cpu',
                   dummy_block_id: int = 0,
                   vocab_size: int = 1) -> ModelInputs:
        """Create dummy model inputs."""
        pass

    @abstractmethod
    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
        """Merge model inputs."""
        pass

    @abstractmethod
    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:
        """Update model inputs with delta."""
        pass


================================================
FILE: lmdeploy/pytorch/strategies/base/sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List

import torch

from lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputsDelta

from .model_agent import ExtraInputs

SeqList = List[SchedulerSequence]


class SamplingStrategy(ABC):
    """Base class for sampling strategies."""

    @abstractmethod
    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
        """Create sampling inputs from the sequences."""
        pass

    @abstractmethod
    def on_session_end(self, session_id: int) -> None:
        """Invoked on session ends."""
        pass

    @abstractmethod
    def merge_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        other: 'SamplingInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Merge two sampling deltas."""

    @abstractmethod
    def step_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        next_token_ids: torch.Tensor,
        extra_inputs: 'ExtraInputs',
    ) -> 'SamplingInputsDelta':
        """Step next delta."""
        pass

    @abstractmethod
    def update_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        delta: 'ModelInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Update sampling delta with model inputs delta."""
        pass


================================================
FILE: lmdeploy/pytorch/strategies/base/sequence.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional

from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest

if TYPE_CHECKING:
    from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
    from lmdeploy.pytorch.messages import SamplingParam, SchedulerSequence, SchedulerSession
    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta
    SeqList = List[SchedulerSequence]


class SequenceStrategy(ABC):

    @abstractmethod
    def make_sequence(self,
                      seq_id: int,
                      session: 'SchedulerSession',
                      sampling_param: 'SamplingParam' = None,
                      adapter_name: str = None,
                      migration_request: Optional[MigrationRequest] = None,
                      resp_cache: bool = False,
                      preserve_cache: bool = False) -> 'SchedulerSequence':
        """Make sequence."""
        pass

    @abstractmethod
    def update_running(self, running: 'SeqList', batched_outputs: 'BatchedOutputs', model_inputs: 'ModelInputs',
                       delta: 'ModelInputsDelta') -> None:
        """Update running sequences."""
        pass


================================================
FILE: lmdeploy/pytorch/strategies/dllm/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from lmdeploy.pytorch.config import DLLMConfig, ModelConfig
from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy
from lmdeploy.utils import get_logger

if TYPE_CHECKING:
    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy
    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig

from ..base import StrategyFactoryBase

logger = get_logger('lmdeploy')


class DLLMStrategyFactory(StrategyFactoryBase):

    def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig):
        """config."""
        self.model_config = model_config
        self.dllm_config = dllm_config

        # update dllm_block_length
        self.dllm_block_length = self._update_dllm_block_length()

    def _update_dllm_block_length(self):
        """Update dllm_block_length."""
        if self.dllm_config.block_length is None:
            dllm_block_length = self.model_config.dllm_block_length
            if dllm_block_length is None:
                dllm_block_length = 4
                logger.warning('Model does not provide dllm_block_length. '
                               f'Set dllm_block_length={dllm_block_length} as default.')
        else:
            dllm_block_length = self.dllm_config.block_length

        assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config'

        self.dllm_config.block_length = dllm_block_length
        self.model_config.dllm_block_length = dllm_block_length

        if self.dllm_config.denoising_steps is None:
            self.dllm_config.denoising_steps = dllm_block_length
        return dllm_block_length

    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
        """Build cudagraph strategy."""
        from .cudagraph import DLLMCudagraphStrategy
        return DLLMCudagraphStrategy(block_size=self.dllm_block_length)

    def build_sampling_strategy(self) -> 'SamplingStrategy':
        """Build sampling strategy."""
        from .sampling import DLLMSamplingStrategy
        pad_token_id = self.model_config.bos_token_id
        pad_token_id = 0 if pad_token_id is None else pad_token_id
        return DLLMSamplingStrategy(pad_token_id, self.dllm_block_length)

    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
        """Build model inputs strategy."""
        from .model_inputs import DLLMModelInputsStrategy
        return DLLMModelInputsStrategy(block_size=self.dllm_block_length)

    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
        """Build model agent strategy."""
        from .model_agent import DLLMModelAgentStrategy
        return DLLMModelAgentStrategy(dllm_config=self.dllm_config, dllm_mask_token=self.model_config.dllm_mask_token)

    def build_engine_strategy(self, cache_config: 'CacheConfig',
                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
        """Build engine strategy."""
        from .engine import DLLMEngineStrategy
        return DLLMEngineStrategy(cache_config=cache_config,
                                  scheduler_config=scheduler_config,
                                  dllm_block_length=self.dllm_block_length)

    def build_sequence_strategy(self) -> SequenceStrategy:
        from .sequence import DLLMSequenceStrategy
        return DLLMSequenceStrategy(block_size=self.dllm_block_length,
                                    dllm_mask_token=self.model_config.dllm_mask_token)


================================================
FILE: lmdeploy/pytorch/strategies/dllm/cudagraph.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from ..base.cudagraph import CudagraphStrategy


class DLLMCudagraphStrategy(CudagraphStrategy):

    def __init__(self, block_size: int) -> None:
        super().__init__()
        self.block_size = block_size

    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:
        """Get max tokens."""
        return batch_size * self.block_size


================================================
FILE: lmdeploy/pytorch/strategies/dllm/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache

from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
from lmdeploy.utils import get_logger

from ..base.engine import EngineStrategy

logger = get_logger('lmdeploy')


class DLLMEngineStrategy(EngineStrategy):
    """DLLM Engine Strategy."""

    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, dllm_block_length: int) -> None:
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
        self.dllm_block_length = dllm_block_length

        self._check()

    def _check(self):
        """check."""
        max_prefill_token_num = self.cache_config.max_prefill_token_num
        max_batches = self.cache_config.max_batches
        if self.dllm_block_length * max_batches > max_prefill_token_num:
            logger.warning(f'dllm_block_length({self.dllm_block_length}) * max_batch_size ({max_batches}) '
                           f'> max_prefill_token_num ({max_prefill_token_num}). '
                           'This may lead to OOM. Consider to reduce max_batch_size or dllm_block_length.')

    @lru_cache(maxsize=2)
    def get_prealloc_size(self, is_decoding: bool) -> int:
        """Get prealloc_size."""
        if not is_decoding:
            return 0
        block_size = self.cache_config.block_size
        dllm_block_length = self.dllm_block_length
        num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length)
        return num_blocks * dllm_block_length

    @lru_cache(maxsize=2)
    def get_num_loops(self, is_decoding: bool) -> int:
        """Get num_loops."""
        if not is_decoding:
            return 1
        block_size = self.cache_config.block_size
        dllm_block_length = self.dllm_block_length
        max_num_loops = block_size // dllm_block_length * 2
        num_loops = min(self.scheduler_config.prefill_interval, max_num_loops)
        return num_loops

    def get_num_decode_tokens(self) -> int:
        """Get num_decode_tokens."""
        return self.dllm_block_length


================================================
FILE: lmdeploy/pytorch/strategies/dllm/model_agent.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
from torch.profiler import record_function

from lmdeploy.pytorch import consts
from lmdeploy.pytorch.config import DLLMConfig
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria
from .unmasking import UnmaskingProcessor

SeqList = List[SchedulerSequence]


def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen,
                                   step_seqlens: torch.Tensor, model_metas) -> ModelInputs:
    """Next decoding step."""
    if input_ids.dim() == 1:
        input_ids = input_ids[None, :]
    step_seqlens = torch.where(step_seqlens > 0, step_seqlens, inputs.seq_length - max_q_seqlen)
    return ModelInputs(
        input_ids=input_ids,
        seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),
        history_lengths=inputs.history_lengths + step_seqlens,
        block_offsets=inputs.block_offsets,
        is_decoding=True,
        num_ignored_history=inputs.num_ignored_history,
        max_q_seqlen=max_q_seqlen,
        max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,
        sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,
        local_adapter_ids=inputs.local_adapter_ids,
        model_metas=model_metas,
        state_offsets=inputs.state_offsets,
    )


@dataclass
class DLLMExtraInputs(ExtraInputs):
    """DLLM extra inputs."""
    dllm_mask: torch.Tensor

    def broadcast(self, src: int, group, async_op=False):
        return dist.broadcast(self.dllm_mask, src=src, group=group, async_op=async_op)

    def merge(self, other: 'DLLMExtraInputs'):
        """Merge extra inputs."""
        dllm_mask = torch.cat([self.dllm_mask, other.dllm_mask], dim=0)
        return DLLMExtraInputs(dllm_mask=dllm_mask)


@dataclass
class DLLMExtraOutputs(ExtraOutputs):
    """Ar extra outputs."""
    dllm_mask: torch.Tensor


def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor,
                          stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor,
                          output_start_pos: torch.Tensor, inputs: ModelInputs):
    num_tokens = token_ids.size(0)
    batch_size = num_appendable_ids.size(0)
    block_size = num_tokens // batch_size

    # blocks might contain stop words in prev-round chat
    # these stop words should be ignored
    kv_seqlens = inputs.history_lengths + inputs.seq_length
    ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0)
    ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device)
    ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten()
    token_ids = token_ids.clone()
    token_ids[ignore_mask] = -1

    # find stop words
    sw_stopped = (token_ids[:, None] == stop_words).any(1)
    sw_stopped = sw_stopped.view(batch_size, block_size)
    sw_stop_pos = sw_stopped.int().argmax(1)

    stop_pos = torch.where(stopped, stop_pos, sw_stop_pos)
    sw_stopped = sw_stopped.any(dim=1)
    sw_stopped = sw_stopped & is_unmasked
    stopped = stopped | sw_stopped

    # update num_appendable_ids
    one_ids = torch.clamp_max(num_appendable_ids, 0)
    num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)

    return stopped, stop_pos, num_appendable_ids


@dataclass
class DLLMStoppingCriteria(StoppingCriteria):
    num_appendable_ids: torch.Tensor
    output_start_pos: torch.Tensor

    def clone(self) -> 'DLLMStoppingCriteria':
        """clone."""
        return DLLMStoppingCriteria(num_appendable_ids=self.num_appendable_ids, output_start_pos=self.output_start_pos)

    def merge(self, other: 'DLLMStoppingCriteria') -> 'DLLMStoppingCriteria':
        """Merge two stopping criteria."""
        return DLLMStoppingCriteria(num_appendable_ids=torch.cat([self.num_appendable_ids, other.num_appendable_ids],
                                                                 dim=0),
                                    output_start_pos=torch.cat([self.output_start_pos, other.output_start_pos], dim=0))

    def update(self, delta: 'ModelInputsDelta') -> 'DLLMStoppingCriteria':
        """Update stopping criteria."""
        indices = delta.indices
        return DLLMStoppingCriteria(num_appendable_ids=self.num_appendable_ids[indices],
                                    output_start_pos=self.output_start_pos[indices])

    @record_function('stopping_criteria')
    def step(self,
             token_ids: torch.Tensor,
             stop_words: torch.Tensor,
             inputs: Optional[ModelInputs] = None,
             extra_inputs: Optional[DLLMExtraInputs] = None):
        """Check whether to stop generation."""
        num_appendable_ids = self.num_appendable_ids
        output_start_pos = self.output_start_pos
        num_tokens = token_ids.size(0)
        batch_size = num_appendable_ids.size(0)
        block_size = num_tokens // batch_size

        dllm_mask = extra_inputs.dllm_mask
        dllm_mask = dllm_mask.view(batch_size, block_size)
        is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1)

        # check stop by num_new_tokens
        num_appendable_ids -= is_unmasked * block_size
        stopped = num_appendable_ids <= 0
        stop_pos = block_size - 1 + num_appendable_ids

        # check stop words
        if stop_words is not None:
            stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids,
                                                                          stop_words,
                                                                          is_unmasked,
                                                                          stopped,
                                                                          stop_pos,
                                                                          num_appendable_ids,
                                                                          output_start_pos=output_start_pos,
                                                                          inputs=inputs)

        new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos)
        return stopped, stop_pos, new_stopping


class DLLMModelAgentStrategy(ModelAgentStrategy):

    def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int):
        block_size = dllm_config.block_length
        self.block_size = block_size
        self.dllm_mask_token = dllm_mask_token

        self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config)

    def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor):
        """Update token_ids and dllm_mask."""
        dllm_mask_token = self.dllm_mask_token
        dllm_block_length = self.block_size

        # reshape to (batch, dllm_block_length)
        next_token_ids = next_token_ids.view(-1, dllm_block_length).clone()
        dllm_mask = dllm_mask.view(-1, dllm_block_length).clone()

        # flags
        is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1)

        is_masked = (dllm_mask == consts.DLLM_MASKED)
        next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token
        dllm_mask[is_cached] = consts.DLLM_MASKED
        seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, )))

        return next_token_ids.flatten(), dllm_mask.flatten(), seqlens

    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
        """Slice outputs."""
        block_length = self.block_size
        # batch size = 1
        if len(seq_length) == 1:
            return inputs[-block_length:]

        if len(seq_length) * block_length == inputs.size(0):
            return inputs
        last_idx = seq_length.cumsum(0)
        block_range = torch.arange(-block_length, 0, device=last_idx.device)
        index = (last_idx[:, None] + block_range[None, :]).flatten()
        inputs = inputs[index]
        return inputs

    def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, model_inputs: ModelInputs,
                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> DLLMExtraInputs:
        """Slice outputs."""
        dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, model_inputs.seq_length)
        return DLLMExtraInputs(dllm_mask=dllm_mask)

    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor,
                             extra_inputs: DLLMExtraInputs, **kwargs):
        """Step sampling inputs."""
        from lmdeploy.pytorch import consts
        dllm_mask = extra_inputs.dllm_mask
        dllm_block_size = self.block_size
        DLLM_UNMASKED = consts.DLLM_UNMASKED
        is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)
        num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size)
        num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)
        sampling_inputs.num_ignore_eos = num_ignore_eos.flatten()
        if sampling_inputs.random_offsets is not None:
            # random offset is used to generate random numbers for multinomial sampling
            # so we need to increase it by 1 at each step
            sampling_inputs.random_offsets += 1
        return sampling_inputs

    def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria:
        """Create stopping criteria."""
        # num_appendable
        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
        num_appendable = torch.tensor(num_appendable)
        block_size = self.block_size
        remain = [seq.num_valid_ids % block_size for seq in seqs]
        num_appendable += torch.tensor(remain)

        # output_start_pos
        pos = [seq.output_start_pos for seq in seqs]
        output_start_pos = torch.tensor(pos)

        return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos)

    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:
        """Create extra inputs."""
        dllm_masks = [seq.dllm_mask for seq in seqs]

        # chunked prefill only require part of the dllm masks
        if model_inputs.is_chunk:
            seqlens = model_inputs.seq_length.tolist()
            dllm_masks = [mask[:length] for mask, length in zip(dllm_masks, seqlens)]

        dllm_masks = torch.as_tensor(np.concatenate(dllm_masks))
        return DLLMExtraInputs(dllm_mask=dllm_masks)

    def update_extra_inputs(self, extra_inputs: DLLMExtraInputs, delta: 'ModelInputsDelta') -> DLLMExtraInputs:
        """Update extra inputs with model inputs delta."""
        dllm_mask = extra_inputs.dllm_mask
        dllm_mask = dllm_mask.reshape(-1, self.block_size)

        indices = delta.indices
        dllm_mask = dllm_mask[indices].flatten()

        return DLLMExtraInputs(dllm_mask=dllm_mask)

    def make_extra_outputs(self, extra_inputs: DLLMExtraInputs) -> DLLMExtraOutputs:
        """Create extra outputs."""
        dllm_mask = extra_inputs.dllm_mask
        return DLLMExtraOutputs(dllm_mask=dllm_mask)

    def update_prefill_for_next_step(
        self,
        model_inputs: 'ModelInputs',
        extra_inputs: DLLMExtraInputs,
        next_token_ids: torch.Tensor,
        model_metas: Any,
        extra_outputs: DLLMExtraOutputs,
    ) -> Tuple['ModelInputs', DLLMExtraInputs]:
        """Step next decoding."""
        dllm_mask = extra_outputs.dllm_mask
        next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)

        inputs = get_model_inputs_next_decoding(model_inputs,
                                                next_token_ids,
                                                model_metas=model_metas,
                                                max_q_seqlen=self.block_size,
                                                step_seqlens=step_seqlens)
        extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)
        return inputs, extra_inputs

    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
                                      extra_inputs: DLLMExtraInputs, **kwargs):
        """Step next inputs."""
        model_inputs.model_metas = model_metas
        dllm_mask = extra_inputs.dllm_mask

        next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)
        model_inputs.step(next_token_ids, step_seqlens)

        extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)
        return model_inputs, extra_inputs

    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
                      extra_inputs: DLLMExtraInputs):
        """Post sampling."""
        dllm_mask = extra_inputs.dllm_mask
        input_ids = inputs.input_ids
        input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length)

        dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask)

        extra_inputs.dllm_mask = dllm_mask
        return next_token_ids, extra_inputs

    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: DLLMExtraInputs):
        """Make dummy next token for broadcast."""
        with torch.inference_mode():
            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
        return next_token_ids, extra_inputs

    @contextmanager
    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext):
        """Broadcast next token ids and extra inputs."""
        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group
        rank = dist.get_global_rank(tp_gpu_group, 0)
        dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)
        handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True)
        yield
        handle.wait()


================================================
FILE: lmdeploy/pytorch/strategies/dllm/model_inputs.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..ar.model_inputs import merge_model_inputs
from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs


class DLLMModelInputsStrategy(ModelInputsStrategy):

    def __init__(self, block_size: int):
        self.block_size = block_size

    def make_dummy(self,
                   batch_size: int,
                   is_decoding: bool,
                   device: str = 'cpu',
                   dummy_block_id: int = 0,
                   vocab_size: int = 1) -> ModelInputs:
        """Create dummy model inputs."""
        return make_dummy_inputs(batch_size,
                                 max_q_seqlen=self.block_size,
                                 is_decoding=is_decoding,
                                 device=device,
                                 dummy_block_id=dummy_block_id,
                                 vocab_size=vocab_size)

    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:
        """Merge model inputs."""
        return merge_model_inputs(inputs, other)

    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:
        """Update model inputs with delta."""

        assert inputs.is_decoding, 'Only support index_select in decoding.'
        indices = delta.indices
        indice_cpu = delta.indice_cpu
        block_offsets = delta.block_offsets
        max_q_seqlen = delta.max_q_seqlen
        max_kv_seqlen = delta.max_kv_seqlen
        sum_kv_seqlen = delta.sum_kv_seqlen
        num_ignored_history = delta.num_ignored_history

        # required inputs
        # input_ids = inputs.input_ids[..., indices]
        inputs_ids = inputs.input_ids.reshape(1, -1, self.block_size)
        input_ids = inputs_ids[:, indices].reshape(1, -1)
        seq_length = inputs.seq_length[indices]
        history_lengths = inputs.history_lengths[indices]
        if block_offsets is None:
            block_offsets = inputs.block_offsets[indices]
        if num_ignored_history is None:
            num_ignored_history = inputs.num_ignored_history[indices]
        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen
        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen
        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen

        # lora adapter ids
        local_adapter_ids = inputs.local_adapter_ids
        if local_adapter_ids is not None:
            local_adapter_ids = local_adapter_ids[indices]

        # model metas for vl models
        model_metas = inputs.model_metas
        if model_metas is not None and indice_cpu is not None:
            model_metas = [model_metas[i] for i in indice_cpu]

        # for ssm
        state_offsets = inputs.state_offsets
        if state_offsets is not None:
            state_offsets = state_offsets[indices]

        # return new inputs
        return ModelInputs(
            input_ids=input_ids,
            seq_length=seq_length,
            history_lengths=history_lengths,
            block_offsets=block_offsets,
            is_decoding=inputs.is_decoding,
            num_ignored_history=num_ignored_history,
            max_q_seqlen=max_q_seqlen,
            max_kv_seqlen=max_kv_seqlen,
            sum_kv_seqlen=sum_kv_seqlen,
            local_adapter_ids=local_adapter_ids,
            model_metas=model_metas,
            state_offsets=state_offsets,
        )


================================================
FILE: lmdeploy/pytorch/strategies/dllm/sampling.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import numpy as np
import torch
from torch.profiler import record_function

from lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputsDelta

from ..ar.sampling import ARSamplingStrategy
from .model_agent import DLLMExtraInputs

SeqList = List[SchedulerSequence]


class DLLMSamplingStrategy(ARSamplingStrategy):
    """Sampling strategy for autoregressive models."""

    def __init__(self, pad_token_id: int, dllm_block_length: int) -> None:
        super().__init__(pad_token_id)
        self.dllm_block_length = dllm_block_length

    @record_function('make_sampling_inputs')
    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
        """Create sampling inputs from the sequences."""
        out = super().make_sampling_inputs(seqs)
        dllm_block_length = self.dllm_block_length

        # repeat tensor
        update_attr_names = [
            'temperature',
            'bad_words',
            'bad_mask',
            'stop_words',
            'stop_mask',
            'repetition_penalty',
            'top_k',
            'top_p',
            'min_p',
            'random_seeds',
            'random_offsets',
            'all_ids',
            'num_ignore_eos',
            'ngram_size',
            'ngram_threshold',
        ]
        for name in update_attr_names:
            attr = getattr(out, name)
            if attr is None:
                continue
            if attr.dim() == 1:
                repeats = (dllm_block_length, 1)
                attr = attr[None].repeat(*repeats).flatten(0, 1)
            elif attr.dim() == 2:
                repeats = (1, dllm_block_length, 1)
                attr = attr[:, None].repeat(*repeats).flatten(0, 1)
            else:
                repeats = (dllm_block_length, ) + (1, ) * (attr.dim())
                attr = attr[None].repeat(*repeats).flatten(0, 1)
            setattr(out, name, attr)

        # update generated_ids_cpu
        if out.generated_ids_cpu is not None:
            generated_ids_cpu = out.generated_ids_cpu
            if generated_ids_cpu.shape[1] == 0:
                out.generated_ids_cpu = np.repeat(generated_ids_cpu, dllm_block_length, axis=0)
            else:
                generated_ids_cpu = np.repeat(generated_ids_cpu[:, None], dllm_block_length, axis=1)
                generated_ids_cpu = np.reshape(generated_ids_cpu, (-1, generated_ids_cpu.shape[-1]))
                out.generated_ids_cpu = generated_ids_cpu

        if len(out.response_formats) > 0:
            new_resp_formats = []
            for resp in out.response_formats:
                new_resp_formats += [resp] * dllm_block_length
            out.response_formats = tuple(new_resp_formats)

        out.batch_size *= dllm_block_length

        return out

    def merge_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        other: 'SamplingInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Merge two sampling deltas."""
        num_ignore_eos = torch.cat([sampling_delta.num_ignore_eos, other.num_ignore_eos], 0)
        random_offsets = torch.cat([sampling_delta.random_offsets, other.random_offsets], 0)

        return SamplingInputsDelta(
            num_ignore_eos=num_ignore_eos,
            random_offsets=random_offsets,
            all_ids=None,
        )

    def update_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        delta: 'ModelInputsDelta',
    ) -> 'SamplingInputsDelta':
        """Update sampling delta with model inputs delta."""
        indices = delta.indices
        num_ignore_eos = sampling_delta.num_ignore_eos.view(-1, self.dllm_block_length)
        num_ignore_eos = num_ignore_eos[indices].flatten()
        if sampling_delta.random_offsets is not None:
            random_offsets = sampling_delta.random_offsets.view(-1, self.dllm_block_length)
            random_offsets = random_offsets[indices].flatten()
        else:
            random_offsets = None
        return SamplingInputsDelta(
            num_ignore_eos=num_ignore_eos,
            random_offsets=random_offsets,
            all_ids=None,
        )

    def step_sampling_delta(
        self,
        sampling_delta: 'SamplingInputsDelta',
        next_token_ids: torch.Tensor,
        extra_inputs: 'DLLMExtraInputs',
    ) -> 'SamplingInputsDelta':
        """Step next delta."""
        from lmdeploy.pytorch import consts
        dllm_mask = extra_inputs.dllm_mask
        dllm_block_size = self.dllm_block_length
        DLLM_UNMASKED = consts.DLLM_UNMASKED
        is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)
        num_ignore_eos = sampling_delta.num_ignore_eos.view(-1, dllm_block_size)
        num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)
        sampling_delta.num_ignore_eos = num_ignore_eos.flatten()
        if sampling_delta.random_offsets is not None:
            # random offset is used to generate random numbers for multinomial sampling
            # so we need to increase it by 1 at each step
            sampling_delta.random_offsets += 1
        return sampling_delta


================================================
FILE: lmdeploy/pytorch/strategies/dllm/sequence.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import numpy as np
from torch import Tensor

from lmdeploy.pytorch import consts
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,
                                       SchedulerSession, UpdateTokenMode, _to_ndarray)
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta

from ..ar.sequence import SchedulerSequenceDefault
from ..base.sequence import SequenceStrategy

SeqList = List['SchedulerSequenceDLLM']

DLLM_MASKED = consts.DLLM_MASKED
DLLM_UNMASKED = consts.DLLM_UNMASKED
DLLM_CACHED = consts.DLLM_CACHED
DLLM_MASK_DTYPE = np.uint8


class HistoryDLLMMask(HistoryTokenIds):

    def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE):
        super().__init__(token_ids=token_ids, dtype=dtype)


@dataclass
class SchedulerSequenceDLLM(SchedulerSequenceDefault):

    # For dllm
    history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask)

    def __post_init__(self):
        """Post init."""
        super().__post_init__()
        self._num_valid_ids: int = len(self.history_cache)
        self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy

    @property
    def dllm_mask(self):
        start = self.num_history_ids
        end = start + self._num_token_ids
        return self.history_dllm_mask[start:end]

    @property
    def num_valid_ids(self):
        return self._num_valid_ids

    @property
    def generated_ids(self) -> np.ndarray:
        end = self.num_valid_ids
        start = end - self.num_new_tokens
        return self.history_cache[start:end]

    @property
    def all_dllm_mask(self):
        return self.history_dllm_mask[:self.num_all_ids]

    @property
    def dllm_block_length(self):
        return self._strategy.block_size

    @property
    def dllm_mask_token(self):
        return self._strategy.dllm_mask_token

    def set_stop_pos(self, pos: int):
        dllm_block_length = self.dllm_block_length
        val = dllm_block_length - pos - 1
        self._num_valid_ids -= val
        self.num_new_tokens -= val

    def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
        """Append tokens."""
        num_tokens = len(token_ids)
        dllm_block_length = self.dllm_block_length
        dllm_mask_token = self.dllm_mask_token
        new_token_ids = [token_ids]
        new_dllm_mask = [dllm_mask]

        # add uncached tokens in token_ids
        # for example, [cccc cccc uumm], the [uu] in last block is remain valid.
        num_remain_valid = self.num_valid_ids - self.num_history_ids
        if num_remain_valid != 0:
            prev_token_ids = self.valid_ids[-num_remain_valid:]
            prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
            new_token_ids = [prev_token_ids] + new_token_ids
            new_dllm_mask = [prev_dllm_mask] + new_dllm_mask
            self.history_cache.resize(self.num_history_ids)
            self.history_dllm_mask.resize(self.num_history_ids)
            num_tokens += num_remain_valid

        # pad to align with dllm_block_length
        num_pad = (-num_tokens) % dllm_block_length
        if num_pad > 0:
            pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, ))
            pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, ))
            new_token_ids += [pad_ids]
            new_dllm_mask += [pad_mask]

        token_ids = np.concatenate(new_token_ids)
        dllm_mask = np.concatenate(new_dllm_mask)

        assert len(token_ids) % dllm_block_length == 0

        self.history_cache.append(token_ids)
        self.history_dllm_mask.append(dllm_mask)
        self.output_start_pos = self._num_valid_ids + len(token_ids)
        self._num_valid_ids = self.num_history_ids + num_tokens
        self._num_token_ids = len(token_ids)
        self.num_new_tokens = 0

    def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
        """Update token ids for decode."""
        num_tokens = len(token_ids)
        dllm_block_length = self.dllm_block_length
        dllm_mask_token = self.dllm_mask_token
        assert num_tokens % dllm_block_length == 0
        num_history_ids = self.num_history_ids

        token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token
        self.history_cache[num_history_ids:] = token_ids
        self.history_dllm_mask[num_history_ids:] = dllm_mask

        # check if all blocks are cached
        last_mask = dllm_mask[-dllm_block_length:]
        is_unmasked = np.all(last_mask == DLLM_UNMASKED)
        is_cached = np.all(last_mask == DLLM_CACHED)

        if is_unmasked:
            num_new = dllm_block_length - self._num_valid_ids % dllm_block_length
            self._num_valid_ids += num_new
            self.num_new_tokens += num_new

        if is_cached:
            # add new block
            new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, ))
            new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, ))
            self.history_cache.append(new_token_ids)
            self.history_dllm_mask.append(new_dllm_mask)
            self._num_history_ids += self._num_token_ids
            self._num_token_ids = dllm_block_length

    def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
        """Update token ids for prefill."""
        dllm_block_length = self.dllm_block_length
        num_history_ids = self.num_history_ids

        # fill input cache
        if self.num_token_ids > dllm_block_length:
            end = self.num_token_ids - dllm_block_length
            self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED
            self._num_history_ids += end
            self._num_token_ids -= end

        # decoding update
        self._update_token_ids_decode(token_ids, dllm_mask)

    def update_token_ids(self,
                         token_ids: Tensor,
                         multimodals: MultiModalInputs = None,
                         embeddings: List[InputEmbeddings] = None,
                         model_meta: Dict[str, Any] = None,
                         dllm_mask: Tensor = None,
                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
                         **kwargs):
        """Update token ids, old token ids will be added to history."""
        # update history image nums
        self._update_embeddings(embeddings)

        # update multimodals
        self._update_multimodals(multimodals)

        self.arrive_time = time.perf_counter()

        token_ids: np.ndarray = _to_ndarray(token_ids)
        if dllm_mask is None:
            dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
        dllm_mask: np.ndarray = _to_ndarray(dllm_mask)

        if mode == UpdateTokenMode.INPUTS:
            self._update_token_ids_inputs(token_ids, dllm_mask)
        elif mode == UpdateTokenMode.PREFILL:
            self._update_token_ids_prefill(token_ids, dllm_mask)
        else:
            self._update_token_ids_decode(token_ids, dllm_mask)

        if model_meta is not None:
            self.model_meta = model_meta

    def set_step(self, step: int):
        """Set step."""
        # reset dllm mask
        start = min(step, self.num_history_ids)
        end = self.num_history_ids
        if end > start:
            to_change_mask = self.history_dllm_mask[start:]
            to_change_mask[to_change_mask == DLLM_CACHED] = DLLM_UNMASKED
        super().set_step(step)


class DLLMSequenceStrategy(SequenceStrategy):

    def __init__(self, block_size: int, dllm_mask_token: int) -> None:
        self.block_size = block_size
        self.dllm_mask_token = dllm_mask_token

    def make_sequence(self,
                      seq_id: int,
                      session: 'SchedulerSession',
                      sampling_param: 'SamplingParam' = None,
                      adapter_name: str = None,
                      migration_request: Optional[MigrationRequest] = None,
                      resp_cache: bool = False,
                      preserve_cache: bool = False) -> 'SchedulerSequenceDLLM':
        """Make sequence."""
        return SchedulerSequenceDLLM(seq_id=seq_id,
                                     session=session,
                                     sampling_param=sampling_param,
                                     adapter_name=adapter_name,
                                     migration_request=migration_request,
                                     resp_cache=resp_cache,
                                     preserve_cache=preserve_cache)

    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',
                       delta: 'ModelInputsDelta', **kwargs) -> None:
        """Update running sequences."""
        next_token_ids = batched_outputs.next_token_ids
        stopped = batched_outputs.stopped
        stopped = stopped.tolist()
        model_metas = batched_outputs.model_metas
        if model_metas is None:
            model_metas = [None] * len(running)
        dllm_mask = batched_outputs.extra_outputs.dllm_mask
        stop_pos = batched_outputs.stop_pos

        if model_inputs is None:
            is_decoding = delta.is_decoding
        else:
            is_decoding = model_inputs.is_decoding

        batch_size = len(running)
        next_token_ids = next_token_ids.view(batch_size, -1).numpy()
        dllm_mask = dllm_mask.view(batch_size, -1).numpy()
        stop_pos = stop_pos.tolist()
        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL
        for idx, token in enumerate(next_token_ids):
            msg = running[idx]
            stop = stopped[idx]
            model_meta = model_metas[idx]
            mask = dllm_mask[idx]
            if msg.status != MessageStatus.RUNNING:
                continue

            # fill token
            msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode)
            if stop:
                msg.set_stop_pos(stop_pos[idx])
                msg.state.finish()


================================================
FILE: lmdeploy/pytorch/strategies/dllm/unmasking.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.profiler import record_function

from lmdeploy.pytorch import consts
from lmdeploy.pytorch.config import DLLMConfig, UnmaskingStrategy

DLLM_MASKED = consts.DLLM_MASKED
DLLM_UNMASKED = consts.DLLM_UNMASKED
DLLM_CACHED = consts.DLLM_CACHED


class UnmaskingProcessor:

    def __init__(self, dllm_config: DLLMConfig):
        self.dllm_config = dllm_config

    def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor):
        """Get scores."""
        scores = logits.softmax(dim=-1)
        scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten()
        return scores

    def _get_denoise_num(self):
        """Get denoise num."""
        block_size = self.dllm_config.block_length
        denoising_steps = self.dllm_config.denoising_steps
        if denoising_steps is None:
            denoising_steps = block_size
        num = block_size // self.dllm_config.denoising_steps
        num = max(1, min(num, block_size))
        return num

    def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
        """static."""
        block_size = self.dllm_config.block_length
        topk = self._get_denoise_num()
        scores = self._get_scores(logits, token_ids)
        is_masked = dllm_mask == DLLM_MASKED
        scores = torch.where(is_masked, scores, scores.new_zeros((1, )))

        scores = scores.view(-1, block_size)
        dllm_mask = dllm_mask.view(-1, block_size)
        _, indices = scores.topk(topk, dim=-1)
        dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED)

        is_masked = is_masked.view_as(dllm_mask)
        dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)
        return dllm_mask.flatten()

    def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
        """dynamic."""
        block_size = self.dllm_config.block_length
        threshold = self.dllm_config.confidence_threshold
        scores = self._get_scores(logits, token_ids)
        is_masked = dllm_mask == DLLM_MASKED
        scores = torch.where(is_masked, scores, scores.new_zeros((1, )))

        scores = scores.view(-1, block_size)
        dllm_mask = dllm_mask.view(-1, block_size)
        _, indices = scores.topk(1, dim=-1)
        scores = scores.scatter(-1, indices, threshold)

        is_masked = is_masked.view_as(dllm_mask)
        is_masked &= scores >= threshold
        dllm_mask[is_masked] = DLLM_UNMASKED
        return dllm_mask.flatten()

    def sequential(self, dllm_mask: torch.Tensor):
        """sequential."""
        block_size = self.dllm_config.block_length
        denoise_num = self._get_denoise_num()
        dllm_mask = dllm_mask.view(-1, block_size)
        is_masked = dllm_mask == DLLM_MASKED

        # get indices
        indices = is_masked.int().argmax(dim=1)
        ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype)
        indices = indices[:, None] + ranges[None, :]
        indices = indices % block_size

        dllm_unmasked = dllm_mask.clone()
        dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED)
        dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)

        return dllm_mask.flatten()

    @record_function('unmasking')
    def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
        """call."""
        strategy = self.dllm_config.unmasking_strategy
        if strategy is None:
            return dllm_mask

        # reshape to [num_blocks, block_size]
        block_size = self.dllm_config.block_length
        dllm_mask = dllm_mask.unflatten(0, (-1, block_size))

        is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1)
        first_mask = dllm_mask[:, 0]

        # unmasked to cache
        is_block_unmasked = is_same & (first_mask == DLLM_UNMASKED)
        dllm_mask[is_block_unmasked] = DLLM_CACHED

        dllm_mask = dllm_mask.flatten()
        token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids)
        if strategy == UnmaskingStrategy.LOW_CONFIDENCE_STATIC:
            dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask)
        elif strategy == UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC:
            dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask)
        elif strategy == UnmaskingStrategy.SEQUENTIAL:
            dllm_mask = self.sequential(dllm_mask)
        else:
            raise RuntimeError(f'strategy {strategy} not supported.')

        return dllm_mask, token_ids


================================================
FILE: lmdeploy/pytorch/third_party/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/third_party/deep_gemm/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

try:
    import deep_gemm  # noqa: F401
except ImportError:
    logger.exception('DeepGemm is not installed. Please install https://github.com/deepseek-ai/DeepGEMM.')

from deep_gemm import ceil_div, get_m_alignment_for_contiguous_layout  # noqa: F401, E402

try:
    from deep_gemm import fp8_gemm_nt
except Exception:
    from deep_gemm.jit_kernels.gemm import gemm_fp8_fp8_bf16_nt

    @contextmanager
    def _log_jit_build(M: int, N: int, K: int):
        from deep_gemm.jit.runtime import RuntimeCache

        if hasattr(RuntimeCache, 'get'):
            func_name = 'get'
        else:
            func_name = '__getitem__'
        origin_func = getattr(RuntimeCache, func_name)

        def __patched_func(self, *args, **kwargs):
            ret = origin_func(self, *args, **kwargs)
            if ret is None:
                logger.warning(f'DeepGemm build : M={M}, N={N}, K={K}. Please waiting.')
            return ret

        setattr(RuntimeCache, func_name, __patched_func)
        yield
        setattr(RuntimeCache, func_name, origin_func)

    def fp8_gemm_nt(a, b, d, c, recipe=None, compiled_dim='nk', disable_ue8m0_cast=False):
        M, K = a[0].shape
        N, _ = b[0].shape
        with _log_jit_build(M, N, K):
            gemm_fp8_fp8_bf16_nt(a, b, d)


try:
    from deep_gemm import m_grouped_fp8_gemm_nt_contiguous
except Exception:
    from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous

    def m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, recipe=None, compiled_dims='nk', disable_ue8m0_cast=False):
        assert recipe is None
        assert compiled_dims == 'nk'
        assert disable_ue8m0_cast is False
        return m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(a, b, d, m_indices)


try:
    from deep_gemm import m_grouped_fp8_gemm_nt_masked
except Exception:
    from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked

    def m_grouped_fp8_gemm_nt_masked(a,
                                     b,
                                     d,
                                     masked_m,
                                     expected_m,
                                     recipe=None,
                                     compiled_dims='nk',
                                     disable_ue8m0_cast=False):
        assert recipe is None
        assert compiled_dims == 'nk'
        assert disable_ue8m0_cast is False
        return m_grouped_gemm_fp8_fp8_bf16_nt_masked(a, b, d, masked_m, expected_m)


try:
    from deep_gemm import get_mn_major_tma_aligned_tensor
except Exception:
    from deep_gemm import get_col_major_tma_aligned_tensor

    def get_mn_major_tma_aligned_tensor(x):
        return get_col_major_tma_aligned_tensor(x)


try:
    from deep_gemm import m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_masked  # noqa: F401
except Exception:
    logger.warning('DeepGemm bf16 grouped gemm kernels are not found. '
                   'Please upgrade DeepGemm to the latest version.')


================================================
FILE: lmdeploy/pytorch/third_party/flash_attn_interface.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import functools

from flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn_interface import flash_attn_with_kvcache as _flash_attn_with_kvcache


@functools.wraps(_flash_attn_varlen_func)
def flash_attn_varlen_func(*args, **kwargs):
    output = _flash_attn_varlen_func(*args, **kwargs)
    if isinstance(output, tuple):
        # for old api
        return output[0]
    return output


@functools.wraps(_flash_attn_with_kvcache)
def flash_attn_with_kvcache(*args, **kwargs):
    output = _flash_attn_with_kvcache(*args, **kwargs)
    return output


================================================
FILE: lmdeploy/pytorch/tools/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import Timer  # noqa: F401


================================================
FILE: lmdeploy/pytorch/tools/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import List


class Timer:
    """Debug timer."""

    def __init__(self):
        self.duration = None
        self.timer_type = None

    def tic_cpu(self):
        self.timer_type = 'cpu'
        import time
        self._start = time.perf_counter()

    def toc_cpu(self):
        assert self.timer_type == 'cpu'
        import time
        self._end = time.perf_counter()
        self.duration = (self._end - self._start) * 1000
        return self

    def tic_cuda(self):
        self.timer_type = 'cuda'
        import torch
        self._start = torch.cuda.Event(enable_timing=True)
        self._end = torch.cuda.Event(enable_timing=True)
        self._start.record()

    def toc_cuda(self):
        assert self.timer_type == 'cuda'
        import torch
        self._end.record()
        torch.cuda.synchronize()
        self.duration = self._start.elapsed_time(self._end)
        return self

    @classmethod
    def tic(cls, is_cuda: bool = False) -> 'Timer':
        timer = Timer()
        if is_cuda:
            timer.tic_cuda()
        else:
            timer.tic_cpu()
        return timer

    def toc(self):
        if self.timer_type == 'cpu':
            return self.toc_cpu()
        elif self.timer_type == 'cuda':
            return self.toc_cuda()
        else:
            raise RuntimeError(f'Unknown timer_type: {self.timer_type}')

    @classmethod
    @contextmanager
    def timing(cls, is_cuda: bool = False) -> 'Timer':
        timer = cls.tic(is_cuda=is_cuda)
        yield timer
        timer.toc()

    @staticmethod
    def format_duration(duration: float, acc: int = 3):
        """Format duration."""
        unit = 'ms'
        if duration < 1:
            duration *= 1000
            unit = 'μs'
        elif duration > 1000:
            duration /= 1000
            unit = 's'

        return f'{duration:.{acc}f} {unit}'

    @staticmethod
    def format_flops(flops: float, acc: int = 3):
        """Compute flops."""
        unit = ''
        if flops > (1 << 40):
            flops /= (1 << 40)
            unit = 'T'
        elif flops > (1 << 30):
            flops /= (1 << 30)
            unit = 'G'
        elif flops > (1 << 20):
            flops /= (1 << 20)
            unit = 'M'
        elif flops > (1 << 10):
            flops /= (1 << 10)
            unit = 'K'
        return f'{flops:.{acc}f} {unit}Flop/s'

    @staticmethod
    def formatted_print(out_info: dict, title: str = None):
        """Formatted print."""
        max_key_len = max(len(k) for k in out_info.keys())
        max_key_len = min(10, max_key_len)
        max_val_len = max(len(k) for k in out_info.values())
        max_val_len = min(10, max_val_len)

        if title is not None:
            print(title)
        for k, v in out_info.items():
            print(f'{k:>{max_key_len}} : {v:>{max_val_len}}')

    def print(self, flop: int = None, title: str = None):
        """print."""
        if self.duration is None:
            print('Please run Timer.tic() first.')
            return

        out_info = dict()

        formated_dur = self.format_duration(self.duration)
        out_info['Duration'] = f'{formated_dur}'

        if flop is not None:
            flops = flop / self.duration * 1000
            formated_flops = self.format_flops(flops)
            out_info['Flops'] = f'{formated_flops}'

        self.formatted_print(out_info, title)

    def toc_print(self, flop: int = None, title: str = None):
        return self.toc().print(flop=flop, title=title)


def visualize_pipe_out(outputs, enable_meta: bool = True):
    import os

    from lmdeploy.messages import Response

    try:
        from termcolor import colored
    except ImportError:

        def colored(text, color=None, on_color=None, attrs=None):
            return text

    if isinstance(outputs, Response):
        outputs = [outputs]
    elif outputs is None:
        outputs = [outputs]
    try:
        term_size = os.get_terminal_size().columns
    except Exception:
        term_size = 100

    border_color = 'cyan'
    meta_color = 'light_grey'
    number_color = 'green'

    def _print_title(title: str, color: str = border_color):
        title_text = f' {title} '
        print(colored(f'【{title_text}】', color, attrs=['bold']))

    def _print_section(title: str, content: str, color: str = border_color):
        """Simple title and content printing."""
        _print_title(title, color)
        print(content)

    def _print_meta(out: Response):
        """Enhanced meta information display."""
        # Create a clean table-like format
        finish_color = 'yellow' if out.finish_reason == 'stop' else 'red'
        meta_content = [
            f"{colored('• Input Tokens:', meta_color)}     {colored(out.input_token_len, number_color)}",
            f"{colored('• Generated Tokens:', meta_color)} {colored(out.generate_token_len, number_color)}",
            f"{colored('• Finish Reason:', meta_color)}    {colored(out.finish_reason, finish_color)}"
        ]
        if out.routed_experts is not None:
            shape = tuple(out.routed_experts.shape)
            meta_content.append(f"{colored('• Routed Experts:', meta_color)}  {colored(shape, number_color)}")
        if out.logits is not None:
            shape = tuple(out.logits.shape)
            meta_content.append(f"{colored('• Logits Shape:', meta_color)}     {colored(shape, number_color)}")
        if out.logprobs is not None:
            size = len(out.logprobs)
            meta_content.append(f"{colored('• Logprobs:', meta_color)}      {colored(size, number_color)}")

        lines = '\n'.join(meta_content)
        lines += '\n'
        _print_section('METADATA', lines, border_color)

    # Main loop
    print(colored('━' * term_size, border_color))

    outputs: List[Response] = outputs
    for idx, out in enumerate(outputs):
        header = f'OUTPUT [{idx + 1}/{len(outputs)}]'
        header_formatted = colored(f'✦ {header}', 'light_magenta', attrs=['bold'])
        print(header_formatted)
        print()

        if out is not None:
            if enable_meta:
                _print_meta(out)

            _print_section('TEXT', out.text, border_color)

        if idx < len(outputs) - 1:  # Add separator when it's not the last output
            print(colored('─' * (term_size), border_color, attrs=['dark']))
        else:
            print(colored('━' * term_size, border_color))


def visualize_chat_completions(outputs, enable_meta: bool = True):
    """Visualize chat completions."""
    from openai.types.chat import ChatCompletion

    from lmdeploy.messages import Response
    if isinstance(outputs, ChatCompletion):
        outputs = [outputs]

    resps = []
    for out in outputs:
        assert isinstance(out, ChatCompletion)
        choice = out.choices[0]
        resp = Response(text=choice.message.content,
                        input_token_len=out.usage.prompt_tokens,
                        generate_token_len=out.usage.completion_tokens,
                        finish_reason=choice.finish_reason)
        resps.append(resp)

    return visualize_pipe_out(resps, enable_meta=enable_meta)


sources = None


def dump_tilelang_source(kernel, path: str = 'sources/tvm_kernels.cu'):
    global sources
    if sources is not None:
        return
    sources = kernel.get_kernel_source()
    with open(path, 'w') as f:
        f.write(sources)


================================================
FILE: lmdeploy/pytorch/transformers/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache

from transformers import AutoConfig

from lmdeploy.utils import get_logger


@lru_cache()
def register_config(model_type: str):
    if model_type == 'deepseek_v32':
        from lmdeploy.pytorch.transformers.configuration_deepseek_v32 import DeepseekV32Config
        AutoConfig.register(DeepseekV32Config.model_type, DeepseekV32Config)
    else:
        logger.debug(f'Can not register config for model_type: {model_type}')


logger = get_logger('lmdeploy')


def config_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
    try:
        return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
    except ValueError as e:
        logger.debug(f'AutoConfig.from_pretrained failed: {e}, try register config manually.')
        # some models (dsv32) does not provide auto map for config
        from transformers import PretrainedConfig
        trust_remote_code = kwargs.pop('trust_remote_code', None)
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
        model_type = config_dict.get('model_type', None)
        if trust_remote_code is not None:
            kwargs['trust_remote_code'] = trust_remote_code
        register_config(model_type)

    return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)


================================================
FILE: lmdeploy/pytorch/transformers/configuration_deepseek_v32.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config


class DeepseekV32Config(DeepseekV3Config):
    model_type = 'deepseek_v32'

    def __init__(self, index_head_dim=128, index_n_heads=64, index_topk=2048, **kwargs):
        super().__init__(**kwargs)
        self.index_head_dim = index_head_dim
        self.index_n_heads = index_n_heads
        self.index_topk = index_topk


================================================
FILE: lmdeploy/pytorch/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import asyncio
import inspect
from contextlib import contextmanager
from inspect import Parameter, Signature
from typing import Dict, Generic, Optional, Sequence, TypeVar

import psutil

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def get_gpu_memory(device_id: int = None) -> int:
    """Returns the free and total physical memory of the GPU in bytes."""
    import torch
    if device_id is None:
        device_id = torch.cuda.current_device()
    return torch.cuda.mem_get_info(device_id)


def get_cpu_memory() -> int:
    """Returns the total CPU memory of the node in bytes."""
    return psutil.virtual_memory().total


def bind_sigature(input_names: str, args: Sequence, kwargs: Dict):
    """Bind args and kwargs to given input names."""
    kind = inspect._ParameterKind.POSITIONAL_OR_KEYWORD

    sig = Signature([Parameter(name, kind) for name in input_names])
    bind = sig.bind(*args, **kwargs)
    return bind.arguments


def singleton(cls):
    """Singleton decorator."""
    import multiprocessing as mp

    from lmdeploy.utils import get_logger
    logger = get_logger('lmdeploy')
    instances = {}

    def get_instance(*args, **kwargs):
        if cls not in instances:
            pid = mp.current_process().pid
            logger.debug(f'pid:{pid} - Creating instance of singleton class {cls.__name__}')
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]

    return get_instance


T = TypeVar('T')


class CtxMgrBase(Generic[T]):
    """Context manager base class."""

    def __init__(self, default: Optional[T] = None):
        self._context = default

    def current_context(self) -> Optional[T]:
        """Get current context."""
        return self._context

    def set_context(self, context: Optional[T]):
        """Set current context."""
        self._context = context

    @contextmanager
    def context(self, context: T):
        """Context manager."""
        origin_context = self.current_context()
        self.set_context(context)
        try:
            yield self
        finally:
            self.set_context(origin_context)


# from vllm
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
    """Try to register HF model configuration class to serialize by value With
    trust_remote_code, the config class is typically an instance of a custom
    class imported from the HF modules cache.

    The class will not be
    importable in spawned workers by default (and won't exist at all on
    other nodes), which breaks serialization of the config.
    In this function we tell the cloudpickle serialization library to pass
    instances of these generated classes by value instead of by reference,
    i.e. the class definition is serialized along with its data so that the
    class module does not need to be importable on the receiving end. This
    registration only works if the modules cache has already been
    initialized.
    See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
    """  # noqa: E501
    if not trust_remote_code:
        return

    try:
        import transformers_modules
    except ImportError:
        logger.debug('Could not import transformers_modules used for remote'
                     ' code. If remote code is not needed remove'
                     ' `--trust-remote-code`.')
        return

    try:
        import cloudpickle
        cloudpickle.register_pickle_by_value(transformers_modules)

        # ray vendors its own version of cloudpickle
        try:
            import ray
        except ImportError:
            return

        ray.cloudpickle.register_pickle_by_value(transformers_modules)

        # multiprocessing uses pickle to serialize arguments when using spawn
        # Here we get pickle to use cloudpickle to serialize ModelConfig objects
        # that contain instances of the custom config class to avoid
        # serialization problems if the generated module (and model) has a `.`
        # in its name
        import multiprocessing
        import pickle

        from lmdeploy.pytorch.config import ModelConfig

        def _reduce_modelconfig(mc: ModelConfig):
            return (pickle.loads, (cloudpickle.dumps(mc), ))

        multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)

    except Exception as e:
        logger.warning(
            'Unable to register remote classes used by'
            ' trust_remote_code with by-value serialization. This may'
            ' lead to a later error. If remote code is not needed'
            ' remove `--trust-remote-code`',
            exc_info=e)


def monkey_patch_hf_modules_cache():
    """Monkey patch HF_MODULES_CACHE to a temporary directory per process. This
    is necessary to avoid conflicts when multiple processes try to read/write
    to the same HF_MODULES_CACHE directory, especially in multi-GPU setups.

    modified from: https://github.com/InternLM/xtuner/blob/main/xtuner/v1/utils/misc.py
    """
    import os

    import transformers
    from huggingface_hub import constants

    # When using `remote_code` in HF components like tokenizer or config
    # (e.g., `AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)`),
    # the hf_model_path is copied to HF_MODULES_CACHE.
    # On multi-GPU machines (e.g., 8 GPUs), simultaneous read/write operations
    # by multiple processes on this shared directory can cause conflicts.
    # Therefore, we set HF_MODULES_CACHE to a temporary directory per process.

    HF_PATCH_MODULES_CACHE_PREFIX = 'modules_pid_'
    modules_cache = os.path.join(constants.HF_HOME, f'{HF_PATCH_MODULES_CACHE_PREFIX}{os.getpid()}')
    os.environ['HF_MODULES_CACHE'] = modules_cache

    transformers.utils.hub.HF_MODULES_CACHE = modules_cache

    # During import, Python creates a new name HF_MODULES_CACHE in the namespace
    # of the dynamic_module_utils module, binding it to the object referenced by
    # transformers.utils.HF_MODULES_CACHE at that moment.
    # Hence, we also need to set transformers.dynamic_module_utils.HF_MODULES_CACHE
    # to the new modules_cache.

    transformers.dynamic_module_utils.HF_MODULES_CACHE = modules_cache
    transformers.utils.HF_MODULES_CACHE = modules_cache

    logger.info(f'Set HF_MODULES_CACHE to {modules_cache} for current process {os.getpid()}')


async def wait_for_async_tasks(tasks: Sequence[asyncio.Task],
                               cancel_pending: bool = True,
                               ignore_cancellederror: bool = True):
    """Wait for async tasks."""
    if len(tasks) == 0:
        return [], []

    for task in tasks:
        if not isinstance(task, asyncio.Task):
            raise ValueError('All inputs must be asyncio.Task instances.')

    try:
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)

        if cancel_pending:
            # cancel all pending tasks
            for task in pending:
                task.cancel()

        # raise exception if any
        for task in done:
            if task.cancelled():
                continue
            if exc := task.exception():
                if isinstance(exc, asyncio.CancelledError) and ignore_cancellederror:
                    logger.debug(f'Task <{task.get_name()}> cancelled.')
                    continue
                raise exc from None
    except asyncio.CancelledError:
        for task in tasks:
            if not task.done():
                task.cancel()
        raise

    return done, pending


async def cancel_async_tasks(tasks: Sequence[asyncio.Task]):
    """Cancel async tasks."""
    if isinstance(tasks, asyncio.Task):
        tasks = [tasks]

    tasks = list(task for task in tasks if not task.done())
    for task in tasks:
        task.cancel()
    return await asyncio.gather(*tasks, return_exceptions=True)


================================================
FILE: lmdeploy/pytorch/weight_loader/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/pytorch/weight_loader/model_weight_loader.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import json
import os.path as osp

import numpy as np
import torch
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME

from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs):
    """Load weight."""
    if hasattr(param, 'weight_loader'):
        param.weight_loader(param, loaded_weight, **kwargs)
    else:
        assert len(kwargs) == 0
        default_weight_loader(param, loaded_weight)


def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor):
    """Default weight loader."""
    if param.numel() == 1 and loaded_weight.numel() == 1:
        param.data.fill_(loaded_weight.item())
    else:
        assert param.size() == loaded_weight.size(), (f'Attempted to load weight ({loaded_weight.size()}) '
                                                      f'into parameter ({param.size()})')
        param.data.copy_(loaded_weight)


def _get_weight_type(model_path: str, use_safetensors: bool = None):
    """Get weight type."""
    weight_type = None
    is_sharded = False
    if use_safetensors is not False and osp.isfile(osp.join(model_path, SAFE_WEIGHTS_NAME)):
        # Load from a safetensors checkpoint
        weight_type = 'safetensors'
    elif use_safetensors is not False and osp.isfile(osp.join(model_path, SAFE_WEIGHTS_INDEX_NAME)):
        # Load from a sharded safetensors checkpoint
        weight_type = 'safetensors'
        is_sharded = True
    elif osp.isfile(osp.join(model_path, WEIGHTS_NAME)):
        # Load from a PyTorch checkpoint
        weight_type = 'pytorch'
    elif osp.isfile(osp.join(model_path, WEIGHTS_INDEX_NAME)):
        # Load from a sharded PyTorch checkpoint
        weight_type = 'pytorch'
        is_sharded = True
    else:
        raise RuntimeError('Unknown weight type.')

    return (weight_type, is_sharded)


def _get_weight_map(model_path: str, weight_type: str):
    """Get weight index."""
    if weight_type == 'safetensors':
        load_index = osp.join(model_path, SAFE_WEIGHTS_INDEX_NAME)
    elif weight_type == 'pytorch':
        load_index = osp.join(model_path, WEIGHTS_INDEX_NAME)
    else:
        raise RuntimeError(f'Unsupported weight type: {weight_type}.')

    with open(load_index, mode='r', encoding='utf-8') as f:
        index = json.load(f)

    weight_map = index['weight_map']
    return weight_map


def _get_weight_path(model_path: str, weight_type: str):
    """Get weight path."""
    if weight_type == 'safetensors':
        weight_name = SAFE_WEIGHTS_NAME
    elif weight_type == 'pytorch':
        weight_name = WEIGHTS_NAME
    else:
        raise RuntimeError('Unknown weight type.')

    weight_path = osp.join(model_path, weight_name)
    return weight_path, weight_name


def _get_safetensors_weights_iterator(file: str, prefix: str):
    """Get safeternsors weights iterator."""
    with safe_open(file, framework='pt') as f:
        for name in f.keys():
            param = f.get_tensor(name)
            if prefix is not None:
                name = f'{prefix}{name}'
            yield name, param


def _get_pt_weights_iterator(file: str, prefix: str):
    """Get pt weights iterator."""
    state = torch.load(file, weights_only=True, map_location='cpu')
    try:
        if prefix is None:
            yield from state.items()
        else:
            for k, v in state.items():
                yield f'{prefix}{k}', v
    finally:
        del state
        torch.cuda.empty_cache()


class ModelWeightLoader:
    """Model weight loader for sharded weights."""

    def __init__(self, model_path: str, prefix: str = None):
        self.model_path = model_path
        weight_type, is_sharded = _get_weight_type(model_path)

        self._weight_type = weight_type
        self._is_sharded = is_sharded
        self._prefix = prefix
        self._shard_paths = self._get_shard_paths(model_path, is_sharded, weight_type)

    @staticmethod
    def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str):
        """Get shard paths."""
        if is_sharded:
            weight_map = _get_weight_map(model_path, weight_type)
            paths = set(weight_map.values())
            paths = tuple(f'{model_path}/{path}' for path in paths)
            return paths
        else:
            path, _ = _get_weight_path(model_path, weight_type)
            return (path, )

    def _get_weights_iterator(self, path: str):
        """Get weights iterator."""
        if self._weight_type == 'safetensors':
            weights_iterator = _get_safetensors_weights_iterator(path, self._prefix)
        else:
            weights_iterator = _get_pt_weights_iterator(path, self._prefix)
        return weights_iterator

    @staticmethod
    def _skip_dummy_iterator(iterator, dummy_prefix: list):
        """Wrap iterator to skip dummy weights."""
        for name, param in iterator:
            if not any(name.startswith(prefix) for prefix in dummy_prefix):
                yield name, param

    @staticmethod
    def _rename_weights_iterator(iterator, model: torch.nn.Module):
        """Wrap iterator to rename weights."""
        rename_func = getattr(model, 'rename_weight', lambda x: x)
        for name, param in iterator:
            new_name = rename_func(name)
            yield new_name, param

    def load_model_weights(
        self,
        model: torch.nn.Module,
        device: torch.device = None,
    ):
        """Load model weights implementation."""
        assert hasattr(model, 'load_weights')
        paths = self._shard_paths
        _, rank = get_world_rank()
        disable_tqdm = rank != 0

        # get dummy prefix
        dummy_prefix = []
        for name, mod in model.named_modules():
            if getattr(mod, '_is_dummy_mod', False):
                dummy_prefix.append(f'{name}.')

        paths = sorted(paths)
        if _envs.random_load_weight:
            np.random.shuffle(paths)
        for path in tqdm(paths, desc='Loading weights from safetensors', disable=disable_tqdm):
            weights_iterator = self._get_weights_iterator(path)
            weights_iterator = self._rename_weights_iterator(weights_iterator, model)
            if len(dummy_prefix) > 0:
                weights_iterator = self._skip_dummy_iterator(weights_iterator, dummy_prefix)
            model.load_weights(weights_iterator)
        if device is not None:
            model.to(device)


@torch.inference_mode()
def load_model_weights(model: torch.nn.Module, checkpoint_path: str, prefix: str = None, device: torch.device = None):
    """Loading model weights."""
    loader = ModelWeightLoader(checkpoint_path, prefix=prefix)
    loader.load_model_weights(model, device=device)
    model.eval()
    for _, mod in model.named_modules():
        if not hasattr(mod, 'update_weights'):
            continue
        mod.update_weights()


================================================
FILE: lmdeploy/serve/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .core import AsyncEngine, VLAsyncEngine
from .managers import Session, SessionManager
from .processors import MultimodalProcessor

__all__ = [
    'AsyncEngine',
    'VLAsyncEngine',
    'SessionManager',
    'Session',
    'MultimodalProcessor',
]


================================================
FILE: lmdeploy/serve/core/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .async_engine import AsyncEngine
from .vl_async_engine import VLAsyncEngine

__all__ = ['AsyncEngine', 'VLAsyncEngine']


================================================
FILE: lmdeploy/serve/core/async_engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import asyncio
import concurrent.futures
import dataclasses
import random
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any, Dict, List, Literal

import torch

from lmdeploy.archs import get_model_arch
from lmdeploy.logger import RequestLogger
from lmdeploy.messages import (EngineOutput, GenerationConfig, PytorchEngineConfig, Response, ResponseType,
                               SpeculativeConfig, TurbomindEngineConfig)
from lmdeploy.metrics.metrics_processor import metrics_processor
from lmdeploy.metrics.stats import IterationStats, RequestStats, SpeculativeDecodingStats
from lmdeploy.model import ChatTemplateConfig, get_chat_template
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
                                                   DistServeInitRequest)
from lmdeploy.serve.managers import Session, SessionManager
from lmdeploy.serve.processors import MultimodalProcessor
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger

from .exceptions import SafeRunException

logger = get_logger('lmdeploy')


@dataclasses.dataclass
class GenOut:
    """Pack all response information together."""
    response: str
    history_token_len: int
    input_token_len: int
    generate_token_len: int
    finish_reason: Literal['stop', 'length', 'error'] | None = None
    token_ids: List[int] | None = None
    logprobs: List[Dict[int, float]] | None = None
    logits: Any = None
    last_hidden_state: Any = None
    cache_block_ids: List[int] | None = None  # for disaggregation
    routed_experts: Any = None  # for RL router replay

    def to_response(self, index: int = 0) -> Response:
        """Convert GenOut to Response object.

        Args:
            index: The index position in the batch. Default to 0.
        """
        return Response(text=self.response,
                        generate_token_len=self.generate_token_len,
                        input_token_len=self.input_token_len,
                        finish_reason=self.finish_reason,
                        token_ids=self.token_ids or [],
                        logprobs=self.logprobs,
                        last_hidden_state=self.last_hidden_state,
                        logits=self.logits,
                        routed_experts=self.routed_experts,
                        index=index)


# class AsyncEngine(LogitsMixin):
class AsyncEngine:
    """Async inference engine. Maintaining a bunch of tm_model instances.

    Args:
        model_path (str): the path of a model.
            It could be one of the following options:
                - i) A local directory path of a turbomind model which is
                    converted by `lmdeploy convert` command or download from
                    ii) and iii).
                - ii) The model_id of a lmdeploy-quantized model hosted
                    inside a model repo on huggingface.co, such as
                    "InternLM/internlm-chat-20b-4bit",
                    "lmdeploy/llama2-chat-70b-4bit", etc.
                - iii) The model_id of a model hosted inside a model repo
                    on huggingface.co, such as "internlm/internlm-chat-7b",
                    "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                    and so on.
        model_name (str): needed when model_path is a pytorch model on
            huggingface.co, such as "internlm/internlm-chat-7b",
            "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
        backend (str): either `turbomind` or `pytorch` backend. Default to
            `turbomind` backend.
        backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
            config instance. Default to none.
        chat_template_config (ChatTemplateConfig): chat template configuration.
            Default to None.
        max_log_len (int): Max number of prompt characters or prompt tokens
            being printed in log. Default: Unlimited
    """

    def __init__(self,
                 model_path: str,
                 model_name: str | None = None,
                 backend: Literal['turbomind', 'pytorch'] = 'turbomind',
                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,
                 chat_template_config: ChatTemplateConfig | None = None,
                 max_log_len: int | None = None,
                 speculative_config: SpeculativeConfig | None = None,
                 **kwargs) -> None:
        logger.info(f'input backend={backend}, backend_config={backend_config}')
        logger.info(f'speculative_config={speculative_config}')
        backend_config = backend_config or (TurbomindEngineConfig()
                                            if backend == 'turbomind' else PytorchEngineConfig())
        self.model_name = model_name if model_name else model_path
        self.chat_template = get_chat_template(model_path, chat_template_config)
        self.tokenizer = Tokenizer(model_path)
        self.prompt_processor = MultimodalProcessor(self.tokenizer, self.chat_template)
        self.hf_gen_cfg = get_hf_gen_cfg(model_path)
        self.arch, self.hf_cfg = get_model_arch(model_path)
        self.session_len = (_get_and_verify_max_len(self.hf_cfg, None)
                            if backend_config.session_len is None else backend_config.session_len)
        backend_config.session_len = self.session_len
        if speculative_config is not None and backend == 'turbomind':
            logger.warning('speculative decoding is not supported by turbomind ')
        # build backend engine
        if backend == 'turbomind':
            self.engine = self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs)
        elif backend == 'pytorch':
            self.engine = self._build_pytorch(model_path=model_path,
                                              backend_config=backend_config,
                                              speculative_config=speculative_config,
                                              **kwargs)
        else:
            raise ValueError(f'unsupported backend {backend}')
        self.backend_config = self.engine.engine_config
        self.is_sleeping = backend_config.empty_init
        self.sleeping_tags: set[str] = set() if not backend_config.empty_init else {'weights', 'kv_cache'}
        logger.info(f'updated backend_config={self.backend_config}')

        # parameters for member functions
        self.stop_words = _stop_words(self.chat_template.stop_words, self.tokenizer)
        if self.stop_words is not None:
            self.stop_words = self.stop_words[0][0].tolist()
        self.backend = backend
        self.request_logger = RequestLogger(max_log_len)

        self.num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \
            else speculative_config.num_speculative_tokens

        self.session_mgr = SessionManager()
        self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)

        # build stat loggers
        self._build_stat_loggers()
        self.epoch = 0

    def close(self):
        self.session_mgr.clear()
        self.engine.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def _build_turbomind(self, model_path: str, backend_config: TurbomindEngineConfig | None = None, **kwargs):
        """Inner build method for turbomind backend."""
        from lmdeploy import turbomind as tm
        return tm.TurboMind.from_pretrained(model_path, engine_config=backend_config, **kwargs)

    def _build_pytorch(self,
                       model_path: str,
                       backend_config: PytorchEngineConfig | None = None,
                       speculative_config: SpeculativeConfig | None = None,
                       **kwargs):
        """Inner build method for pytorch backend."""
        from lmdeploy.pytorch.engine import Engine
        return Engine.from_pretrained(model_path, engine_config=backend_config, speculative_config=speculative_config)

    def _build_stat_loggers(self):
        self.stat_loggers = []

        if getattr(self.backend_config, 'enable_metrics', False):
            from lmdeploy.metrics.loggers import LoggingStatLogger, PrometheusStatLogger

            # currently, metrics in TM engine doesn't support dp
            dp_rank = self.backend_config.dp_rank if self.backend == 'pytorch' else 0

            logger.info(f'enable metrics, with dp: {self.backend_config.dp} dp_rank: {dp_rank}')
            self.stat_loggers = [
                LoggingStatLogger(dp_rank=dp_rank),
                PrometheusStatLogger(model_name=self.model_name, max_model_len=self.session_len, dp_rank=dp_rank)
            ]

            # set stats loggers of metrics processor
            metrics_processor.stat_loggers = self.stat_loggers

    def get_schedule_metrics(self):
        return self.engine.get_schedule_metrics()

    async def do_log_stats(self):
        """Loop through CLI logger and Prometheus logger and output the
        metrics."""
        for stat_logger in self.stat_loggers:
            stat_logger.log()

    async def stop_all_session(self):
        """Stop all running sessions."""
        logger.info('stop all sessions')
        self.epoch += 1
        await self.session_mgr.async_abort_all()

    def sleep(self, level: int = 1):
        """Sleep the model.

        Args:
            level (int): The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. Level 2 sleep will
                discard both the model weights and the kv cache.
        """
        self.engine.sleep(level)
        self.sleeping_tags = {'weights', 'kv_cache'}
        self.is_sleeping = True

    def wakeup(self, tags: List[str] | None = None):
        """Wake up the model.

        Args:
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
                `("weights", "kv_cache")`. If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the
                engine is used again.
        """
        tags = tags or list(self.sleeping_tags)
        if any(tag not in self.sleeping_tags for tag in tags):
            logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}')
            return
        self.engine.wakeup(tags)
        # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances
        if self.backend == 'turbomind' and 'kv_cache' in tags:
            self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)
        self.sleeping_tags = self.sleeping_tags - set(tags)
        self.is_sleeping = bool(self.sleeping_tags)

    def _determine_gen_config(self, session, input_ids, gen_config: GenerationConfig | None = None) -> GenerationConfig:
        """Determine the generation configuration."""
        gen_config = deepcopy(gen_config) or GenerationConfig()
        gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
        gen_config.stop_token_ids = gen_config.stop_token_ids or self.stop_words
        gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)
        if not gen_config.do_sample:
            # greedy decode
            gen_config.top_k = 1
            # avoid unnecessary process
            gen_config.temperature = 1.0
            gen_config.repetition_penalty = 1.0
        # set random if it is not set and sequence_start is True
        elif gen_config.random_seed is None and session.step == 0:
            gen_config.random_seed = random.getrandbits(64)
        if gen_config.n > 1:
            logger.warning(f'n({gen_config.n}) > 1 hasn\'t been supported yet. Fallback to 1')
            gen_config.n = 1
        if gen_config.max_new_tokens is None:
            gen_config.max_new_tokens = max(0, self.session_len - session.step - len(input_ids))
        return gen_config

    @asynccontextmanager
    async def safe_run(self, handle, session, **kwargs):
        generator = handle.async_stream_infer(session.session_id, **kwargs)
        try:
            metrics_processor.increase_api_routed_requests()
            yield generator
        except (Exception, asyncio.CancelledError, GeneratorExit) as e:  # noqa
            logger.exception(f'[safe_run] session {session.session_id} exception caught: {e}')
            await session.async_abort()
            if self.backend == 'pytorch':
                await handle.async_end(session.session_id)
            raise SafeRunException(f'Safe run exception for session {session.session_id}') from e
        finally:
            await generator.aclose()
            metrics_processor.decrease_api_routed_requests()

    async def generate(
            self,
            messages,
            session_id: int | Session,
            gen_config: GenerationConfig | None = None,
            tools: List[object] | None = None,
            reasoning_effort: Literal['low', 'medium', 'high'] | None = None,
            stream_response: bool = True,
            sequence_start: bool = True,
            sequence_end: bool = True,  # no interactive mode by default
            step: int = 0,
            do_preprocess: bool = True,
            adapter_name: str | None = None,
            rewind_stop_tokens: bool = False,
            input_ids: List | None = None,
            enable_thinking: bool | None = None,
            chat_template_kwargs: Dict | None = None,
            media_io_kwargs: Dict[str, Any] | None = None,
            mm_processor_kwargs: Dict[str, Any] | None = None,
            **kwargs):
        """Generate responses.

        Args:
            messages (str | List): chat history or prompt
            session_id (int | Session): the session id or instance of Session
            gen_config (GenerationConfig | None): a instance of
                GenerationConfig. Default to None.
            stream_response (bool): whether return responses streamingly
            sequence_start (bool): indicator for starting a sequence
            sequence_end (bool): indicator for ending a sequence
            step (int): the offset of the k/v cache
            do_preprocess (bool): whether pre-process the messages. Default to
                True, which means chat_template will be applied.
        """
        epoch = self.epoch
        if (messages is not None) ^ (input_ids is None):
            raise ValueError('You must specify exactly one of messages or input_ids')
        if isinstance(session_id, Session):
            session = session_id
        elif isinstance(session_id, int):
            session = self.session_mgr.get(session_id, step=step)
        else:
            raise ValueError(f'Invalid session_id: {session_id}. It should be an instance of Session or an integer.')
        session_id = session.session_id
        chat_template_kwargs = chat_template_kwargs or {}
        if enable_thinking is not None:
            logger.warning('enable_thinking is deprecated, use chat_template_kwargs["enable_thinking"] instead')
            if chat_template_kwargs.get('enable_thinking') is None:
                chat_template_kwargs['enable_thinking'] = enable_thinking
            else:
                logger.warning('chat_template_kwargs["enable_thinking"] is already set, '
                               'the value will not be overwritten by enable_thinking')
        if messages:
            prompt = messages
            self.request_logger.log_prompt(session, prompt=prompt)
            prompt_input = await self.prompt_processor.get_prompt_input(prompt=prompt,
                                                                        do_preprocess=do_preprocess,
                                                                        sequence_start=sequence_start,
                                                                        adapter_name=adapter_name,
                                                                        tools=tools,
                                                                        reasoning_effort=reasoning_effort,
                                                                        chat_template_kwargs=chat_template_kwargs,
                                                                        media_io_kwargs=media_io_kwargs,
                                                                        mm_processor_kwargs=mm_processor_kwargs,
                                                                        **kwargs)
            prompt = prompt_input['prompt']
            input_ids = prompt_input['input_ids']
            self.request_logger.log_inputs(session,
                                           prompt=prompt,
                                           prompt_token_ids=input_ids,
                                           gen_config=gen_config,
                                           adapter_name=adapter_name)
        else:
            # TODO(lvhan) VLM doesn't support input_ids as an argument.
            # Figure out a graceful way to handle the invalid input
            prompt_input = dict(input_ids=input_ids)

        gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config)

        if gen_config.max_new_tokens == 0:
            logger.info(f'run out of tokens. session={session_id}.')
            yield GenOut(response='',
                         history_token_len=session.step,
                         input_token_len=len(input_ids),
                         generate_token_len=0,
                         finish_reason='length',
                         token_ids=[])
            if sequence_end is True and sequence_start is False:
                await session.async_close()
            return

        if self.backend_config.enable_prefix_caching and (gen_config.output_last_hidden_state == 'all'
                                                          or gen_config.output_logits == 'all'):
            errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
                      'when prefix caching is ON')
            yield GenOut(response=errmsg,
                         history_token_len=session.step,
                         input_token_len=len(input_ids),
                         generate_token_len=0,
                         finish_reason='error',
                         token_ids=[])
            return
        logger.info(f'session={session_id}, '
                    f'history_tokens={session.step}, '
                    f'input_tokens={len(input_ids)}, '
                    f'max_new_tokens={gen_config.max_new_tokens}, '
                    f'seq_start={sequence_start}, seq_end={sequence_end}, '
                    f'step={step}, prep={do_preprocess}')

        def is_error(status):
            return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL]

        stop_ids = []
        if not gen_config.ignore_eos:
            stop_ids = gen_config.stop_token_ids or []

        metrics_processor.increase_total_requests()
        async with session.request_handle() as handle:
            if epoch != self.epoch:
                logger.info(f'[generate] session {session_id} got aborted before starting inference')
                # TODO(lvhan): metrics_processor.increase_failed_requests('abort')
                metrics_processor.increase_completed_requests()
                yield GenOut(response='',
                             history_token_len=0,
                             input_token_len=len(input_ids),
                             generate_token_len=0,
                             finish_reason='abort',
                             token_ids=[])
                return
            token_ids = input_ids.copy()
            history_len = session.step
            input_len = len(input_ids)
            output_len, gen_len = 0, 0
            state = DetokenizeState(input_len)
            response = ''
            finish_reason = None
            async with self.safe_run(handle,
                                     session=session,
                                     **prompt_input,
                                     gen_config=gen_config,
                                     adapter_name=adapter_name,
                                     stream_output=stream_response,
                                     sequence_start=sequence_start,
                                     sequence_end=sequence_end,
                                     step=history_len) as gen:
                logger.debug(f'[generate] session {session_id} started')
                hit_stop_token = 0
                req_stats = RequestStats(prompt_tokens=input_len)  # per-request stats

                # We use this as default outputs in case the async_stream_infer of the Engine yields empty generator.
                outputs = EngineOutput(ResponseType.INTERNAL_ENGINE_ERROR, [])

                async for outputs in gen:
                    iteration_stats = IterationStats()  # per-iteration stats
                    specdecode_stats = SpeculativeDecodingStats(
                        self.num_spec_token) if self.num_spec_token > 0 else None
                    metrics_processor.queue_update((outputs, req_stats, iteration_stats, specdecode_stats))
                    # decode res
                    if is_error(outputs.status):
                        break

                    output_len = len(outputs.token_ids)
                    if hit_stop_token or output_len == 0:
                        continue

                    # This assumes the engine will stop when stop token is hit
                    if output_len and outputs.token_ids[-1] in stop_ids:
                        hit_stop_token = 1

                    token_ids += outputs.token_ids[:output_len - hit_stop_token]
                    gen_len = len(token_ids) - input_len

                    ids_offset = state.ids_offset
                    response, state = self.tokenizer.detokenize_incrementally(
                        token_ids,
                        state,
                        skip_special_tokens=gen_config.skip_special_tokens,
                        spaces_between_special_tokens=gen_config.spaces_between_special_tokens)
                    res = token_ids[ids_offset:]

                    out = GenOut(response,
                                 history_len,
                                 input_len,
                                 gen_len,
                                 finish_reason,
                                 token_ids=res,
                                 routed_experts=outputs.routed_experts,
                                 cache_block_ids=outputs.cache_block_ids)
                    if outputs.logprobs is not None:
                        out.logprobs = (outputs.logprobs[:-hit_stop_token] if hit_stop_token else outputs.logprobs)
                    if outputs.last_hidden_state is not None:
                        out.last_hidden_state = (outputs.last_hidden_state[:-hit_stop_token]
                                                 if hit_stop_token else outputs.last_hidden_state)
                    if outputs.logits is not None:
                        out.logits = (outputs.logits[:-hit_stop_token] if hit_stop_token else outputs.logits)
                    yield out
                # end of generator loop
                metrics_processor.increase_completed_requests()

                if not is_error(outputs.status):
                    if outputs.status == ResponseType.CANCEL:
                        finish_reason = 'abort'
                    else:
                        finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'

                    # utf-8 char at the end means it's a potential unfinished byte sequence
                    if not response.endswith('�'):
                        # avoid returning the last response twice
                        response = ''
                    token_ids, logits, last_hidden_state, logprobs = [], None, None, None
                    if gen_config.include_stop_str_in_output and finish_reason == 'stop':
                        # return the eos token id (MUST be in a list), eos string, eos token's logits and so on
                        token_ids = outputs.token_ids[-1:]
                        response = self.tokenizer.decode(token_ids, skip_special_tokens=False)
                        logits = outputs.logits[-1:] if outputs.logits is not None else None
                        last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None
                        logprobs = outputs.logprobs[-1:] if outputs.logprobs else None
                        gen_len += 1

                    # router replay
                    routed_experts = outputs.routed_experts
                    if routed_experts is not None and not isinstance(routed_experts, str) and (
                            not gen_config.include_stop_str_in_output) and finish_reason == 'stop':
                        routed_experts = routed_experts[:-1]

                    logger.info(f'session {session_id} finished, reason '
                                f'"{finish_reason}", input_tokens '
                                f'{len(input_ids)}, output_tokens {gen_len}')
                    yield GenOut(response,
                                 session.step,
                                 len(input_ids),
                                 gen_len,
                                 finish_reason,
                                 token_ids=token_ids,
                                 logprobs=logprobs,
                                 logits=logits,
                                 last_hidden_state=last_hidden_state,
                                 routed_experts=routed_experts,
                                 cache_block_ids=outputs.cache_block_ids)
                    # Note: We remove the session step update here. Let the caller(e.g., pipeline.chat) take care of it.
                else:
                    logger.error(f'session {session_id} finished, {outputs.status}, '
                                 'reason "error"')
                    yield GenOut(response=f'internal error happened, status code {outputs.status}',
                                 history_token_len=session.step,
                                 input_token_len=len(input_ids),
                                 generate_token_len=0,
                                 finish_reason='error',
                                 token_ids=[])
            # update step
            if sequence_end:
                if self.backend == 'pytorch':
                    # manually end pytorch session
                    # note: Using session.async_abort() here results in deadlock
                    # because it waits for session's _active event to be set, but the event won't be set
                    # until the session is finished, i.e., session.request_handle() context exits.
                    await handle.async_end(session.session_id)
                self.session_mgr.remove(session)
        # if sequence_end:
        #     if self.backend == 'pytorch':
        #         # manually end pytorch session. session cannot be ended until session.request_handle()
        #         # context exits
        #         await session.async_close()
        #     self.session_mgr.remove(session)

    def start_loop(self, loop, use_async_api=False):
        """Start engine loop.

        When using pytorch backend with dp > 1, all dp_rank should receive at least one request before it can start
        processing (warmup). Since pytorch engine will bound to event loop, the pipeline can only choose either the
        synchronous apis(__call__, stream_infer, etc.) or the asynchronous api (generate) during its lifetime.

        The purpose of this function is to allow users to choose whether to use the synchronous interface or the
        asynchronous interface for the pipeline.
        """
        self.session_mgr.attach_event_loop(loop)
        if hasattr(self.engine, 'start_loop'):
            if use_async_api:
                return self.engine.start_loop()
            else:
                fut = concurrent.futures.Future()

                def _start_loop(fut):
                    res = self.engine.start_loop()
                    fut.set_result(res)

                loop.call_soon_threadsafe(_start_loop, fut)
                return fut.result()
        else:
            return True

    """ DistServe Async Engine API Begin """

    def free_cache(self, session_id: int):
        if self.engine.end_session(session_id):
            logger.debug(f'successfully free session {session_id}')
        else:
            logger.warning(f'Invalid Free session {session_id}.')

    def p2p_initialize(self, init_request: DistServeInitRequest):
        return self.engine.p2p_initialize(init_request)

    def p2p_connect(self, conn_request: List[DistServeConnectionRequest]):
        return self.engine.p2p_connect(conn_request)

    def p2p_drop_connect(self, drop_conn_request: List[DistServeDropConnectionRequest]):
        return self.engine.p2p_drop_connect(drop_conn_request)

    """ DistServe Async Engine API End """

    async def async_get_reward_score(self, input_ids: List) -> List[float]:
        """Async version of get_reward_score."""
        supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
        if self.arch not in supported_reward_models:
            raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}')
        assert isinstance(input_ids, List)
        assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
        # Make input_ids a list of token_id list
        input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids

        logits = await self.async_get_logits(input_ids=input_ids)

        logits = [x.squeeze() for x in logits]
        scores = [x[-1].cpu().item() for x in logits]
        return scores

    async def async_get_logits(self,
                               input_ids,
                               sessions: List['Session'] | None = None,
                               sequence_start: bool = True,
                               sequence_end: bool = True) -> List[torch.Tensor]:
        assert input_ids and all(isinstance(_, List) for _ in input_ids)
        assert sessions is None or (len(sessions) == len(input_ids))

        logits = [None] * len(input_ids)

        async def _proc(session, i):
            async with session.request_handle() as handle:
                input_len = len(input_ids[i])
                # TODO(lvhan): Fix the ugly code later on
                max_new_tokens = 1 if self.backend == 'turbomind' else 0
                # The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage
                # when perform inference on a reward model.
                gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1)
                async with self.safe_run(handle,
                                         session=session,
                                         input_ids=input_ids[i],
                                         gen_config=gen_config,
                                         stream_output=False,
                                         sequence_start=sequence_start,
                                         sequence_end=sequence_end,
                                         step=session.step) as gen:
                    async for outputs in gen:
                        pass
                    logits[i] = outputs.logits[:input_len, :]

        create_sessions = False
        if sessions is None:
            create_sessions = True
            sessions = [self.session_mgr.get() for _ in range(len(input_ids))]
        tasks = [_proc(session, i) for i, session in enumerate(sessions)]
        await asyncio.gather(*tasks)
        if sequence_end and self.backend == 'pytorch':
            for session in sessions:
                await session.async_close()
        if sequence_end and create_sessions:
            for session in sessions:
                self.session_mgr.remove(session)
        return logits


================================================
FILE: lmdeploy/serve/core/exceptions.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
"""Exceptions for the serve module."""


class SafeRunException(Exception):
    """Exception raised by safe_run to avoid upper layer handling the original
    exception again.

    This exception wraps the original exception that occurred during safe_run execution.
    """


================================================
FILE: lmdeploy/serve/core/vl_async_engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal

from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig
from lmdeploy.utils import get_logger

from .async_engine import AsyncEngine

logger = get_logger('lmdeploy')


class VLAsyncEngine(AsyncEngine):
    """Visual Language Async inference engine."""

    def __init__(self,
                 model_path: str,
                 backend: Literal['turbomind', 'pytorch'] = 'turbomind',
                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,
                 vision_config: VisionConfig | None = None,
                 **kwargs) -> None:
        from lmdeploy.serve.processors import MultimodalProcessor
        from lmdeploy.utils import try_import_deeplink
        from lmdeploy.vl.engine import ImageEncoder

        if backend == 'pytorch':
            try_import_deeplink(backend_config.device_type)
        if backend_config and backend_config.enable_prefix_caching:
            backend_config.enable_prefix_caching = False
            logger.warning('Prefix caching is disabled since LMDeploy hasn\'t support in on VL models yet')
        self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config)
        super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs)
        # Update prompt_processor to support multimodal processing
        self.prompt_processor = MultimodalProcessor(self.tokenizer,
                                                    self.chat_template,
                                                    vl_encoder=self.vl_encoder,
                                                    backend=backend)
        if self.model_name == 'base':
            raise RuntimeError(
                'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template'  # noqa: E501
            )

    def close(self):
        if hasattr(self, 'vl_encoder'):
            del self.vl_encoder
            super().close()


================================================
FILE: lmdeploy/serve/managers/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .session_manager import Session, SessionManager

__all__ = ['Session', 'SessionManager']


================================================
FILE: lmdeploy/serve/managers/session_manager.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import annotations

import asyncio
import itertools
import weakref
from contextlib import asynccontextmanager
from typing import Any, List, Tuple

from lmdeploy.messages import GenerationConfig, Response
from lmdeploy.serve.core.exceptions import SafeRunException
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class Session:
    """Session for the engine."""

    def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs):
        self.session_id = session_id
        self.prompt: Any = None
        self.response: Response | None = None
        self.history: List[Tuple[Any, str]] = []
        self.gen_config: GenerationConfig | None = None
        self.step: int = 0
        # event to wait for the session to be active
        self._active: asyncio.Event | None = None
        self._handle = None  # inference instance
        self._session_mgr: SessionManager = weakref.ref(session_mgr)
        self.update(**kwargs)

    def update(self, **kwargs):
        """Update the session."""
        self.prompt = kwargs.get('prompt', self.prompt)
        self.gen_config = kwargs.get('gen_config', self.gen_config)
        self.step = kwargs.get('step', self.step)

    def __repr__(self) -> str:
        """Return a string representation of the Session object."""
        return (f'Session(session_id={self.session_id}, '
                f'step={self.step}, history_len={len(self.history)}, '
                f'has_response={self.response is not None}, '
                f'has_gen_config={self.gen_config is not None})')

    def __str__(self) -> str:
        """Return a human-readable string representation of the Session."""
        res = f'Session(id={self.session_id}, step={self.step})'
        if self.history:
            res += '\nHistory:\n'
            for user, assistant in self.history:
                if isinstance(user, list):
                    user = str(user)
                res += f'USER: \n{user}\nASSISTANT: \n{assistant}\n'
        return res

    def reset(self):
        """Reset the session to initial state.

        This method resets all session data (prompt, response, history, etc.) but keeps the session_id.
        """
        self.prompt = None
        self.response = None
        self.history = []
        self.gen_config = None
        self.step = 0
        self._active = None
        self._handle = None
        self._session_mgr = None
        logger.debug(f'Session {self.session_id} has been reset.')

    @asynccontextmanager
    async def request_handle(self):
        if self._handle is not None:
            raise RuntimeError(f'Session {self.session_id} already has an inference instance.')
        logger.debug(f'[request_handle] session {self.session_id} acquiring an instance')

        hnd_pool = self._session_mgr().request_handle_pool
        self._handle = await hnd_pool.get()
        self._active = asyncio.Event()
        logger.debug(f'[request_handle] session {self.session_id} acquired an instance')
        try:
            yield self._handle
        except SafeRunException:
            pass
        except (asyncio.CancelledError, GeneratorExit) as e:
            logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}')
            await self._handle.async_cancel(self.session_id)
        except Exception as e:
            logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}')
            raise
        finally:
            logger.debug(f'[request_handle] session {self.session_id} releasing the instance')
            # Return inference instance if it was acquired
            if self._handle is not None:
                hnd_pool.put(self._handle)
                self._handle = None
            # MUST set the signal after releasing the instance to avoid race condition
            # refer to async_end method
            self._active.set()

    async def async_abort(self):
        """Abort the session."""
        logger.info(f'[session] Aborting session {self.session_id}')
        if self._handle is not None:
            await self._handle.async_cancel(self.session_id)

    async def async_close(self):
        """End the session."""
        logger.info(f'[session] Ending session {self.session_id}')
        if self._handle is None and self.step == 0:
            return
        if self._handle is not None:
            await self._active.wait()
        async with self.request_handle() as handle:
            try:
                await handle.async_end(self.session_id)
            except (Exception, asyncio.CancelledError, GeneratorExit) as e:
                logger.exception(f'[async_close] exception caught: {e}')
        self.reset()

    def abort(self):
        """Abort the session in sync mode."""
        if self._session_mgr is not None:
            self._run(self.async_abort()).result()

    def close(self):
        """End the session in sync mode."""
        if self._session_mgr is not None:
            self._run(self.async_close()).result()

    def _run(self, coro):
        assert self._session_mgr is not None, 'Session manager is not initialized'
        return asyncio.run_coroutine_threadsafe(coro, self._session_mgr().loop)


class RequestHandlePool:
    """Manages a pool of request handles for concurrent request processing.

    This class maintains a fixed-size pool of request handles that can be reused
    across multiple inference requests. It implements a lazy-initialized queue-based
    pool pattern to efficiently manage handle lifecycle and enable concurrent
    request handling.

    Each session or request should acquire a handle from the pool before inference and
    return it after completion. The manager supports:
    - Pool-based handle allocation and deallocation
    - Lazy initialization of the async queue (required for asyncio.Queue)
    - Handle rebuilding after engine wakeup (e.g., turbomind backend)
    - Complete pool cleanup

    Args:
        engine (AsyncEngine): The async inference engine that creates handles.
        size (int): The size of the handle pool, typically set to max_batch_size.

    Note:
        The pool queue is lazily initialized on first access via `get()` method,
        as `asyncio.Queue` must be created within an async context.
    """

    def __init__(self, engine, size: int):
        self.size = size
        self.handles = [engine.create_instance() for _ in range(size)]
        # `asyncio.Queue` must be created in an async context, refer to `get` method
        self.pool: asyncio.Queue = None

    async def get(self):
        """Get a handle from pool."""
        # Lazy initialization: create pool on first use
        if self.pool is None:
            self.pool = asyncio.Queue()
            for inst in self.handles:
                self.pool.put_nowait(inst)

        return await self.pool.get()

    def put(self, handle):
        """Put a handle back to the pool."""
        if handle is not None and self.pool is not None:
            self.pool.put_nowait(handle)

    def clear(self):
        """Clear all handles."""
        self.handles = []
        self.pool = None


class SessionManager:
    """Session manager."""

    def __init__(self):
        """Initialize the session manager."""

        self.sessions = {}
        self.session_id_generator = itertools.count(1)
        self.request_handle_pool = None
        self.loop = None

    def get(self, session_id: int | None = None, **kwargs) -> Session:
        """Create a new session."""
        session_id = session_id or next(self.session_id_generator)
        if session_id in self.sessions:
            logger.debug(f'[SessionManager] session {session_id} already exists. Updating...')
            session = self.sessions[session_id]
            session.update(**kwargs)
            return session
        else:
            logger.info(f'[SessionManager] session {session_id} not found. Creating...')
            session = Session(session_id, self, **kwargs)
            self.sessions[session_id] = session
            return session

    async def async_abort_all(self):
        """Abort all sessions."""
        tasks = []
        for session in list(self.sessions.values()):
            tasks.append(session.async_abort())
        await asyncio.gather(*tasks, return_exceptions=True)
        # "abort all" is designed for async RL. The aborted sessions will be no longer used,
        # so we clear the sessions here.
        self.sessions.clear()

    def has(self, session_id):
        return session_id in self.sessions

    def remove(self, session: Session):
        self.sessions.pop(session.session_id, None)

    def clear(self):
        self.sessions.clear()
        # reset the session id generator
        self.session_id_generator = itertools.count(1)

    def attach_event_loop(self, loop):
        self.loop = loop

    def build_request_handle_pool(self, engine, size):
        """Build the request handle's pool."""
        self.request_handle_pool = RequestHandlePool(engine, size)


================================================
FILE: lmdeploy/serve/openai/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/serve/openai/api_client.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Any, Dict, List, Optional, Union

import requests

from lmdeploy.utils import get_logger


def get_model_list(api_url: str, headers: dict = None):
    """Get model list from api server."""
    response = requests.get(api_url, headers=headers)
    logger = get_logger('lmdeploy')
    if not response.ok:
        logger.error(f'Failed to get the model list: {api_url}'
                     f' returns {response.status_code}')
        return None
    elif not hasattr(response, 'text'):
        logger.warning('Failed to get the model list.')
        return None
    else:
        model_list = response.json()
        model_list = model_list.pop('data', [])
        return [item['id'] for item in model_list]


def json_loads(content):
    """Loads content to json format."""
    try:
        content = json.loads(content)
        return content
    except:  # noqa
        logger = get_logger('lmdeploy')
        logger.warning(f'weird json content {content}')
        return ''


class APIClient:
    """Chatbot for LLaMA series models with turbomind as inference engine.

    Args:
        api_server_url (str): communicating address 'http://:' of
            api_server
        api_key (str | None): api key. Default to None, which means no
            api key will be used.
    """

    def __init__(self, api_server_url: str, api_key: Optional[str] = None, **kwargs):
        self.api_server_url = api_server_url
        self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'
        self.completions_v1_url = f'{api_server_url}/v1/completions'
        self.models_v1_url = f'{api_server_url}/v1/models'
        self.encode_v1_url = f'{api_server_url}/v1/encode'
        self._available_models = None
        self.api_key = api_key
        self.headers = {'content-type': 'application/json'}
        if api_key is not None:
            self.headers['Authorization'] = f'Bearer {api_key}'

    @property
    def available_models(self):
        """Show available models."""
        if self._available_models is not None:
            return self._available_models
        self._available_models = get_model_list(self.models_v1_url, headers=self.headers)
        return self._available_models

    def encode(self,
               input: Union[str, List[str]],
               do_preprocess: Optional[bool] = False,
               add_bos: Optional[bool] = True):
        """Encode prompts.

        Args:
            input: the prompt to be encoded. In str or List[str] format.
            do_preprocess: whether do preprocess or not. Default to False.
            add_bos: True when it is the beginning of a conversation. False
                when it is not. Default to True.
        Return: (input_ids, length)
        """
        response = requests.post(self.encode_v1_url,
                                 headers=self.headers,
                                 json=dict(input=input, do_preprocess=do_preprocess, add_bos=add_bos),
                                 stream=False)
        if hasattr(response, 'text'):
            output = json_loads(response.text)
            return output['input_ids'], output['length']
        return None, None

    def chat_completions_v1(
        self,
        model: str,
        messages: Union[str, List[Dict[str, str]]],
        temperature: Optional[float] = 0.7,
        top_p: Optional[float] = 1.0,
        logprobs: Optional[bool] = False,
        top_logprobs: Optional[int] = 0,
        n: Optional[int] = 1,
        max_completion_tokens: Optional[int] = None,
        max_tokens: Optional[int] = None,
        stop: Optional[Union[str, List[str]]] = None,
        stream: Optional[bool] = False,
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
        user: Optional[str] = None,
        repetition_penalty: Optional[float] = 1.0,
        ignore_eos: Optional[bool] = False,
        skip_special_tokens: Optional[bool] = True,
        spaces_between_special_tokens: Optional[bool] = True,
        top_k: int = 40,
        min_new_tokens: Optional[int] = None,
        min_p: float = 0.0,
        logit_bias: Optional[Dict[str, float]] = None,
        stream_options: Optional[Dict] = None,
        **kwargs,
    ):
        """Chat completion v1.

        Args:
            model: model name. Available from self.available_models.
            messages: string prompt or chat history in OpenAI format. Chat
                history example: `[{"role": "user", "content": "hi"}]`.
            temperature (float): to modulate the next token probability
            top_p (float): If set to float < 1, only the smallest set of most
                probable tokens with probabilities that add up to top_p or
                higher are kept for generation.
            n (int): How many chat completion choices to generate for each
                input message. Only support one here.
            stream: whether to stream the results or not. Default to false.
            max_completion_tokens (int | None): output token nums. Default to None.
            max_tokens (int | None): output token nums. Default to None.
                Deprecated: Use max_completion_tokens instead.
            stop (str | List[str] | None): To stop generating further
              tokens. Only accept stop words that's encoded to one token idex.
            repetition_penalty (float): The parameter for repetition penalty.
                1.0 means no penalty
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be True.
            spaces_between_special_tokens (bool): Whether or not to add spaces
                around special tokens. The behavior of Fast tokenizers is to have
                this to False. This is setup to True in slow tokenizers.
            top_k (int): The number of the highest probability vocabulary
                tokens to keep for top-k-filtering
            min_new_tokens (int): To generate at least numbers of tokens.
            min_p (float): Minimum token probability, which will be scaled by the
                probability of the most likely token. It must be a value between
                0 and 1. Typical values are in the 0.01-0.2 range, comparably
                selective as setting `top_p` in the 0.99-0.8 range (use the
                opposite of normal `top_p` values)
            logit_bias (Dict): Bias to logits. Only supported in pytorch engine.
            stream_options: Options for streaming response. Only set this when you
                set stream: true.

        Yields:
            json objects in openai formats
        """
        pload = {k: v for k, v in locals().copy().items() if k[:2] != '__' and k not in ['self']}
        response = requests.post(self.chat_completions_v1_url, headers=self.headers, json=pload, stream=stream)
        for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
            if chunk:
                if stream:
                    decoded = chunk.decode('utf-8')
                    if decoded == 'data: [DONE]':
                        continue
                    if decoded[:6] == 'data: ':
                        decoded = decoded[6:]
                    output = json_loads(decoded)
                    yield output
                else:
                    decoded = chunk.decode('utf-8')
                    output = json_loads(decoded)
                    yield output

    def completions_v1(
        self,
        model: str,
        prompt: Union[str, List[Any]],
        suffix: Optional[str] = None,
        temperature: Optional[float] = 0.7,
        n: Optional[int] = 1,
        max_completion_tokens: Optional[int] = 16,
        max_tokens: Optional[int] = 16,
        stream: Optional[bool] = False,
        stop: Optional[Union[str, List[str]]] = None,
        top_p: Optional[float] = 1.0,
        top_k: Optional[int] = 40,
        user: Optional[str] = None,
        # additional argument of lmdeploy
        repetition_penalty: Optional[float] = 1.0,
        ignore_eos: Optional[bool] = False,
        skip_special_tokens: Optional[bool] = True,
        spaces_between_special_tokens: Optional[bool] = True,
        stream_options: Optional[Dict] = None,
        **kwargs,
    ):
        """Chat completion v1.

        Args:
            model (str): model name. Available from /v1/models.
            prompt (str): the input prompt.
            suffix (str): The suffix that comes after a completion of inserted
                text.
            max_completion_tokens (int | None): output token nums. Default to 16.
            max_tokens (int): output token nums
                Deprecated: Use max_completion_tokens instead.
            temperature (float): to modulate the next token probability
            top_p (float): If set to float < 1, only the smallest set of most
                probable tokens with probabilities that add up to top_p or
                higher are kept for generation.
            top_k (int): The number of the highest probability vocabulary
                tokens to keep for top-k-filtering
            n (int): How many chat completion choices to generate for each
                input message. Only support one here.
            stream: whether to stream the results or not. Default to false.
            stop (str | List[str] | None): To stop generating further
              tokens. Only accept stop words that's encoded to one token idex.
            repetition_penalty (float): The parameter for repetition penalty.
                1.0 means no penalty
            user (str): A unique identifier representing your end-user.
            ignore_eos (bool): indicator for ignoring eos
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be True.
            spaces_between_special_tokens (bool): Whether or not to add spaces
                around special tokens. The behavior of Fast tokenizers is to have
                this to False. This is setup to True in slow tokenizers.
            stream_options: Options for streaming response. Only set this when you
                set stream: true.

        Yields:
            json objects in openai formats
        """
        pload = {k: v for k, v in locals().copy().items() if k[:2] != '__' and k not in ['self']}
        response = requests.post(self.completions_v1_url, headers=self.headers, json=pload, stream=stream)
        for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
            if chunk:
                if stream:
                    decoded = chunk.decode('utf-8')
                    if decoded == 'data: [DONE]':
                        continue
                    if decoded[:6] == 'data: ':
                        decoded = decoded[6:]
                    output = json_loads(decoded)
                    yield output
                else:
                    decoded = chunk.decode('utf-8')
                    output = json_loads(decoded)
                    yield output


================================================
FILE: lmdeploy/serve/openai/api_server.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
import asyncio
import copy
import json
import os
import re
import time
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncGenerator, Literal

import uvicorn
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Mount

from lmdeploy.archs import get_task
from lmdeploy.messages import (GenerationConfig, LogitsProcessor, PytorchEngineConfig, SpeculativeConfig,
                               TurbomindEngineConfig)
from lmdeploy.metrics.metrics_processor import metrics_processor
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,
                                                   DistServeDropConnectionRequest, DistServeInitRequest,
                                                   MigrationRequest)
from lmdeploy.serve.core import AsyncEngine
from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
from lmdeploy.serve.openai.protocol import ChatCompletionResponse  # noqa: E501
from lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice,
                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
                                            ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest,
                                            CompletionResponse, CompletionResponseChoice,
                                            CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
                                            EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
                                            GenerateReqInput, GenerateReqMetaOutput, GenerateReqOutput, LogProbs,
                                            ModelCard, ModelList, ModelPermission, PoolingRequest, PoolingResponse,
                                            TopLogprob, UpdateParamsRequest, UsageInfo)
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
from lmdeploy.serve.utils.server_utils import validate_json_request
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import get_logger

# yapf: enable

logger = get_logger('lmdeploy')


class VariableInterface:
    """A IO interface maintaining variables."""
    async_engine: AsyncEngine = None
    request_hosts = []
    # following are for registering to proxy server
    proxy_url: str | None = None
    api_server_url: str | None = None
    # following are for reasoning parsers
    reasoning_parser: ReasoningParser | None = None
    # following is for tool parsers
    tool_parser: ToolParser | None = None
    allow_terminate_by_client: bool = False
    enable_abort_handling: bool = False

    @staticmethod
    def get_session(session_id: int) -> int:
        session_mgr = VariableInterface.get_session_manager()
        if session_id == -1:
            return session_mgr.get()
        else:
            return session_mgr.get(session_id)

    @staticmethod
    def get_session_manager():
        return VariableInterface.async_engine.session_mgr

    @staticmethod
    def get_engine_config():
        return VariableInterface.async_engine.backend_config


router = APIRouter()
server_context = VariableInterface()


def get_model_list():
    """Available models.

    If it is a slora serving. The model list would be [model_name, adapter_name1, adapter_name2, ...]
    """
    model_names = [VariableInterface.async_engine.model_name]
    cfg = VariableInterface.async_engine.backend_config
    model_names += getattr(cfg, 'adapters', None) or []
    return model_names


@router.get('/v1/models')
def available_models():
    """Show available models."""
    model_cards = []
    for model_name in get_model_list():
        model_cards.append(ModelCard(id=model_name, root=model_name, permission=[ModelPermission()]))
    return ModelList(data=model_cards)


def create_error_response(status: HTTPStatus, message: str, error_type='invalid_request_error'):
    """Create error response according to http status and message.

    Args:
        status (HTTPStatus): HTTP status codes and reason phrases
        message (str): error message
        error_type (str): error type
    """
    return JSONResponse(ErrorResponse(message=message, type=error_type, code=status.value).model_dump(),
                        status_code=status.value)


def check_request(request) -> JSONResponse | None:
    """Check if a request is valid."""
    if hasattr(request, 'model') and request.model not in get_model_list():
        return create_error_response(HTTPStatus.NOT_FOUND, f'The model {request.model!r} does not exist.')

    # Import the appropriate check function based on request type
    if isinstance(request, ChatCompletionRequest):
        from .serving_chat_completion import check_request
        check_func = check_request
    elif isinstance(request, CompletionRequest):
        from .serving_completion import check_request
        check_func = check_request
    elif isinstance(request, GenerateReqInput):
        from .serving_generate import check_request
        check_func = check_request
    else:
        # Define an async function that always returns success
        def always_success(req, server_context):
            return ''

        check_func = always_success

    error_msg = check_func(request, server_context)
    if error_msg:
        return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
    return None


def _create_completion_logprobs(tokenizer: Tokenizer,
                                token_ids: list[int] | None = None,
                                logprobs: list[dict[int, float]] | None = None,
                                skip_special_tokens: bool = True,
                                offset: int = 0,
                                all_token_ids: list[int] | None = None,
                                state: DetokenizeState = None,
                                spaces_between_special_tokens: bool = True):
    """Create openai LogProbs for completion.

    Args:
        tokenizer (Tokenizer): tokenizer.
        token_ids (List[int]): output token ids.
        logprobs (List[Dict[int, float]]): the top logprobs for each output
            position.
        skip_special_tokens (bool): Whether or not to remove special tokens
            in the decoding. Default to be True.
        offset (int): text offset.
        all_token_ids (int): the history output token ids.
        state (DetokenizeState): tokenizer decode state.
        spaces_between_special_tokens (bool): Whether or not to add spaces
            around special tokens. The behavior of Fast tokenizers is to have
            this to False. This is setup to True in slow tokenizers.
    """
    if logprobs is None or len(logprobs) == 0:
        return None, None, None, None

    if all_token_ids is None:
        all_token_ids = []
    if state is None:
        state = DetokenizeState()

    out_logprobs = LogProbs()
    out_logprobs.top_logprobs = []
    for token_id, tops in zip(token_ids, logprobs):
        out_logprobs.text_offset.append(offset)
        out_logprobs.token_logprobs.append(tops[token_id])

        res = {}
        out_state = None
        for top_id, prob in tops.items():
            response, _state = tokenizer.detokenize_incrementally(
                all_token_ids + [top_id],
                copy.deepcopy(state),
                skip_special_tokens=skip_special_tokens,
                spaces_between_special_tokens=spaces_between_special_tokens)
            res[response] = prob
            if top_id == token_id:
                out_state = _state
                offset += len(response)
                out_logprobs.tokens.append(response)

        out_logprobs.top_logprobs.append(res)
        state = out_state
        all_token_ids.append(token_id)

    return out_logprobs, offset, all_token_ids, state


def _create_chat_completion_logprobs(tokenizer: Tokenizer,
                                     token_ids: list[int] | None = None,
                                     logprobs: list[dict[int, float]] | None = None):
    """Create openai LogProbs for chat.completion.

    Args:
        tokenizer (Tokenizer): tokenizer.
        token_ids (List[int]): output token ids.
        logprobs (List[Dict[int, float]]): the top logprobs for each output
            position.
    Returns:
        ChoiceLogprobs: logprob result.
    """
    if token_ids is None or logprobs is None:
        return None

    content: list[ChatCompletionTokenLogprob] = []
    for token_id, tops in zip(token_ids, logprobs):
        item = ChatCompletionTokenLogprob(token='', bytes=[], logprob=0.0, top_logprobs=[])
        for top_id, prob in tops.items():
            token = tokenizer.model.model.convert_ids_to_tokens(top_id)
            if isinstance(token, bytes):
                _bytes = list(token)
                token = token.decode('utf-8', errors='backslashreplace')
            else:
                _bytes = list(token.encode())  # token is str
            if top_id == token_id:
                item.token = token
                item.bytes = _bytes
                item.logprob = prob
            else:
                item.top_logprobs.append(TopLogprob(token=token, bytes=_bytes, logprob=prob))
        content.append(item)
    return ChoiceLogprobs(content=content)


@router.get('/health')
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


@router.get('/terminate')
async def terminate():
    """Terminate server."""
    import signal

    if not VariableInterface.allow_terminate_by_client:
        return create_error_response(
            HTTPStatus.BAD_REQUEST,
            'The server can not be terminated. Please add --allow-terminate-by-client when start the server.')
    os.kill(os.getpid(), signal.SIGTERM)
    return Response(status_code=200)


# modified from https://github.com/vllm-project/vllm/blob/v0.5.4/vllm/entrypoints/openai/logits_processors.py#L51  # noqa
def logit_bias_logits_processor(logit_bias: dict[int, float] | dict[str, float], tokenizer) -> LogitsProcessor:
    try:
        # Convert token_id to integer
        # Clamp the bias between -100 and 100 per OpenAI API spec
        clamped_logit_bias: dict[int, float] = {
            int(token_id): min(100.0, max(-100.0, bias))
            for token_id, bias in logit_bias.items()
        }
    except ValueError as exc:
        raise ValueError('Found token_id in logit_bias that is not '
                         'an integer or string representing an integer') from exc

    # Check if token_id is within the vocab size
    for token_id, bias in clamped_logit_bias.items():
        if token_id < 0 or token_id >= tokenizer.vocab_size:
            raise ValueError(f'token_id {token_id} in logit_bias contains '
                             'out-of-vocab token id')

    def _logit_bias_processor(
        logit_bias,
        token_ids,
        logits,
    ):
        for token_id, bias in logit_bias.items():
            logits[token_id] = logits[token_id] + bias
        return logits

    return partial(_logit_bias_processor, clamped_logit_bias)


@router.post('/v1/chat/completions', dependencies=[Depends(validate_json_request)])
async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None):
    """Completion API similar to OpenAI's API.

    Refer to https://platform.openai.com/docs/api-reference/chat/create
    for the API specification.

    The request should be a JSON object with the following fields:

    - **model**: model name. Available from /v1/models.
    - **messages**: string prompt or chat history in OpenAI format. Chat history example:
      ``[{"role": "user", "content": "hi"}]``.
    - **temperature** (float): to modulate the next token probability
    - **top_p** (float): If set to float < 1, only the smallest set of most
      probable tokens with probabilities that add up to top_p or higher
      are kept for generation.
    - **n** (int): How many chat completion choices to generate for each input
      message. **Only support one here**.
    - **stream**: whether to stream the results or not. Default to false.
    - **stream_options**: Options for streaming response. Only set this when you
      set stream: true.
    - **max_completion_tokens** (int | None): output token nums. Default to None.
    - **max_tokens** (int | None): output token nums. Default to None.
      Deprecated: Use max_completion_tokens instead.
    - **repetition_penalty** (float): The parameter for repetition penalty.
      1.0 means no penalty
    - **stop** (str | List[str] | None): To stop generating further
      tokens. Only accept stop words that's encoded to one token idex.
    - **response_format** (dict | None): To generate response according to given
      schema. Examples:

      .. code-block:: json

        {
          "type": "json_schema",
          "json_schema":{
            "name": "test",
            "schema":{
              "properties":{
                "name":{"type":"string"}
              },
              "required":["name"],
              "type":"object"
            }
          }
        }

      or ``{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}``
    - **logit_bias** (dict): Bias to logits. Only supported in pytorch engine.
    - **tools** (list): A list of tools the model may call. Currently, only
      internlm2 functions are supported as a tool. Use this to specify a
      list of functions for which the model can generate JSON inputs.
    - **tool_choice** (str | object): Controls which (if any) tool is called by
      the model. `none` means the model will not call any tool and instead
      generates a message. Specifying a particular tool via
      ``{"type": "function", "function": {"name": "my_function"}}``
      forces the model to call that tool. `auto` or `required` will put all
      the tools informationto the model.

    Additional arguments supported by LMDeploy:

    - **top_k** (int): The number of the highest probability vocabulary
      tokens to keep for top-k-filtering
    - **ignore_eos** (bool): indicator for ignoring eos
    - **skip_special_tokens** (bool): Whether or not to remove special tokens
      in the decoding. Default to be True.
    - **spaces_between_special_tokens** (bool): Whether or not to add spaces
      around special tokens. The behavior of Fast tokenizers is to have
      this to False. This is setup to True in slow tokenizers.
    - **min_new_tokens** (int): To generate at least numbers of tokens.
    - **min_p** (float): Minimum token probability, which will be scaled by the
      probability of the most likely token. It must be a value between
      0 and 1. Typical values are in the 0.01-0.2 range, comparably
      selective as setting `top_p` in the 0.99-0.8 range (use the
      opposite of normal `top_p` values)

    Currently we do not support the following features:

    - **presence_penalty** (replaced with repetition_penalty)
    - **frequency_penalty** (replaced with repetition_penalty)
    """
    error_check_ret = check_request(request)
    if error_check_ret is not None:
        return error_check_ret
    session = VariableInterface.get_session(request.session_id)

    json_request = await raw_request.json()
    migration_request = json_request.pop('migration_request', None)
    with_cache = json_request.pop('with_cache', False)
    preserve_cache = json_request.pop('preserve_cache', False)
    if migration_request:
        migration_request = MigrationRequest.model_validate(migration_request)

    model_name = request.model
    adapter_name = None
    if model_name != VariableInterface.async_engine.model_name:
        adapter_name = model_name  # got a adapter name
    request_id = str(session.session_id)
    created_time = int(time.time())
    gpt_oss_parser = None
    if VariableInterface.async_engine.arch == 'GptOssForCausalLM':
        gpt_oss_parser = GptOssChatParser()

    if isinstance(request.stop, str):
        request.stop = [request.stop]

    gen_logprobs, logits_processors = None, None
    if request.logprobs and request.top_logprobs:
        gen_logprobs = request.top_logprobs
    response_format = None
    if request.response_format and request.response_format.type != 'text':
        response_format = request.response_format.model_dump()

    if request.logit_bias is not None:
        try:
            logits_processors = [
                logit_bias_logits_processor(request.logit_bias, VariableInterface.async_engine.tokenizer.model)
            ]
        except Exception as e:
            return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

    random_seed = request.seed if request.seed else None
    max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens)

    gen_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=True,
        logprobs=gen_logprobs,
        top_k=request.top_k,
        top_p=request.top_p,
        temperature=request.temperature,
        repetition_penalty=request.repetition_penalty,
        ignore_eos=request.ignore_eos,
        stop_words=request.stop,
        include_stop_str_in_output=request.include_stop_str_in_output,
        skip_special_tokens=request.skip_special_tokens,
        response_format=response_format,
        logits_processors=logits_processors,
        min_new_tokens=request.min_new_tokens,
        min_p=request.min_p,
        random_seed=random_seed,
        spaces_between_special_tokens=request.spaces_between_special_tokens,
        migration_request=migration_request,
        with_cache=with_cache,
        preserve_cache=preserve_cache,
    )

    tools = None
    if request.tools and request.tool_choice != 'none':
        gen_config.skip_special_tokens = False
        # internlm2 only uses contents inside function regardless of 'type'
        if not isinstance(request.tool_choice, str):
            if gpt_oss_parser:
                tools = [
                    item.model_dump() for item in request.tools
                    if item.function.name == request.tool_choice.function.name
                ]
            else:
                tools = [
                    item.function.model_dump() for item in request.tools
                    if item.function.name == request.tool_choice.function.name
                ]
        else:
            if gpt_oss_parser:
                tools = [item.model_dump() for item in request.tools]
            else:
                tools = [item.function.model_dump() for item in request.tools]
    # text completion for string input
    do_preprocess = False if isinstance(request.messages, str) else request.do_preprocess
    chat_template_kwargs = request.chat_template_kwargs or {}
    if request.enable_thinking is not None:
        logger.warning('`enable_thinking` will be deprecated in the future, '
                       'please use `chat_template_kwargs` instead.')
        if chat_template_kwargs.get('enable_thinking') is None:
            chat_template_kwargs['enable_thinking'] = request.enable_thinking
        else:
            logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.')
    enable_thinking = chat_template_kwargs.get('enable_thinking', None)
    result_generator = VariableInterface.async_engine.generate(
        request.messages,
        session,
        gen_config=gen_config,
        tools=tools,
        reasoning_effort=request.reasoning_effort,
        stream_response=True,  # always use stream to enable batching
        sequence_start=True,
        sequence_end=True,
        do_preprocess=do_preprocess,
        adapter_name=adapter_name,
        chat_template_kwargs=chat_template_kwargs or None,
        media_io_kwargs=request.media_io_kwargs,
        mm_processor_kwargs=request.mm_processor_kwargs)

    def create_stream_response_json(index: int,
                                    delta_message: DeltaMessage,
                                    finish_reason: str | None = None,
                                    logprobs: LogProbs | None = None,
                                    usage: UsageInfo | None = None) -> str:
        choice_data = ChatCompletionResponseStreamChoice(index=index,
                                                         delta=delta_message,
                                                         finish_reason=finish_reason,
                                                         logprobs=logprobs)
        response = ChatCompletionStreamResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=[choice_data],
            usage=usage,
        )
        response_json = response.model_dump_json()

        return response_json

    async def completion_stream_generator() -> AsyncGenerator[str, None]:
        previous_text = ''
        current_text = ''
        previous_token_ids = []
        current_token_ids = []
        delta_token_ids = []
        has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None
        streaming_tools = False
        async for res in result_generator:
            logprobs, usage = None, None
            if gen_logprobs and res.logprobs:
                logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids,
                                                            res.logprobs)
            # Only stream chunk `usage` in the final chunk according to OpenAI API spec
            if (res.finish_reason and request.stream_options and request.stream_options.include_usage):
                total_tokens = sum([res.input_token_len, res.generate_token_len])
                usage = UsageInfo(
                    prompt_tokens=res.input_token_len,
                    completion_tokens=res.generate_token_len,
                    total_tokens=total_tokens,
                )

            delta_token_ids = res.token_ids if res.token_ids is not None else []
            if gpt_oss_parser:
                delta_message = gpt_oss_parser.parse_streaming(res.token_ids)
                if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0:
                    res.finish_reason = 'tool_calls'
            else:
                delta_message = DeltaMessage(role='assistant', content=res.response)
                if has_parser:
                    current_text = current_text + res.response
                    current_token_ids = current_token_ids + delta_token_ids
                if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
                    if res.finish_reason == 'stop' and streaming_tools is True:
                        res.finish_reason = 'tool_calls'
                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
                        previous_text=previous_text,
                        current_text=current_text,
                        delta_text=delta_message.content,
                        previous_token_ids=previous_token_ids,
                        current_token_ids=current_token_ids,
                        delta_token_ids=delta_token_ids,
                        request=request)
                    if tool_delta is not None:
                        delta_message.tool_calls = tool_delta.tool_calls
                        delta_message.content = tool_delta.content
                        if isinstance(tool_delta.tool_calls, list) and len(tool_delta.tool_calls):
                            streaming_tools = True
                elif (request.tool_choice != 'none' and request.tools is not None
                      and VariableInterface.tool_parser is None):
                    logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
                if VariableInterface.reasoning_parser is not None and enable_thinking is not False:
                    reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(
                        previous_text=previous_text,
                        current_text=current_text,
                        delta_text=delta_message.content or '',
                        previous_token_ids=previous_token_ids,
                        current_token_ids=current_token_ids,
                        delta_token_ids=delta_token_ids)
                    if reasoning_delta is not None:
                        delta_message.reasoning_content = reasoning_delta.reasoning_content
                        delta_message.content = reasoning_delta.content
                if has_parser:
                    previous_text = current_text
                    previous_token_ids = current_token_ids
            if request.return_token_ids:
                delta_message.gen_tokens = delta_token_ids
            response_json = create_stream_response_json(index=0,
                                                        delta_message=delta_message,
                                                        finish_reason=res.finish_reason,
                                                        logprobs=logprobs,
                                                        usage=usage)
            if res.cache_block_ids is not None:
                response_json['cache_block_ids'] = res.cache_block_ids
                response_json['remote_token_ids'] = res.token_ids
            yield f'data: {response_json}\n\n'
        yield 'data: [DONE]\n\n'

    # Streaming response
    if request.stream:
        return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')

    # Non-streaming response
    final_logprobs = []
    final_token_ids = []
    final_res = None
    text = ''
    cache_block_ids = []
    remote_token_ids = []
    async for res in result_generator:
        if await raw_request.is_disconnected():
            # Abort the request if the client disconnects.
            await session.async_abort()
            return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')
        final_res = res
        text += res.response
        if res.token_ids:
            final_token_ids.extend(res.token_ids)
        if res.logprobs:
            final_logprobs.extend(res.logprobs)
        cache_block_ids.append(res.cache_block_ids)
        remote_token_ids.append(res.token_ids)

    if gpt_oss_parser:
        message = gpt_oss_parser.parse_full(final_token_ids)
        if final_res.finish_reason == 'stop' and len(message.tool_calls) > 0:
            final_res.finish_reason = 'tool_calls'
    else:
        tool_calls = None
        reasoning_content = None
        if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
            try:
                tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
                text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
                if isinstance(tool_calls, list) and len(tool_calls):
                    if final_res.finish_reason == 'stop':
                        final_res.finish_reason = 'tool_calls'

            except Exception as e:
                logger.error(f'Failed to parse {text}. Exception: {e}.')
                return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!')
        elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:
            logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')

        if VariableInterface.reasoning_parser is not None and enable_thinking is not False:
            reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)

        message = ChatMessage(role='assistant',
                              content=text,
                              tool_calls=tool_calls,
                              reasoning_content=reasoning_content)

    logprobs = None
    if gen_logprobs and len(final_logprobs):
        logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids,
                                                    final_logprobs)

    assert final_res is not None
    choices = []
    if request.return_token_ids:
        message.gen_tokens = final_token_ids
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=message,
        logprobs=logprobs,
        finish_reason=final_res.finish_reason,
    )
    choices.append(choice_data)

    if with_cache:
        cache_block_ids = cache_block_ids[0]
        remote_token_ids = [remote_token_ids[0][-1]]

    total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])
    usage = UsageInfo(
        prompt_tokens=final_res.input_token_len,
        completion_tokens=final_res.generate_token_len,
        total_tokens=total_tokens,
    )
    response = ChatCompletionResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        choices=choices,
        usage=usage,
    ).model_dump()

    if with_cache:
        response['cache_block_ids'] = cache_block_ids
        response['remote_token_ids'] = remote_token_ids

    return response


@router.post('/v1/completions', dependencies=[Depends(validate_json_request)])
async def completions_v1(request: CompletionRequest, raw_request: Request = None):
    """Completion API similar to OpenAI's API.

    Go to https://platform.openai.com/docs/api-reference/completions/create
    for the API specification.

    The request should be a JSON object with the following fields:

    - **model** (str): model name. Available from /v1/models.
    - **prompt** (str): the input prompt.
    - **suffix** (str): The suffix that comes after a completion of inserted text.
    - **max_completion_tokens** (int | None): output token nums. Default to None.
    - **max_tokens** (int | None): output token nums. Default to 16.
      Deprecated: Use max_completion_tokens instead.
    - **temperature** (float): to modulate the next token probability
    - **top_p** (float): If set to float < 1, only the smallest set of most
      probable tokens with probabilities that add up to top_p or higher
      are kept for generation.
    - **n** (int): How many chat completion choices to generate for each input
      message. **Only support one here**.
    - **stream**: whether to stream the results or not. Default to false.
    - **stream_options**: Options for streaming response. Only set this when you
      set stream: true.
    - **repetition_penalty** (float): The parameter for repetition penalty.
      1.0 means no penalty
    - **user** (str): A unique identifier representing your end-user.
    - **stop** (str | list[str] | None): To stop generating further
      tokens. Only accept stop words that's encoded to one token idex.

    Additional arguments supported by LMDeploy:

    - **ignore_eos** (bool): indicator for ignoring eos
    - **skip_special_tokens** (bool): Whether or not to remove special tokens
      in the decoding. Default to be True.
    - **spaces_between_special_tokens** (bool): Whether or not to add spaces
      around special tokens. The behavior of Fast tokenizers is to have
      this to False. This is setup to True in slow tokenizers.
    - **top_k** (int): The number of the highest probability vocabulary
      tokens to keep for top-k-filtering
    - **min_p** (float): Minimum token probability, which will be scaled by the
      probability of the most likely token. It must be a value between
      0 and 1. Typical values are in the 0.01-0.2 range, comparably
      selective as setting `top_p` in the 0.99-0.8 range (use the
      opposite of normal `top_p` values)

    Currently we do not support the following features:

    - **logprobs** (not supported yet)
    - **presence_penalty** (replaced with repetition_penalty)
    - **frequency_penalty** (replaced with repetition_penalty)
    """
    error_check_ret = check_request(request)
    if error_check_ret is not None:
        return error_check_ret

    json_request = await raw_request.json()
    migration_request = json_request.pop('migration_request', None)
    with_cache = json_request.pop('with_cache', False)
    preserve_cache = json_request.pop('preserve_cache', False)
    if migration_request:
        migration_request = MigrationRequest.model_validate(migration_request)

    model_name = request.model
    adapter_name = None
    if model_name != VariableInterface.async_engine.model_name:
        adapter_name = model_name  # got a adapter name
    request_id = str(request.session_id)
    created_time = int(time.time())
    sessions = []
    if isinstance(request.prompt, str):
        request.prompt = [request.prompt]
        sessions.append(VariableInterface.get_session(request.session_id))
    elif isinstance(request.prompt, list):
        for i in range(len(request.prompt)):
            sessions.append(VariableInterface.get_session(i + 1))
    if isinstance(request.stop, str):
        request.stop = [request.stop]
    random_seed = request.seed if request.seed else None
    max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens)

    gen_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=True,
        logprobs=request.logprobs,
        top_k=request.top_k,
        top_p=request.top_p,
        temperature=request.temperature,
        repetition_penalty=request.repetition_penalty,
        ignore_eos=request.ignore_eos,
        stop_words=request.stop,
        skip_special_tokens=request.skip_special_tokens,
        min_p=request.min_p,
        random_seed=random_seed,
        spaces_between_special_tokens=request.spaces_between_special_tokens,
        migration_request=migration_request,
        with_cache=with_cache,
        preserve_cache=preserve_cache,
    )
    generators = []
    for prompt, session in zip(request.prompt, sessions):
        result_generator = VariableInterface.async_engine.generate(
            prompt,
            session,
            gen_config=gen_config,
            stream_response=True,  # always use stream to enable batching
            sequence_start=True,
            sequence_end=True,
            do_preprocess=False,
            adapter_name=adapter_name)
        generators.append(result_generator)

    def create_stream_response_json(index: int,
                                    text: str,
                                    finish_reason: str | None = None,
                                    logprobs: LogProbs | None = None,
                                    gen_tokens: list[int] | None = None,
                                    usage: UsageInfo | None = None) -> str:
        choice_data = CompletionResponseStreamChoice(index=index,
                                                     text=text,
                                                     gen_tokens=gen_tokens,
                                                     finish_reason=finish_reason,
                                                     logprobs=logprobs)
        response = CompletionStreamResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=[choice_data],
            usage=usage,
        )
        response_json = response.model_dump()
        return response_json

    async def completion_stream_generator() -> AsyncGenerator[str, None]:
        # First chunk with role
        for generator in generators:
            offset = 0
            all_token_ids = []
            state = DetokenizeState()
            async for res in generator:
                logprobs = None
                usage = None
                if request.logprobs and res.logprobs:
                    logprobs, offset, all_token_ids, state = _create_completion_logprobs(  # noqa E501
                        VariableInterface.async_engine.tokenizer, res.token_ids, res.logprobs,
                        gen_config.skip_special_tokens, offset, all_token_ids, state,
                        gen_config.spaces_between_special_tokens)
                # Only stream chunk `usage` in the final chunk according to OpenAI API spec
                if (res.finish_reason and request.stream_options and request.stream_options.include_usage):
                    final_res = res
                    total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])
                    usage = UsageInfo(
                        prompt_tokens=final_res.input_token_len,
                        completion_tokens=final_res.generate_token_len,
                        total_tokens=total_tokens,
                    )
                gen_tokens = None
                if request.return_token_ids:
                    gen_tokens = res.token_ids or []
                response_json = create_stream_response_json(index=0,
                                                            text=res.response,
                                                            gen_tokens=gen_tokens,
                                                            finish_reason=res.finish_reason,
                                                            logprobs=logprobs,
                                                            usage=usage)
                if res.cache_block_ids is not None:
                    response_json['cache_block_ids'] = res.cache_block_ids
                    response_json['remote_token_ids'] = res.token_ids
                yield f'data: {json.dumps(response_json)}\n\n'
        yield 'data: [DONE]\n\n'

    # Streaming response
    if request.stream:
        return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')

    # Non-streaming response
    usage = UsageInfo()
    choices = [None] * len(generators)
    cache_block_ids = []
    remote_token_ids = []

    async def _inner_call(i, generator):
        nonlocal cache_block_ids, remote_token_ids
        final_logprobs = []
        final_token_ids = []
        final_res = None
        text = ''
        async for res in generator:
            if await raw_request.is_disconnected():
                # Abort the request if the client disconnects.
                await VariableInterface.async_engine.stop_session(request.session_id)
                return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')
            final_res = res
            text += res.response
            cache_block_ids.append(res.cache_block_ids)
            remote_token_ids.append(res.token_ids)
            if res.token_ids:
                final_token_ids.extend(res.token_ids)
            if res.logprobs:
                final_logprobs.extend(res.logprobs)

        logprobs = None
        if request.logprobs and len(final_logprobs):
            logprobs, _, _, _ = _create_completion_logprobs(
                VariableInterface.async_engine.tokenizer,
                final_token_ids,
                final_logprobs,
                gen_config.skip_special_tokens,
                spaces_between_special_tokens=gen_config.spaces_between_special_tokens)

        assert final_res is not None
        choice_data = CompletionResponseChoice(index=i,
                                               text=text,
                                               finish_reason=final_res.finish_reason,
                                               logprobs=logprobs,
                                               gen_tokens=final_token_ids if request.return_token_ids else None)
        choices[i] = choice_data

        if with_cache:
            cache_block_ids = cache_block_ids[0]
            remote_token_ids = [remote_token_ids[0][-1]]

        total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])
        usage.prompt_tokens += final_res.input_token_len
        usage.completion_tokens += final_res.generate_token_len
        usage.total_tokens += total_tokens

    await asyncio.gather(*[_inner_call(i, generators[i]) for i in range(len(generators))])

    response = CompletionResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        choices=choices,
        usage=usage,
    ).model_dump()

    if with_cache:
        response['cache_block_ids'] = cache_block_ids
        response['remote_token_ids'] = remote_token_ids

    return response


@router.post('/generate', dependencies=[Depends(validate_json_request)])
async def generate(request: GenerateReqInput, raw_request: Request = None):
    error_check_ret = check_request(request)
    if error_check_ret is not None:
        return error_check_ret
    session = VariableInterface.get_session(request.session_id)

    prompt = request.prompt
    input_ids = request.input_ids
    image_data = request.image_data
    if image_data is not None:
        # convert to openai format
        image_input = []
        if not isinstance(image_data, list):
            image_data = [image_data]
        for img in image_data:
            if isinstance(img, str):
                image_input.append(dict(type='image_url', image_url=dict(url=img)))
            else:
                image_input.append(dict(type='image_url', image_url=img))
        text_input = dict(type='text', text=prompt if prompt else input_ids)
        prompt = [dict(role='user', content=[text_input] + image_input)]
        input_ids = None

    gen_config = GenerationConfig(
        max_new_tokens=request.max_tokens,
        do_sample=True,
        logprobs=1 if request.return_logprob else None,
        top_k=request.top_k,
        top_p=request.top_p,
        min_p=request.min_p,
        temperature=request.temperature,
        repetition_penalty=request.repetition_penalty,
        ignore_eos=request.ignore_eos,
        stop_words=request.stop,
        stop_token_ids=request.stop_token_ids,
        skip_special_tokens=request.skip_special_tokens,
        spaces_between_special_tokens=request.spaces_between_special_tokens,
        include_stop_str_in_output=request.include_stop_str_in_output,
        return_routed_experts=request.return_routed_experts,
        repetition_ngram_size=request.repetition_ngram_size,
        repetition_ngram_threshold=request.repetition_ngram_threshold,
    )

    result_generator = VariableInterface.async_engine.generate(
        messages=prompt,
        session_id=session,
        input_ids=input_ids,
        gen_config=gen_config,
        stream_response=True,  # always use stream to enable batching
        sequence_start=True,
        sequence_end=True,
        do_preprocess=False,
        media_io_kwargs=request.media_io_kwargs,
        mm_processor_kwargs=request.mm_processor_kwargs)

    def create_generate_response_json(res, text, output_ids, logprobs, finish_reason, routed_experts=None):
        # only output router experts in last chunk
        routed_experts = None if finish_reason is None else routed_experts
        meta = GenerateReqMetaOutput(finish_reason=dict(type=finish_reason) if finish_reason else None,
                                     output_token_logprobs=logprobs or None,
                                     prompt_tokens=res.input_token_len,
                                     routed_experts=routed_experts,
                                     completion_tokens=res.generate_token_len)

        response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta, routed_experts=routed_experts)
        return response.model_dump_json()

    async def generate_stream_generator():
        async for res in result_generator:
            text = res.response or ''
            output_ids = res.token_ids
            routed_experts = res.routed_experts
            logprobs = []
            if res.logprobs:
                for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
                    logprobs.append((tok_logprobs[tok], tok))
            response_json = create_generate_response_json(res,
                                                          text,
                                                          output_ids,
                                                          logprobs,
                                                          res.finish_reason,
                                                          routed_experts=routed_experts)
            yield f'data: {response_json}\n\n'
        yield 'data: [DONE]\n\n'

    if request.stream:
        return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')

    response = None

    async def _inner_call():
        text = ''
        output_ids = []
        logprobs = []
        async for res in result_generator:
            if await raw_request.is_disconnected():
                # Abort the request if the client disconnects.
                await session.async_abort()
                return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')
            text += res.response or ''
            output_ids.extend(res.token_ids or [])
            if res.logprobs:
                for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
                    logprobs.append((tok_logprobs[tok], tok))
        nonlocal response
        meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None,
                                     output_token_logprobs=logprobs or None,
                                     prompt_tokens=res.input_token_len,
                                     routed_experts=res.routed_experts,
                                     completion_tokens=res.generate_token_len)
        response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta)

    await _inner_call()
    return response


@router.post('/v1/embeddings', tags=['unsupported'])
async def create_embeddings(request: EmbeddingsRequest, raw_request: Request = None):
    """Creates embeddings for the text."""
    return create_error_response(HTTPStatus.BAD_REQUEST, 'Unsupported by turbomind.')


@router.post('/v1/encode', dependencies=[Depends(validate_json_request)])
async def encode(request: EncodeRequest, raw_request: Request = None):
    """Encode prompts.

    The request should be a JSON object with the following fields:

    - **input**: the prompt to be encoded. In str or list[str] format.
    - **do_preprocess**: whether do preprocess or not. Default to False.
    - **add_bos**: True when it is the beginning of a conversation. False when it
      is not. Default to True.
    """

    def encode(prompt: str, do_preprocess: bool, add_bos: bool):
        if do_preprocess:
            prompt = VariableInterface.async_engine.chat_template.get_prompt(prompt, sequence_start=add_bos)
        input_ids = VariableInterface.async_engine.tokenizer.encode(prompt, add_bos=add_bos)
        return input_ids

    if isinstance(request.input, str):
        encoded = encode(request.input, request.do_preprocess, request.add_bos)
        return EncodeResponse(input_ids=encoded, length=len(encoded))
    else:
        encoded, length = [], []
        for prompt in request.input:
            ids = encode(prompt, request.do_preprocess, request.add_bos)
            encoded.append(ids)
            length.append(len(ids))
        return EncodeResponse(input_ids=encoded, length=length)


@router.post('/pooling', dependencies=[Depends(validate_json_request)])
async def pooling(request: PoolingRequest, raw_request: Request = None):
    """Pooling prompts for reward model.

    In vLLM documentation, https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#pooling-api_1,
    the input format of Pooling API is the same as Embeddings API.

    Go to https://platform.openai.com/docs/api-reference/embeddings/create
    for the Embeddings API specification.

    The request should be a JSON object with the following fields:

    - **model** (str): model name. Available from /v1/models.
    - **input** (list[int] | list[list[int]] | str | list[str]): input text to be embed
    """

    async_engine = VariableInterface.async_engine

    request_input = request.input
    model_name = request.model or async_engine.model_name

    # Normalize all inputs to be a batch (List[List[int]])
    if isinstance(request_input, str):
        input_ids = [async_engine.tokenizer.encode(request_input)]
    elif isinstance(request_input, list):
        if not request_input:
            return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list cannot be empty.')
        if isinstance(request_input[0], str):  # list[str]
            input_ids = [async_engine.tokenizer.encode(p) for p in request_input]
        elif isinstance(request_input[0], int):  # list[int]
            input_ids = [request_input]
        elif isinstance(request_input[0], list):  # list[list[int]]
            input_ids = request_input
        else:
            return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list contains an invalid type.')
    else:
        return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must be a string or a list.')

    batch_scores = await async_engine.async_get_reward_score(input_ids)
    prompt_tokens = sum(len(ids) for ids in input_ids)
    usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens)

    data = []
    for i, score in enumerate(batch_scores):
        data.append({
            'index': i,
            'object': 'pooling',
            'data': score,
        })

    response = PoolingResponse(model=model_name, data=data, usage=usage)
    return response.model_dump()


@router.post('/update_weights', dependencies=[Depends(validate_json_request)])
def update_params(request: UpdateParamsRequest, raw_request: Request = None):
    """Update weights for the model."""
    VariableInterface.async_engine.engine.update_params(request)
    return JSONResponse(content=None)


@router.post('/sleep', dependencies=[Depends(validate_json_request)])
async def sleep(raw_request: Request = None):
    level = raw_request.query_params.get('level', '1')
    VariableInterface.async_engine.sleep(int(level))
    return Response(status_code=200)


@router.post('/wakeup', dependencies=[Depends(validate_json_request)])
async def wakeup(raw_request: Request = None):
    tags = raw_request.query_params.getlist('tags')
    tags = tags or None
    VariableInterface.async_engine.wakeup(tags)
    return Response(status_code=200)


@router.get('/is_sleeping')
async def is_sleeping():
    is_sleeping = VariableInterface.async_engine.is_sleeping
    return JSONResponse(content={'is_sleeping': is_sleeping})


""" PD Disaggregation API Begin """


@router.get('/distserve/engine_info')
async def engine_info():
    engine_config = VariableInterface.async_engine.backend_config

    response = DistServeEngineConfig(tp_size=engine_config.tp,
                                     dp_size=engine_config.dp,
                                     pp_size=None,
                                     ep_size=engine_config.ep,
                                     dp_rank=engine_config.dp_rank,
                                     block_size=engine_config.block_size,
                                     num_cpu_blocks=engine_config.num_cpu_blocks,
                                     num_gpu_blocks=engine_config.num_gpu_blocks)

    return response.model_dump_json()


@router.post('/distserve/p2p_initialize')
async def p2p_initialize(init_request: DistServeInitRequest):
    return VariableInterface.async_engine.p2p_initialize(init_request)


@router.post('/distserve/p2p_connect')
async def p2p_connect(conn_request: DistServeConnectionRequest):
    return VariableInterface.async_engine.p2p_connect(conn_request)


@router.post('/distserve/p2p_drop_connect')
async def p2p_drop_connect(drop_conn_request: DistServeDropConnectionRequest):
    return VariableInterface.async_engine.p2p_drop_connect(drop_conn_request)


@router.post('/distserve/free_cache')
async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONResponse:
    session_id = cache_free_request.remote_session_id
    VariableInterface.async_engine.free_cache(session_id)
    return {'status': 'SUCCESS'}


""" PD Disaggregation API End """


@router.post('/abort_request')
async def abort_request(request: AbortRequest, raw_request: Request = None):
    """Abort an ongoing request."""
    if not VariableInterface.enable_abort_handling:
        return Response(
            status_code=501,
            content='This server does not support abort requests. Enable with --enable-abort-handling flag.')

    if request.abort_all:
        await VariableInterface.async_engine.stop_all_session()
    else:
        session = VariableInterface.get_session(request.session_id)
        await session.async_abort()
    return Response(status_code=200)


@router.post('/v1/chat/interactive', dependencies=[Depends(validate_json_request)], include_in_schema=False)
async def chat_interactive_v1(request, raw_request: Request = None):
    return create_error_response(
        HTTPStatus.BAD_REQUEST, 'v1/chat/interactive is deprecated, please launch server with --enable-prefix-cache '
        'and use /v1/chat/completions instead.')


def handle_torchrun():
    """To disable mmengine logging logic when using torchrun."""

    def dummy_get_device_id():
        return 0

    if int(os.environ.get('LOCAL_RANK', -1)) > 0:
        from lmdeploy.vl.model.utils import _set_func

        # the replacement can't be recovered
        _set_func('mmengine.logging.logger._get_device_id', dummy_get_device_id)


@router.on_event('startup')
async def startup_event():
    async_engine = VariableInterface.async_engine
    async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True)

    if VariableInterface.proxy_url is None:
        return
    elif getattr(async_engine.engine, 'is_dummy', False):
        logger.info('Dummy node started')
        return
    try:
        import requests
        engine_config = VariableInterface.async_engine.backend_config
        engine_role = engine_config.role.value if hasattr(engine_config, 'role') else 1
        url = f'{VariableInterface.proxy_url}/nodes/add'
        data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}}
        headers = {'accept': 'application/json', 'Content-Type': 'application/json'}
        response = requests.post(url, headers=headers, json=data)

        if response.status_code != 200:
            raise HTTPException(status_code=response.status_code, detail=response.text)
    except Exception as e:
        logger.error(f'Service registration failed: {e}')


@router.on_event('shutdown')
async def shutdown_event():
    async_engine = VariableInterface.async_engine
    if async_engine is not None:
        async_engine.close()


async def validation_exception_handler(request: Request, exc: RequestValidationError):
    """Handler for RequestValidationError."""
    return JSONResponse(
        status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
        content=jsonable_encoder({
            'detail': exc.errors(),
            'body': exc.body
        }),
    )


class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):

    def __init__(self, app: FastAPI, max_concurrent_requests: int):
        super().__init__(app)
        self.semaphore = asyncio.Semaphore(max_concurrent_requests)

    async def dispatch(self, request: Request, call_next):
        async with self.semaphore:
            response = await call_next(request)
            return response


def set_parsers(reasoning_parser: str | None = None, tool_parser: str | None = None):
    """Set tool parser and reasoning parsers."""
    # set reasoning parser
    if reasoning_parser is not None:
        if reasoning_parser in ReasoningParserManager.module_dict:
            tokenizer = VariableInterface.async_engine.tokenizer
            VariableInterface.reasoning_parser = ReasoningParserManager.get(reasoning_parser)(tokenizer)
        else:
            raise ValueError(
                f'The reasoning parser {reasoning_parser} is not in the parser list: {ReasoningParserManager.module_dict.keys()}'  # noqa
            )
    # set tool parsers
    if tool_parser is not None:
        if tool_parser in ToolParserManager.module_dict:
            tokenizer = VariableInterface.async_engine.tokenizer
            VariableInterface.tool_parser = ToolParserManager.get(tool_parser)(tokenizer)
        else:
            raise ValueError(
                f'The reasoning parser {tool_parser} is not in the parser list: {ToolParserManager.module_dict.keys()}'  # noqa
            )


def mount_metrics(app: FastAPI, backend_config: PytorchEngineConfig | TurbomindEngineConfig):
    if not getattr(backend_config, 'enable_metrics', False):
        return

    from prometheus_client import REGISTRY, make_asgi_app
    registry = REGISTRY

    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount('/metrics', make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P.*)$')
    app.routes.append(metrics_route)


def create_lifespan_handler(backend_config: PytorchEngineConfig | TurbomindEngineConfig, async_engine: AsyncEngine):
    """Factory function to create a lifespan handler."""

    @asynccontextmanager
    async def lifespan_handler(app: FastAPI):
        task = None
        try:
            if getattr(backend_config, 'enable_metrics', False):
                metrics_processor.start_metrics_handler(enable_metrics=True)
                log_interval = 10.

                async def _force_log():
                    while True:
                        await asyncio.sleep(log_interval)

                        # periodically update schedule metrics, as they change less frequently than iteration stats
                        schedule_metrics = async_engine.get_schedule_metrics()
                        await metrics_processor.update_schedule_stats(schedule_metrics)

                        await async_engine.do_log_stats()

                task = asyncio.create_task(_force_log())

            yield
        finally:
            if task:
                task.cancel()
            await metrics_processor.stop_metrics_handler()

    return lifespan_handler


def serve(model_path: str,
          model_name: str | None = None,
          backend: Literal['turbomind', 'pytorch'] = 'turbomind',
          backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None,
          chat_template_config: ChatTemplateConfig | None = None,
          server_name: str = '0.0.0.0',
          server_port: int = 23333,
          allow_origins: list[str] = ['*'],
          allow_credentials: bool = True,
          allow_methods: list[str] = ['*'],
          allow_headers: list[str] = ['*'],
          log_level: str = 'ERROR',
          api_keys: list[str] | str | None = None,
          ssl: bool = False,
          proxy_url: str | None = None,
          max_log_len: int | None = None,
          disable_fastapi_docs: bool = False,
          max_concurrent_requests: int | None = None,
          reasoning_parser: str | None = None,
          tool_call_parser: str | None = None,
          allow_terminate_by_client: bool = False,
          enable_abort_handling: bool = False,
          speculative_config: SpeculativeConfig | None = None,
          **kwargs):
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of a model.
            It could be one of the following options:
                - i) A local directory path of a turbomind model which is
                    converted by `lmdeploy convert` command or download from
                    ii) and iii).
                - ii) The model_id of a lmdeploy-quantized model hosted
                    inside a model repo on huggingface.co, such as
                    "InternLM/internlm-chat-20b-4bit",
                    "lmdeploy/llama2-chat-70b-4bit", etc.
                - iii) The model_id of a model hosted inside a model repo
                    on huggingface.co, such as "internlm/internlm-chat-7b",
                    "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                    and so on.
        model_name (str): the name of the served model. It can be accessed
            by the RESTful API `/v1/models`. If it is not specified,
            `model_path` will be adopted
        backend (str): either `turbomind` or `pytorch` backend. Default to
            `turbomind` backend.
        backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
            config instance. Default to none.
        chat_template_config (ChatTemplateConfig): chat template configuration
            Default to None.
        server_name (str): host ip for serving
        server_port (int): server port
        tp (int): tensor parallel
        allow_origins (List[str]): a list of allowed origins for CORS
        allow_credentials (bool): whether to allow credentials for CORS
        allow_methods (List[str]): a list of allowed HTTP methods for CORS
        allow_headers (List[str]): a list of allowed HTTP headers for CORS
        log_level(str): set log level whose value among [CRITICAL, ERROR,
            WARNING, INFO, DEBUG]
        api_keys (List[str] | str | None): Optional list of API keys. Accepts
            string type as a single api_key. Default to None, which means no
            api key applied.
        ssl (bool): Enable SSL. Requires OS Environment variables
            'SSL_KEYFILE' and 'SSL_CERTFILE'.
        proxy_url (str): The proxy url to register the api_server.
        max_log_len (int): Max number of prompt characters or prompt tokens
            being printed in log. Default: Unlimited
        max_concurrent_requests: This refers to the number of concurrent
            requests that the server can handle. The server is designed to
            process the engine’s tasks once the maximum number of concurrent
            requests is reached, regardless of any additional requests sent by
            clients concurrently during that time. Default to None.
        reasoning_parser (str): The reasoning parser name.
        tool_call_parser (str): The tool call parser name.
        allow_terminate_by_client (bool): Allow request from client to terminate server.
    """
    if os.getenv('TM_LOG_LEVEL') is None:
        os.environ['TM_LOG_LEVEL'] = log_level
    logger.setLevel(log_level)

    VariableInterface.allow_terminate_by_client = allow_terminate_by_client
    VariableInterface.enable_abort_handling = enable_abort_handling

    ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http'
    if ssl:
        ssl_keyfile = os.environ['SSL_KEYFILE']
        ssl_certfile = os.environ['SSL_CERTFILE']
        http_or_https = 'https'

    handle_torchrun()
    _, pipeline_class = get_task(backend, model_path)
    if isinstance(backend_config, PytorchEngineConfig):
        backend_config.enable_mp_engine = True
        # router replay
        if backend_config.enable_return_routed_experts:
            backend_config.enable_transfer_obj_ref = True
    VariableInterface.async_engine = pipeline_class(model_path=model_path,
                                                    model_name=model_name,
                                                    backend=backend,
                                                    backend_config=backend_config,
                                                    chat_template_config=chat_template_config,
                                                    max_log_len=max_log_len,
                                                    speculative_config=speculative_config,
                                                    **kwargs)
    # set reasoning parser and tool parser
    set_parsers(reasoning_parser, tool_call_parser)

    # create FastAPI lifespan events
    lifespan = create_lifespan_handler(backend_config, VariableInterface.async_engine)

    if disable_fastapi_docs:
        app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None, lifespan=lifespan)
    else:
        app = FastAPI(docs_url='/', lifespan=lifespan)

    app.include_router(router)
    app.add_exception_handler(RequestValidationError, validation_exception_handler)
    mount_metrics(app, backend_config)

    if allow_origins:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=allow_origins,
            allow_credentials=allow_credentials,
            allow_methods=allow_methods,
            allow_headers=allow_headers,
        )

    if api_keys is not None and (tokens := [key for key in api_keys if key]):
        from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware

        app.add_middleware(AuthenticationMiddleware, tokens=tokens)

    # set the maximum number of concurrent requests
    if max_concurrent_requests is not None:
        app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests)

    if proxy_url is not None:
        VariableInterface.proxy_url = proxy_url
        VariableInterface.api_server_url = f'{http_or_https}://{server_name}:{server_port}'  # noqa
    for i in range(3):
        print(f'HINT:    Please open \033[93m\033[1m{http_or_https}://'
              f'{server_name}:{server_port}\033[0m in a browser for detailed api'
              ' usage!!!')
    uvicorn.run(app=app,
                host=server_name,
                port=server_port,
                log_level=os.getenv('UVICORN_LOG_LEVEL', 'info').lower(),
                ssl_keyfile=ssl_keyfile,
                ssl_certfile=ssl_certfile,
                timeout_keep_alive=int(os.environ.get('UVICORN_TIMEOUT_KEEP_ALIVE', 5)))


if __name__ == '__main__':
    import fire

    fire.Fire(serve)


================================================
FILE: lmdeploy/serve/openai/harmony_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/vllm-project/vllm/blob/v0.10.2rc1/vllm/entrypoints/harmony_utils.py
from typing import List

import shortuuid
from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding

from lmdeploy.serve.openai.protocol import (ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall,
                                            ToolCall)

_harmony_encoding = None


def get_encoding():
    global _harmony_encoding
    if _harmony_encoding is None:
        _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
    return _harmony_encoding


def get_streamable_parser_for_assistant() -> 'StreamableParser':
    return StreamableParser(get_encoding(), role=Role.ASSISTANT)


class GptOssChatParser:

    def __init__(self):
        self.parser = get_streamable_parser_for_assistant()

    def parse_streaming(self, tokens: List[int]) -> DeltaMessage:
        parser = self.parser
        delta_message = DeltaMessage(role='assistant')
        content = ''
        reasoning_content = ''
        tool_calls = []
        delta_tool_call = None
        for token in tokens:
            prev_recipient = parser.current_recipient
            parser.process(token)
            cur_channel = parser.current_channel
            cur_recipient = parser.current_recipient
            delta_text = parser.last_content_delta or ''
            if cur_channel == 'final':
                content += delta_text
            elif cur_channel == 'analysis':
                reasoning_content += delta_text
            elif cur_channel == 'commentary' and cur_recipient and cur_recipient.startswith('functions.'):
                base_index = 0
                for msg in parser.messages:
                    if msg.channel == 'commentary' and msg.recipient and msg.recipient.startswith('functions.'):
                        base_index += 1
                if prev_recipient != cur_recipient:
                    if delta_tool_call is not None:
                        tool_calls.append(delta_tool_call)
                    tool_name = cur_recipient.split('functions.', 1)[1]
                    delta_tool_call = DeltaToolCall(id=f'chatcmpl-tool-{shortuuid.random()}',
                                                    type='function',
                                                    index=base_index,
                                                    function=DeltaFunctionCall(name=tool_name, arguments=''))
                elif delta_text:
                    # Continuing the same tool call. Ensure we don't duplicate the
                    # very first delta string in this chunk. Previously we initialized
                    # with arguments=delta_text and then appended again, causing
                    # duplicated content like "locationlocation".
                    if delta_tool_call is None:
                        # We are in the middle of a tool call carried over from the
                        # previous chunk. Initialize an empty arguments buffer.
                        delta_tool_call = DeltaToolCall(index=base_index, function=DeltaFunctionCall(arguments=''))
                    delta_tool_call.function.arguments += delta_text

        if delta_tool_call:
            tool_calls.append(delta_tool_call)

        delta_message.content = content if content else None
        delta_message.reasoning_content = reasoning_content if reasoning_content else None
        delta_message.tool_calls = tool_calls
        return delta_message

    def parse_full(self, tokens: List[int]) -> ChatMessage:
        delta_message = self.parse_streaming(tokens)
        tool_calls = []
        for delta_tool_call in delta_message.tool_calls:
            function = FunctionCall(**delta_tool_call.function.model_dump())
            tool_calls.append(ToolCall(id=delta_tool_call.id, type=delta_tool_call.type, function=function))
        chat_message = ChatMessage(role='assistant',
                                   content=delta_message.content,
                                   tool_calls=tool_calls,
                                   reasoning_content=delta_message.reasoning_content)
        return chat_message


================================================
FILE: lmdeploy/serve/openai/launch_server.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import multiprocessing as mp
import os
import random
import signal
import socket
import sys
from typing import List, Union

from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.utils import get_logger

from .api_server import serve

logger = get_logger('lmdeploy')


def find_available_ports(num: int) -> List[int]:
    """Find available port."""

    def __is_port_ok(port: int):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                s.bind(('127.0.0.1', port))
                s.listen(1)
                return True
            except Exception:
                return False

    ports = []
    test_port = 3000
    while len(ports) < num:
        test_port += random.randint(10, 500)
        if __is_port_ok(test_port):
            ports.append(test_port)

    return ports


def get_host_ip():
    """Get host ip."""
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
        s.connect(('8.8.8.8', 0))
        ip = s.getsockname()[0]
        return ip


def _run_server(gpu_ids: List[int], model_path: str, **kwargs):
    """Launch a server process."""
    cuda_visible_devices = ','.join([str(_) for _ in gpu_ids])
    os.setpgrp()
    if len(gpu_ids) > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
    serve(model_path, **kwargs)


def cleanup_processes(processes: List[mp.Process]):
    """Clean up server process."""
    for process in processes:
        logger.info(f'Terminating process group {process.pid}')
        try:
            os.killpg(process.pid, signal.SIGTERM)
        except ProcessLookupError:
            # Process group may already be terminated
            pass

    # Wait for processes to terminate
    for process in processes:
        process.join(timeout=15)
        if process.is_alive():
            logger.warning(f'Process {process.pid} did not terminate gracefully, forcing kill')
            try:
                os.killpg(process.pid, signal.SIGKILL)
            except ProcessLookupError:
                pass

    logger.info('All processes terminated')
    sys.exit(0)


def launch_server(num_nodes: int,
                  node_rank: int,
                  model_path: str,
                  backend_config: Union[PytorchEngineConfig, TurbomindEngineConfig],
                  proxy_url: str = None,
                  **kwargs):
    """Run multiple server processes in dp mode."""
    assert proxy_url is not None, 'Please launch proxy server and pass proxy_url'
    log_level = kwargs.get('log_level', 'ERROR')
    logger.setLevel(log_level)

    mp.set_start_method('spawn', force=True)
    dp = backend_config.dp
    tp = backend_config.tp
    ep = backend_config.ep
    assert dp > 1, f'only support dp > 1, but give dp={dp}'
    assert tp > 1 or ep > 1, f'only support tp > 1 or ep > 1, but given tp={tp} ep={ep}'

    num_devices = max(dp, tp, ep)
    dp_per_node = dp // num_nodes
    tp_per_dp = num_devices // dp
    http_or_https = 'https' if kwargs.get('ssl', False) else 'http'
    model_name = kwargs.get('model_name', None)
    if model_name is None:
        model_name = model_path
    server_name = get_host_ip()
    server_urls = []
    processes = []

    server_port_li = find_available_ports(dp_per_node)

    for idx in range(dp_per_node):
        backend_config_dp = copy.deepcopy(backend_config)
        dp_rank = node_rank * dp_per_node + idx
        gpu_ids_per_dp = [gid for gid in range(idx * tp_per_dp, (idx + 1) * tp_per_dp)]
        backend_config_dp.dp_rank = dp_rank
        server_port = server_port_li[idx]

        cur_server_kwargs = dict()
        cur_server_kwargs.update(kwargs)
        cur_server_kwargs['server_name'] = server_name
        cur_server_kwargs['server_port'] = server_port
        cur_server_kwargs['backend_config'] = backend_config_dp
        cur_server_kwargs['proxy_url'] = proxy_url
        url = f'{http_or_https}://{server_name}:{server_port}'
        server_urls.append(url)
        logger.info(f'create server with url={url}')
        logger.info(f'backend_config_dp={backend_config_dp} gpus={gpu_ids_per_dp}')
        proc = mp.Process(target=_run_server, args=(gpu_ids_per_dp, model_path), kwargs=cur_server_kwargs)
        proc.start()
        processes.append(proc)

    # bind signal
    signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(processes))
    signal.signal(signal.SIGTERM, lambda sig, frame: cleanup_processes(processes))
    signal.signal(signal.SIGQUIT, lambda sig, frame: cleanup_processes(processes))

    for p in processes:
        p.join()


================================================
FILE: lmdeploy/serve/openai/protocol.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Any, Dict, List, Literal, Optional, Union

import shortuuid
from pydantic import BaseModel, ConfigDict, Field


class ErrorResponse(BaseModel):
    """Error responses."""
    message: str
    type: str
    code: int
    param: Optional[str] = None
    object: str = 'error'


class ModelPermission(BaseModel):
    """Model permissions."""
    id: str = Field(default_factory=lambda: f'modelperm-{shortuuid.random()}')
    object: str = 'model_permission'
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = True
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = '*'
    group: Optional[str] = None
    is_blocking: bool = False


class ModelCard(BaseModel):
    """Model cards."""
    id: str
    object: str = 'model'
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = 'lmdeploy'
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: List[ModelPermission] = []


class ModelList(BaseModel):
    """Model list consists of model cards."""
    object: str = 'list'
    data: List[ModelCard] = []


class UsageInfo(BaseModel):
    """Usage information."""
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0


class Function(BaseModel):
    """Function descriptions."""
    description: Optional[str] = Field(default=None, examples=[None])
    name: str
    parameters: Optional[Dict[str, Any]] = None


class Tool(BaseModel):
    """Function wrapper."""
    type: str = Field(default='function', examples=['function'])
    function: Function


class ToolChoiceFuncName(BaseModel):
    """The name of tool choice function."""
    name: str


class ToolChoice(BaseModel):
    """The tool choice definition."""
    function: ToolChoiceFuncName
    type: Literal['function'] = Field(default='function', examples=['function'])


class StreamOptions(BaseModel):
    """The stream options."""
    include_usage: Optional[bool] = False


class JsonSchema(BaseModel):
    name: str
    # description is not used since it depends on model
    description: Optional[str] = None
    # `schema` is a reserved field in Pydantic BaseModel
    # use alias since pydantic does not support the OpenAI key `schema`
    json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema', examples=[None])
    # strict is not used
    strict: Optional[bool] = False
    model_config = ConfigDict(serialize_by_alias=True)


class ResponseFormat(BaseModel):
    # regex_schema is extended by lmdeploy to support regex output
    type: Literal['text', 'json_object', 'json_schema', 'regex_schema']
    json_schema: Optional[JsonSchema] = None
    regex_schema: Optional[str] = None


class ChatCompletionRequest(BaseModel):
    """Chat completion request."""
    model: str

    messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]])
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
    tools: Optional[List[Tool]] = Field(default=None, examples=[None])
    tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto', examples=['none'])
    logprobs: Optional[bool] = False
    top_logprobs: Optional[int] = None
    n: Optional[int] = 1
    logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None])
    max_completion_tokens: Optional[int] = Field(
        default=None,
        examples=[None],
        description=('An upper bound for the number of tokens that can be generated for a completion, '
                     'including visible output tokens and reasoning tokens'),
    )
    max_tokens: Optional[int] = Field(
        default=None,
        examples=[None],
        deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',
    )
    stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])

    stream: Optional[bool] = False
    stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    user: Optional[str] = None
    reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None
    response_format: Optional[ResponseFormat] = Field(default=None, examples=[None])
    # additional argument of lmdeploy
    do_preprocess: Optional[bool] = True
    repetition_penalty: Optional[float] = 1.0
    session_id: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    skip_special_tokens: Optional[bool] = True
    spaces_between_special_tokens: Optional[bool] = True
    top_k: Optional[int] = 40
    seed: Optional[int] = None
    min_new_tokens: Optional[int] = Field(default=None, examples=[None])
    min_p: float = 0.0
    enable_thinking: Optional[bool] = None  # will be deprecated in the future
    return_token_ids: Optional[bool] = False
    include_stop_str_in_output: Optional[bool] = False
    # kwargs for chat template renderer
    chat_template_kwargs: dict[str, Any] | None = Field(
        default=None,
        description=('Additional keyword args to pass to the template renderer. '
                     'Will be accessible by the chat template.'),
    )
    # kwargs for media IO
    media_io_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=('Additional kwargs to pass to the media IO processing, keyed by modality.'),
    )
    # kwargs for hf processor
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=('Additional kwargs to pass to the HF processor'),
    )


class FunctionCall(BaseModel):
    """Function response."""
    name: str
    arguments: str


class ToolCall(BaseModel):
    """Tool call response."""
    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')
    type: Literal['function'] = 'function'
    function: FunctionCall


class ExtractedToolCallInformation(BaseModel):
    # modified from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/protocol.py#L1199
    # indicate if tools were called
    tools_called: bool
    # extracted tool calls
    tool_calls: List[ToolCall]
    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


class ChatMessage(BaseModel):
    """Chat messages."""
    role: str
    content: Optional[str] = None
    gen_tokens: Optional[List[int]] = None
    reasoning_content: Optional[str] = Field(default=None, examples=[None])
    tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])


class LogProbs(BaseModel):
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None


class TopLogprob(BaseModel):
    token: str
    bytes: Optional[List[int]] = None
    logprob: float


class ChatCompletionTokenLogprob(BaseModel):
    token: str
    bytes: Optional[List[int]] = None
    logprob: float
    top_logprobs: List[TopLogprob]


class ChoiceLogprobs(BaseModel):
    content: Optional[List[ChatCompletionTokenLogprob]] = None


class ChatCompletionResponseChoice(BaseModel):
    """Chat completion response choices."""
    index: int
    message: ChatMessage
    logprobs: Optional[ChoiceLogprobs] = None
    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class ChatCompletionResponse(BaseModel):
    """Chat completion response."""
    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')
    object: str = 'chat.completion'
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo


class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(BaseModel):
    id: str = Field(default_factory=lambda: f'chatcmpl-tool-{shortuuid.random()}')
    type: Literal['function'] = 'function'
    index: int
    function: Optional[DeltaFunctionCall] = None


class DeltaMessage(BaseModel):
    """Delta messages."""
    role: Optional[str] = None
    content: Optional[str] = None
    reasoning_content: Optional[str] = None
    gen_tokens: Optional[List[int]] = None
    tool_calls: List[DeltaToolCall] = Field(default_factory=list)


class ChatCompletionResponseStreamChoice(BaseModel):
    """Chat completion response stream choice."""
    index: int
    delta: DeltaMessage
    logprobs: Optional[ChoiceLogprobs] = None
    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class ChatCompletionStreamResponse(BaseModel):
    """Chat completion stream response."""
    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')
    object: str = 'chat.completion.chunk'
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
    usage: Optional[UsageInfo] = None


class CompletionRequest(BaseModel):
    """Completion request."""
    model: str
    prompt: Union[str, List[Any]]
    suffix: Optional[str] = None
    temperature: Optional[float] = 0.7
    n: Optional[int] = 1
    logprobs: Optional[int] = None
    max_completion_tokens: Optional[int] = Field(
        default=None,
        examples=[None],
        description=('An upper bound for the number of tokens that can be generated for a completion, '
                     'including visible output tokens and reasoning tokens'),
    )
    max_tokens: Optional[int] = Field(
        default=16,
        examples=[16],
        deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',
    )
    stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])
    stream: Optional[bool] = False
    stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])
    top_p: Optional[float] = 1.0
    echo: Optional[bool] = False
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    user: Optional[str] = None
    # additional argument of lmdeploy
    repetition_penalty: Optional[float] = 1.0
    session_id: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    skip_special_tokens: Optional[bool] = True
    spaces_between_special_tokens: Optional[bool] = True
    top_k: Optional[int] = 40  # for opencompass
    seed: Optional[int] = None
    min_p: float = 0.0
    return_token_ids: Optional[bool] = False


class CompletionResponseChoice(BaseModel):
    """Completion response choices."""
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    gen_tokens: Optional[List[int]] = None
    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class CompletionResponse(BaseModel):
    """Completion response."""
    id: str = Field(default_factory=lambda: f'cmpl-{shortuuid.random()}')
    object: str = 'text_completion'
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
    """Completion response stream choice."""
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    gen_tokens: Optional[List[int]] = None
    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class CompletionStreamResponse(BaseModel):
    """Completion stream response."""
    id: str = Field(default_factory=lambda: f'cmpl-{shortuuid.random()}')
    object: str = 'text_completion'
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
    usage: Optional[UsageInfo] = None


class EmbeddingsRequest(BaseModel):
    """Embedding request."""
    model: str = None
    input: Union[str, List[str]]
    user: Optional[str] = None


class EmbeddingsResponse(BaseModel):
    """Embedding response."""
    object: str = 'list'
    data: List[Dict[str, Any]]
    model: str
    usage: UsageInfo


class PoolingRequest(BaseModel):
    """Pooling request.

    Currently we follow vLLM API protocol,
    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174

    Notice that ideally we should reuse the input format of embedding API
    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174
    https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L383
    """
    model: Optional[str] = None
    input: Union[List[int], List[List[int]], str, List[str]]
    encoding_format: Literal['float', 'base64'] = 'float'
    dimensions: Optional[int] = None
    user: Optional[str] = None


class PoolingResponse(BaseModel):
    """Pooling response."""
    id: str = Field(default_factory=lambda: f'pool-{shortuuid.random()}')
    object: str = 'list'
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str = None
    data: List[Dict[str, Any]]
    usage: UsageInfo


class EncodeRequest(BaseModel):
    """Encode request."""
    input: Union[str, List[str]]
    do_preprocess: Optional[bool] = False
    add_bos: Optional[bool] = True


class EncodeResponse(BaseModel):
    """Encode response."""
    input_ids: Union[List[int], List[List[int]]]
    length: Union[int, List[int]]


class GenerateResponse(BaseModel):
    """Generate response."""
    text: str
    tokens: int
    input_tokens: int
    history_tokens: int
    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class UpdateParamsRequest(BaseModel):
    """Update weights request."""
    serialized_named_tensors: Union[str, List[str], Dict]
    load_format: Optional[str] = None  # 'flattened_bucket' or None
    finished: bool = False


# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...}
ImageDataInputItem = Union[str, Dict]
ImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]]


# /generate input
class GenerateReqInput(BaseModel):
    session_id: Optional[int] = -1
    prompt: Optional[str] = None
    input_ids: Optional[List[int]] = None
    image_data: Optional[ImageDataFormat] = None
    return_logprob: Optional[bool] = None
    max_tokens: int = 128
    stop: Optional[Union[str, List[str]]] = None
    stop_token_ids: Optional[List[int]] = None
    stream: Optional[bool] = False
    temperature: float = 1.0
    repetition_penalty: Optional[float] = 1.0
    ignore_eos: Optional[bool] = False
    top_p: float = 1.0
    top_k: int = 0
    min_p: float = 0.0
    skip_special_tokens: Optional[bool] = True
    spaces_between_special_tokens: Optional[bool] = True
    include_stop_str_in_output: Optional[bool] = False
    return_routed_experts: Optional[bool] = False
    repetition_ngram_size: int = 0
    repetition_ngram_threshold: int = 0
    # kwargs for media IO
    media_io_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=('Additional kwargs to pass to the media IO processing, keyed by modality.'),
    )
    # kwargs for hf processor
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=('Additional kwargs to pass to the HF processor'),
    )


class GenerateReqMetaOutput(BaseModel):
    prompt_tokens: Optional[int] = None
    completion_tokens: Optional[int] = None
    finish_reason: Optional[Dict[str, Any]] = None
    output_token_logprobs: Optional[List[tuple[float, int]]] = None  # (logprob, token_id)
    routed_experts: Optional[Union[List[List[List[int]]], str]] = None  # (num_token, num_layer, topk_expert)


# /generate output
class GenerateReqOutput(BaseModel):
    text: str
    output_ids: List[int]
    meta_info: GenerateReqMetaOutput


class AbortRequest(BaseModel):
    # Whether to abort all requests
    abort_all: bool = False
    # The finished reason data
    finished_reason: Optional[Dict[str, Any]] = None
    abort_message: Optional[str] = None
    # The session ID to abort. If `abort_all` is True, this field is ignored.
    session_id: Optional[int] = -1


================================================
FILE: lmdeploy/serve/openai/reasoning_parser/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .qwen_qwq_reasoning_parser import QwenQwQReasoningParser
from .reasoning_parser import ReasoningParser, ReasoningParserManager

__all__ = ['ReasoningParser', 'ReasoningParserManager', 'DeepSeekR1ReasoningParser', 'QwenQwQReasoningParser']


================================================
FILE: lmdeploy/serve/openai/reasoning_parser/deepseek_r1_reasoning_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers
import re
from typing import Optional, Sequence, Tuple, Union

from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage

from .reasoning_parser import ReasoningParser, ReasoningParserManager


@ReasoningParserManager.register_module(name='deepseek-r1')
class DeepSeekR1ReasoningParser(ReasoningParser):
    """Reasoning parser for DeepSeek R1 model.

    The DeepSeek R1 model uses ... tokens to denote reasoning text. This parser extracts the reasoning
    content from the model output.
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.think_start_token = ''
        self.think_end_token = ''

        self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL)

        if not self.model_tokenizer:
            raise ValueError('The model tokenizer must be passed to the ReasoningParser '
                             'constructor during construction.')

        self.think_start_token_id = self.vocab.get(self.think_start_token)
        self.think_end_token_id = self.vocab.get(self.think_end_token)
        if (self.think_start_token_id is None or self.think_end_token_id is None):
            raise RuntimeError('DeepSeek R1 reasoning parser could not locate think start/end '
                               'tokens in the tokenizer!')

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        **kwargs,
    ) -> Union[DeltaMessage, None]:
        """Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming.

        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information
        about what has previously been parsed and extracted (see constructor)
        """
        # Skip single special tokens
        if len(delta_token_ids) == 1:
            if delta_token_ids[0] == self.think_end_token_id:
                return DeltaMessage(content='')
            elif delta_token_ids[0] == self.think_start_token_id:
                return None

        # Check if  is present in previous or delta.
        # Keep compatibility with models that don't generate  tokens.
        if self.think_start_token_id in previous_token_ids:
            if self.think_end_token_id in delta_token_ids:
                #  in previous,  in delta,
                # extract reasoning content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            elif self.think_end_token_id in previous_token_ids:
                #  in previous,  in previous,
                return DeltaMessage(content=delta_text)
            else:
                #  in previous, no  in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.think_start_token_id in delta_token_ids:
            if self.think_end_token_id in delta_token_ids:
                #  in delta,  in delta, extract reasoning content
                start_index = delta_text.find(self.think_start_token)
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[start_index + len(self.think_start_token):end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            else:
                #  in delta, no  in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # No  in previous or delta, also need to check for .
            # Because the model may have generated  without 
            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
            if self.think_end_token_id in delta_token_ids:
                #  in delta with more tokens,
                # extract reasoning content and content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            elif self.think_end_token_id in previous_token_ids:
                #  in previous, thinking content ends
                return DeltaMessage(content=delta_text)
            else:
                # no  in previous or delta, reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)

    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,
                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:
        """Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Args:
            model_output (str): The model-generated string to extract reasoning content from.
            request (ChatCompletionRequest): he request object that was used to generate the model_output.

        Returns:
            reasoning_content (str | None): The reasoning content.
            final_output (str | None): The content.
        """
        # DeepSeek R1 doesn't generate  now.
        # Thus we assume the reasoning content is always at the start.
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
        if self.think_end_token not in model_output:
            return model_output, None
        else:
            # Add a start token if it's missing to keep compatibility.
            if self.think_start_token not in model_output:
                model_output = f'{self.think_start_token}{model_output}'
            # Use a regex to find the reasoning content
            reasoning_content = self.reasoning_regex.findall(model_output)[0]

            end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}')
            final_output = model_output[end_index:]

            if len(final_output) == 0:
                return reasoning_content, None

            return reasoning_content, final_output


================================================
FILE: lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Sequence, Tuple, Union

from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage

from .reasoning_parser import ReasoningParser, ReasoningParserManager


@ReasoningParserManager.register_module(name=['qwen-qwq', 'intern-s1'])
class QwenQwQReasoningParser(ReasoningParser):
    """Reasoning parser for Qwen QwQ model.

    The Qwen QwQ model uses ... tokens to denote reasoning text. This parser extracts the reasoning
    content from the model output.
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.think_start_token = ''
        self.think_end_token = ''

        self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL)

        if not self.model_tokenizer:
            raise ValueError('The model tokenizer must be passed to the ReasoningParser '
                             'constructor during construction.')

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        **kwargs,
    ) -> Union[DeltaMessage, None]:
        """Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming.

        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information
        about what has previously been parsed and extracted (see constructor)
        """
        # Skip single special tokens
        if delta_text == self.think_end_token or delta_text == self.think_start_token:
            return DeltaMessage(content='')

        # Check if  is present in previous or delta.
        # Keep compatibility with models that don't generate  tokens.
        if self.think_start_token in previous_text:
            if self.think_end_token in delta_text:
                #  in previous,  in delta,
                # extract reasoning content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            elif self.think_end_token in previous_text:
                #  in previous,  in previous,
                return DeltaMessage(content=delta_text)
            else:
                #  in previous, no  in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.think_start_token in delta_text:
            if self.think_end_token in delta_text:
                #  in delta,  in delta, extract reasoning content
                start_index = delta_text.find(self.think_start_token)
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[start_index + len(self.think_start_token):end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            else:
                #  in delta, no  in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # No  in previous or delta, also need to check for .
            # Because the model may have generated  without 
            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
            if self.think_end_token in delta_text:
                #  in delta with more tokens,
                # extract reasoning content and content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
            elif self.think_end_token in previous_text:
                #  in previous, thinking content ends
                return DeltaMessage(content=delta_text)
            else:
                # no  in previous or delta, reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)

    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,
                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:
        """Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Args:
            model_output (str): The model-generated string to extract reasoning content from.
            request (ChatCompletionRequest): he request object that was used to generate the model_output.

        Returns:
            reasoning_content (str | None): The reasoning content.
            final_output (str | None): The content.
        """
        # DeepSeek R1 doesn't generate  now.
        # Thus we assume the reasoning content is always at the start.
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
        if self.think_end_token not in model_output:
            # for qwen3 model, the reasoning content is wrapped by   xml tags
            return None, model_output
        # Add a start token if it's missing to keep compatibility.
        if self.think_start_token not in model_output:
            model_output = f'{self.think_start_token}{model_output}'
        # Use a regex to find the reasoning content
        reasoning_content = self.reasoning_regex.findall(model_output)[0]

        end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}')
        final_output = model_output[end_index:]
        if reasoning_content.startswith('\n'):
            reasoning_content = reasoning_content[1:]
        if reasoning_content.endswith('\n'):
            reasoning_content = reasoning_content[:-1]

        if len(final_output) == 0:
            return reasoning_content, None

        return reasoning_content, final_output


================================================
FILE: lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers
from functools import cached_property
from typing import Dict, Optional, Sequence, Tuple, Union

from mmengine import Registry

from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage

ReasoningParserManager = Registry('reasoning_parser', locations=['lmdeploy.serve.openai.reasoning_parser'])


class ReasoningParser:

    def __init__(self, tokenizer: object):
        self.model_tokenizer = tokenizer

    @cached_property
    def vocab(self) -> Dict[str, int]:
        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
        # whereas all tokenizers have .get_vocab()
        return self.model_tokenizer.get_vocab()

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        **kwargs,
    ) -> Union[DeltaMessage, None]:
        """Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming.

        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information
        about what has previously been parsed and extracted (see constructor)
        """
        raise NotImplementedError('ReasoningParser.extract_reasoning_content_streaming '
                                  'has not been implemented!')

    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,
                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:
        """Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Args:
            model_output (str): The model-generated string to extract reasoning content from.
            request (ChatCompletionRequest): he request object that was used to generate the model_output.

        Returns:
            reasoning_content (str | None): The reasoning content.
            final_output (str | None): The content.
        """
        raise NotImplementedError('ReasoningParser.extract_reasoning_content '
                                  'has not been implemented!')


================================================
FILE: lmdeploy/serve/openai/serving_chat_completion.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from .protocol import ChatCompletionRequest

if TYPE_CHECKING:
    from .api_server import VariableInterface


def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str:
    engine_config = server_context.get_engine_config()
    session_manager = server_context.get_session_manager()
    try:
        # Check logprobs settings
        logprobs_mode = engine_config.logprobs_mode
        logprobs = request.logprobs
        top_logprobs = request.top_logprobs or 0
        if logprobs_mode is None and (logprobs or top_logprobs > 0):
            return (f'Logprobs({logprobs})/top_logprobs({top_logprobs}) requested '
                    'but not enabled logprobs_mode in engine configuration')
        if logprobs_mode is not None and (top_logprobs < 0 or (not logprobs and top_logprobs > 0)):
            return (f'Invalid logprobs({logprobs})/top_logprobs({top_logprobs}) requested '
                    'when logprobs_mode is enabled in engine configuration.')
    except AttributeError:
        pass

    if session_manager.has(request.session_id):
        return f'The session_id {request.session_id!r} is occupied.'

    # check sampling settings
    if request.n <= 0:
        return f'The n {request.n!r} must be a positive int.'
    if not (0 < request.top_p <= 1):
        return f'The top_p {request.top_p!r} must be in (0, 1].'
    if request.top_k < 0:
        return f'The top_k {request.top_k!r} cannot be a negative integer.'
    if not (0 <= request.temperature <= 2):
        return f'The temperature {request.temperature!r} must be in [0, 2]'

    return ''


================================================
FILE: lmdeploy/serve/openai/serving_completion.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from .protocol import CompletionRequest

if TYPE_CHECKING:
    from .api_server import VariableInterface


def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str:
    engine_config = server_context.get_engine_config()
    session_manager = server_context.get_session_manager()
    try:
        # Check logprobs settings
        logprobs_mode = engine_config.logprobs_mode
        logprobs = request.logprobs or 0
        if logprobs > 0 and logprobs_mode is None:
            return f'logprobs({logprobs}) requested but not enabled logprobs_mode in engine configuration.'
        if logprobs_mode is not None and logprobs < 0:
            return 'logprobs must be non-negative when logprobs_mode is enabled in engine configuration.'
    except AttributeError:
        pass

    if session_manager.has(request.session_id):
        return f'The session_id {request.session_id!r} is occupied.'

    # check sampling settings
    if request.n <= 0:
        return f'The n {request.n!r} must be a positive int.'
    if not (0 < request.top_p <= 1):
        return f'The top_p {request.top_p!r} must be in (0, 1].'
    if request.top_k < 0:
        return f'The top_k {request.top_k!r} cannot be a negative integer.'
    if not (0 <= request.temperature <= 2):
        return f'The temperature {request.temperature!r} must be in [0, 2]'

    return ''


================================================
FILE: lmdeploy/serve/openai/serving_generate.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING

from .protocol import GenerateReqInput

if TYPE_CHECKING:
    from .api_server import VariableInterface


def check_request(request: GenerateReqInput, server_context: 'VariableInterface') -> str:
    engine_config = server_context.get_engine_config()
    session_manager = server_context.get_session_manager()
    try:
        # Check logprobs settings
        logprobs_mode = engine_config.logprobs_mode
        return_logprob = request.return_logprob
        if logprobs_mode is None and return_logprob:
            return f'return_logprob({return_logprob}) requested but not enabled logprobs_mode in engine configuration.'
    except AttributeError:
        pass

    if (request.prompt is not None) ^ (request.input_ids is None):
        return 'You must specify exactly one of prompt or input_ids'

    if request.prompt is not None and request.prompt == '':
        return 'The prompt must not be an empty string'

    if request.input_ids is not None and len(request.input_ids) == 0:
        return 'The input_ids must not be an empty list'

    if request.max_tokens is not None and request.max_tokens <= 0:
        return f'The max_tokens {request.max_tokens!r} must be a positive integer.'

    if session_manager.has(request.session_id):
        return f'The session_id {request.session_id!r} is occupied.'

    # check sampling settings
    if not (0 < request.top_p <= 1):
        return f'The top_p {request.top_p!r} must be in (0, 1].'
    if request.top_k < 0:
        return f'The top_k {request.top_k!r} cannot be a negative integer.'
    if not (0 <= request.temperature <= 2):
        return f'The temperature {request.temperature!r} must be in [0, 2]'

    return ''


================================================
FILE: lmdeploy/serve/openai/tool_parser/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .internlm2_parser import Internlm2ToolParser
from .llama3_parser import Llama3JsonToolParser
from .qwen2d5_parser import Qwen2d5ToolParser
from .qwen3_parser import Qwen3ToolParser
from .qwen3coder_parser import Qwen3CoderToolParser
from .tool_parser import ToolParser, ToolParserManager

__all__ = [
    'Internlm2ToolParser',
    'Qwen2d5ToolParser',
    'Qwen3ToolParser',
    'Qwen3CoderToolParser',
    'ToolParser',
    'ToolParserManager',
    'Llama3JsonToolParser',
]


================================================
FILE: lmdeploy/serve/openai/tool_parser/internlm2_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers
import json
from typing import Dict, Sequence, Union

import partial_json_parser
import shortuuid
from partial_json_parser.core.options import Allow

from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,
                                            ExtractedToolCallInformation, FunctionCall, ToolCall)
from lmdeploy.utils import get_logger

from .tool_parser import ToolParser, ToolParserManager
from .utils import extract_intermediate_diff

logger = get_logger('lmdeploy')


@ToolParserManager.register_module(['internlm', 'intern-s1'])
class Internlm2ToolParser(ToolParser):

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.position = 0

    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
        if request.tools and request.tool_choice != 'none':
            # do not skip special tokens because internlm use the special
            # tokens to indicated the start and end of the tool calls
            # information.
            request.skip_special_tokens = False
        return request

    def get_argments(self, obj):
        if 'parameters' in obj:
            return obj.get('parameters')
        elif 'arguments' in obj:
            return obj.get('arguments')
        return None

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        if '<|action_start|>' not in current_text:
            self.position = len(current_text)
            return DeltaMessage(content=delta_text)
        # if the tool call is sended, return a empty delta message
        # to make sure the finish_reason will be send correctly.
        if self.current_tool_id > 0:
            return DeltaMessage(content='')

        last_pos = self.position
        if '<|action_start|><|plugin|>\n' not in current_text[last_pos:]:
            return None

        new_delta = current_text[last_pos:]
        text, action = new_delta.split('<|action_start|><|plugin|>\n')

        if len(text) > 0:
            self.position = self.position + len(text)
            return DeltaMessage(content=text)

        action = action.strip()
        action = action.split('<|action_end|>'.strip())[0]

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
        flags = Allow.ALL if self.current_tool_name_sent \
            else Allow.ALL & ~Allow.STR

        try:
            parsable_arr = action

            # tool calls are generated in an object in inernlm2
            # it's not support parallel tool calls
            try:
                tool_call_arr: Dict = partial_json_parser.loads(parsable_arr, flags)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug('not enough tokens to parse into JSON yet')
                return None

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            if not self.current_tool_name_sent:
                function_name = tool_call_arr.get('name')
                if function_name:
                    self.current_tool_id = self.current_tool_id + 1
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      type='function',
                                      id=f'chatcmpl-tool-{shortuuid.random()}',
                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))
                    ])
                    self.current_tool_name_sent = True
                    self.streamed_args_for_tool.append('')
                else:
                    delta = None
            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id])
                cur_arguments = self.get_argments(tool_call_arr)

                # not arguments generated
                if not cur_arguments and not prev_arguments:
                    delta = None
                # will never happen
                elif not cur_arguments and prev_arguments:
                    logger.error('INVARIANT - impossible to have arguments reset '
                                 'mid-arguments')
                    delta = None
                # first time to get parameters
                elif cur_arguments and not prev_arguments:
                    cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)

                    arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)]
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      function=DeltaFunctionCall(arguments=arguments_delta).model_dump(
                                          exclude_none=True))
                    ])
                    self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
                # both prev and cur parameters, send the increase parameters
                elif cur_arguments and prev_arguments:
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)

                    argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json)

                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True))
                    ])
                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff

            # check to see if the name is defined and has been sent. if so,
            # stream the name - otherwise keep waiting
            # finish by setting old and returning None as base case
            tool_call_arr['arguments'] = self.get_argments(tool_call_arr)
            self.prev_tool_call_arr = [tool_call_arr]
            return delta
        except Exception:
            logger.exception('Error trying to handle streaming tool call.')
            logger.debug('Skipping chunk as a result of tool streaming extraction '
                         'error')
            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        text = model_output
        tools = request.tools
        if '<|action_start|><|plugin|>' in text:
            text, action = text.split('<|action_start|><|plugin|>')
            action = action.split('<|action_end|>'.strip())[0]
            action = action[action.find('{'):]
            action_dict = json.loads(action)
            name, parameters = action_dict['name'], json.dumps(action_dict.get('parameters',
                                                                               action_dict.get('arguments', {})),
                                                               ensure_ascii=False)

            if not tools or name not in [t.function.name for t in tools]:
                ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)

            tool_calls = [ToolCall(function=FunctionCall(name=name, arguments=parameters))]
            return ExtractedToolCallInformation(tools_called=True,
                                                tool_calls=tool_calls,
                                                content=text if len(text) > 0 else None)

        return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)


================================================
FILE: lmdeploy/serve/openai/tool_parser/llama3_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import re
from typing import Dict, List, Sequence, Union

import partial_json_parser
import shortuuid
from partial_json_parser.core.options import Allow

from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,
                                            ExtractedToolCallInformation, FunctionCall, ToolCall)
from lmdeploy.utils import get_logger

from .tool_parser import ToolParser, ToolParserManager
from .utils import find_common_prefix, is_complete_json, partial_json_loads

logger = get_logger('lmdeploy')


@ToolParserManager.register_module('llama3')
class Llama3JsonToolParser(ToolParser):
    """Tool call parser for Llama 3.1 models intended for use with the
    examples/tool_chat_template_llama.jinja template.

    Used when --tool-call-parser llama3 are all set
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)

        # initialize properties used for state when parsing tool calls in
        # streaming mode
        self.prev_tool_call_arr: List[Dict] = []
        self.current_tool_id: int = -1
        self.current_tool_name_sent: bool = False
        self.streamed_args_for_tool: List[str] = []  # map what has been streamed for each tool so far to a list
        self.bot_token = '<|python_tag|>'
        self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[0]
        self.tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL)

    def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
        """Extract the tool calls from a complete model response."""
        try:
            # load the JSON, and then use it to build the Function and
            # Tool Call
            action, _ = model_output.split('')
            parameters = action[action.find('{'):]
            name = action.split('{')[0]
            call_info_list = [(name, parameters)]

            tool_calls: List[ToolCall] = [
                ToolCall(type='function', function=FunctionCall(name=name, arguments=arguments))
                for name, arguments in call_info_list
            ]

            # get any content before  the tool call
            ret = ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content=None)
            return ret

        except Exception:
            logger.exception('Error in extracting tool call from response.')
            # return information to just treat the tool call as regular JSON
            return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        if not (current_text.startswith(self.bot_token) or current_text.startswith('{')):
            return DeltaMessage(content=delta_text)

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
        flags = Allow.ALL if self.current_tool_name_sent \
            else Allow.ALL & ~Allow.STR
        try:
            tool_call_arr = []
            is_complete = []
            try:
                # depending on the prompt format the Llama model may or may not
                # prefix the output with the <|python_tag|> token
                start_idx = len(self.bot_token) if current_text.startswith(self.bot_token) else 0
                while start_idx < len(current_text):
                    (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
                    is_complete.append(is_complete_json(current_text[start_idx:start_idx + end_idx]))
                    start_idx += end_idx + len('; ')
                    # depending on the prompt Llama can use
                    # either arguments or parameters
                    if 'parameters' in obj:
                        assert 'arguments' not in obj, \
                            'model generated both parameters and arguments'
                        obj['arguments'] = obj['parameters']
                    tool_call_arr.append(obj)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug('not enough tokens to parse into JSON yet')
                return None

            # select as the current tool call the one we're on the state at
            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
                if len(tool_call_arr) > 0 else {}

            # case -- if no tokens have been streamed for the tool, e.g.
            #   only the array brackets, stream nothing
            if len(tool_call_arr) == 0:
                return None

            # case: we are starting a new tool in the array
            #   -> array has > 0 length AND length has moved past cursor
            elif (len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1):

                # if we're moving on to a new call, first make sure we
                # haven't missed anything in the previous one that was
                # auto-generated due to JSON completions, but wasn't
                # streamed to the client yet.
                if self.current_tool_id >= 0:
                    cur_arguments = current_tool_call.get('arguments')
                    if cur_arguments:
                        cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                        sent = len(self.streamed_args_for_tool[self.current_tool_id])
                        argument_diff = cur_args_json[sent:]

                        logger.debug('got arguments diff: %s', argument_diff)
                        delta = DeltaMessage(tool_calls=[
                            DeltaToolCall(index=self.current_tool_id,
                                          function=DeltaFunctionCall(arguments=argument_diff).model_dump(
                                              exclude_none=True))
                        ])
                        self.streamed_args_for_tool[self.current_tool_id] += argument_diff
                    else:
                        delta = None
                else:
                    delta = None
                # re-set stuff pertaining to progress in the current tool
                self.current_tool_id = len(tool_call_arr) - 1
                self.current_tool_name_sent = False
                self.streamed_args_for_tool.append('')
                logger.debug('starting on new tool %d', self.current_tool_id)
                return delta

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            elif not self.current_tool_name_sent:
                function_name = current_tool_call.get('name')
                if function_name:

                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      type='function',
                                      id=f'chatcmpl-tool-{shortuuid.random()}',
                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))
                    ])
                    self.current_tool_name_sent = True
                else:
                    delta = None

            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                cur_arguments = current_tool_call.get('arguments')
                delta = None

                if cur_arguments:
                    sent = len(self.streamed_args_for_tool[self.current_tool_id])
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments')

                    argument_diff = None
                    if is_complete[self.current_tool_id]:
                        argument_diff = cur_args_json[sent:]
                    elif prev_arguments:
                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
                        if cur_args_json != prev_args_json:

                            prefix = find_common_prefix(prev_args_json, cur_args_json)
                            argument_diff = prefix[sent:]

                    if argument_diff is not None:
                        delta = DeltaMessage(tool_calls=[
                            DeltaToolCall(index=self.current_tool_id,
                                          function=DeltaFunctionCall(arguments=argument_diff).model_dump(
                                              exclude_none=True))
                        ])
                        self.streamed_args_for_tool[self.current_tool_id] += argument_diff

            self.prev_tool_call_arr = tool_call_arr
            return delta

        except Exception:
            logger.exception('Error trying to handle streaming tool call.')
            logger.debug('Skipping chunk as a result of tool streaming extraction '
                         'error')
            return None


================================================
FILE: lmdeploy/serve/openai/tool_parser/qwen2d5_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import re
from typing import Dict, Sequence, Union

import partial_json_parser
import shortuuid
from partial_json_parser.core.options import Allow

from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,
                                            ExtractedToolCallInformation, FunctionCall, ToolCall)
from lmdeploy.utils import get_logger

from .tool_parser import ToolParser, ToolParserManager
from .utils import extract_intermediate_diff

logger = get_logger('lmdeploy')


@ToolParserManager.register_module(['qwen2d5'])
class Qwen2d5ToolParser(ToolParser):

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.position = 0
        self.tool_start_token = ''
        self.tool_end_token = ''
        self.pattern = r'(.*?)'

    def get_argments(self, obj):
        if 'parameters' in obj:
            return obj.get('parameters')
        elif 'arguments' in obj:
            return obj.get('arguments')
        return None

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        if self.tool_start_token not in current_text:
            self.position = len(current_text)
            return DeltaMessage(content=delta_text)
        # if the tool call is sended, return a empty delta message
        # to make sure the finish_reason will be send correctly.
        if self.current_tool_id > 0:
            return DeltaMessage(content='')

        last_pos = self.position
        if self.tool_start_token not in current_text[last_pos:]:
            return None

        new_delta = current_text[last_pos:]
        text, action = new_delta.split(self.tool_start_token)

        if len(text) > 0:
            self.position = self.position + len(text)
            return DeltaMessage(content=text)

        action = action.strip()
        action = action.split(self.tool_end_token.strip())[0]

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
        flags = Allow.ALL if self.current_tool_name_sent \
            else Allow.ALL & ~Allow.STR

        try:
            parsable_arr = action

            # tool calls are generated in an object in inernlm2
            # it's not support parallel tool calls
            try:
                tool_call_arr: Dict = partial_json_parser.loads(parsable_arr, flags)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug('not enough tokens to parse into JSON yet')
                return None

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            if not self.current_tool_name_sent:
                function_name = tool_call_arr.get('name')
                if function_name:
                    self.current_tool_id = self.current_tool_id + 1
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      type='function',
                                      id=f'chatcmpl-tool-{shortuuid.random()}',
                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))
                    ])
                    self.current_tool_name_sent = True
                    self.streamed_args_for_tool.append('')
                else:
                    delta = None
            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id])
                cur_arguments = self.get_argments(tool_call_arr)

                # not arguments generated
                if not cur_arguments and not prev_arguments:
                    delta = None
                # will never happen
                elif not cur_arguments and prev_arguments:
                    logger.error('INVARIANT - impossible to have arguments reset '
                                 'mid-arguments')
                    delta = None
                # first time to get parameters
                elif cur_arguments and not prev_arguments:
                    cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)

                    arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)]
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      function=DeltaFunctionCall(arguments=arguments_delta).model_dump(
                                          exclude_none=True))
                    ])
                    self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
                # both prev and cur parameters, send the increase parameters
                elif cur_arguments and prev_arguments:
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)

                    argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json)

                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True))
                    ])
                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff

            # check to see if the name is defined and has been sent. if so,
            # stream the name - otherwise keep waiting
            # finish by setting old and returning None as base case
            tool_call_arr['arguments'] = self.get_argments(tool_call_arr)
            self.prev_tool_call_arr = [tool_call_arr]
            return delta
        except Exception:
            logger.exception('Error trying to handle streaming tool call.')
            logger.debug('Skipping chunk as a result of tool streaming extraction '
                         'error')
            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        text = model_output
        if self.tool_start_token in text:

            # get tool_call in text
            match_result_list = re.findall(self.pattern, text, re.DOTALL)
            tool_calls = []
            for match_result in match_result_list:
                action = json.loads(match_result)
                name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False)
                tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments)))

            # get text outside of tags
            if not text.startswith(''):
                text = text[:text.find('')]
            elif not text.endswith(''):
                text = text[text.rfind('') + len(''):]
            else:
                text = ''
            return ExtractedToolCallInformation(tools_called=True,
                                                tool_calls=tool_calls,
                                                content=text if len(text) > 0 else None)

        return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)


================================================
FILE: lmdeploy/serve/openai/tool_parser/qwen3_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import re
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Union

import shortuuid

from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,
                                            ExtractedToolCallInformation, FunctionCall, ToolCall)
from lmdeploy.utils import get_logger

from .tool_parser import ToolParser, ToolParserManager

logger = get_logger('lmdeploy')


@dataclass
class ParserState(object):
    """Maintains the state of parsing during tool call extraction."""
    position: int = 0  # Current position in the text being parsed
    current_index: int = -1  # Index of the current tool call
    parsing_reasoning: bool = False  # Whether currently parsing reasoning content

    id: str = ''  # ID of the current tool call

    def reset_tool_call(self):
        """Called when `` finish tag occurred."""
        self.id = ''


@ToolParserManager.register_module(['qwen', 'qwen3'])
class Qwen3ToolParser(ToolParser):
    """Parser for Qwen3 model's tool call format.

    Handles the extraction of tool calls from Qwen3's output format, which uses XML-like tags for tool calls and
    reasoning.
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.tool_start_token = ''
        self.tool_end_token = ''
        self.tool_call_pat = re.compile(r'\n*(.*?)', re.DOTALL)

    def get_argments(self, obj):
        """Extract arguments from tool call object, handling different formats.

        Supports both 'parameters' and 'arguments' keys in the tool call object.
        """
        if 'parameters' in obj:
            return obj.get('parameters')
        elif 'arguments' in obj:
            return obj.get('arguments')
        return None

    def _split(self, parser_state: ParserState, parsing_content: str):
        """Split content into tuple: (text_content, tool_content, has_tool_end)

        This method parses the model output and separates it into regular text,
        and tool call content.
        """
        # tool call
        try:
            start_idx = parsing_content.index(self.tool_start_token)
            # move to the beginning of tool_start_token
            parser_state.position += start_idx
        except ValueError:
            parser_state.position += len(parsing_content)
            return parsing_content, '', False
        try:
            end_idx = parsing_content.index(self.tool_end_token)
        except ValueError:
            # position holds until tool_end_token is found
            return parsing_content[:start_idx], '', False
        # move position to the end of tool_end_token
        parser_state.position += (end_idx - start_idx) + len(self.tool_end_token)
        return parsing_content[:start_idx], parsing_content[start_idx + len(self.tool_start_token):end_idx], True

    def _parse_delta_tool_call(self, parser_state: ParserState, tool_content: str) -> Optional[DeltaToolCall]:
        """Parse tool content into a DeltaToolCall object.

        This method handles parsing tool calls only when it's a valid tool
        """
        parsable_arr = tool_content.strip()
        try:
            tool_call_arr: Dict = json.loads(parsable_arr)
        except json.JSONDecodeError:
            logger.debug('cannot parse into JSON yet')
            return

        fcall = DeltaFunctionCall()
        func_name = tool_call_arr.get('name')
        if func_name:
            fcall.name = func_name
        args = self.get_argments(tool_call_arr)
        if args and isinstance(args, dict):
            fcall.arguments = json.dumps(args, ensure_ascii=False)
        # Return None if no new information to send
        if not fcall.name and not fcall.arguments:
            return
        if not parser_state.id:
            # A new tool call parsed, allocate a new id & index
            parser_state.id = f'chatcmpl-tool-{shortuuid.random()}'
            parser_state.current_index += 1
        # Create and return the DeltaToolCall object
        return DeltaToolCall(
            id=parser_state.id,
            index=parser_state.current_index,
            function=fcall.model_dump(exclude_none=True),
        )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        """Extract tool calls from streaming model output.

        This method processes incremental model output to extract tool calls, reasoning content, and regular text
        content in a streaming fashion. It maintains parser state between calls to handle partial outputs.
        """
        parser_state = getattr(request, '_tool_parser_state', None)
        if parser_state is None:
            parser_state = ParserState()
            setattr(request, '_tool_parser_state', parser_state)

        # Split the new content into text and tool content
        split_result = self._split(parser_state, current_text[parser_state.position:])
        text_content, tool_content, has_tool_end = split_result
        delta = DeltaMessage()

        # Add each type of content to the delta message if present
        if text_content:
            delta.content = text_content
        if tool_content:
            # Parse tool content into a DeltaToolCall object
            delta_tool_call = self._parse_delta_tool_call(parser_state, tool_content)
            if delta_tool_call is not None:
                delta.tool_calls = [delta_tool_call]
            if has_tool_end:
                parser_state.reset_tool_call()
        return delta

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        """Extract tool calls from complete model output.

        This method processes the full model output to extract tool calls, reasoning content, and regular text content.
        Unlike the streaming version, this processes the entire output at once.
        """
        text = model_output

        # Extract tool calls (content inside  tags)
        buf = []
        scan_pos = 0
        tool_calls = []
        for idx, match in enumerate(self.tool_call_pat.finditer(text)):
            buf.append(text[scan_pos:match.start()])  # Add text before the  tag
            scan_pos = match.end()
            action = json.loads(match.group(1))  # Parse the tool call JSON
            name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False)
            tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments)))
        if scan_pos < len(text):
            buf.append(text[scan_pos:])  # Add remaining text
        text = ''.join(buf)  # Reconstruct text without  tags

        return ExtractedToolCallInformation(
            content=text,
            tool_calls=tool_calls,
            tools_called=bool(tool_calls),
        )


================================================
FILE: lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import shortuuid

from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,
                                            ExtractedToolCallInformation, FunctionCall, ToolCall)
from lmdeploy.utils import get_logger

from .tool_parser import ToolParser, ToolParserManager

logger = get_logger('lmdeploy')


@dataclass
class ParserState(object):
    """Maintains the state of parsing during tool call extraction."""
    position: int = 0  # Current position in the text being parsed
    current_index: int = -1  # Index of the current tool call

    id: str = ''  # ID of the current tool call

    def reset_tool_call(self):
        """Called when `` finish tag occurred."""
        self.id = ''


@ToolParserManager.register_module(['qwen3coder'])
class Qwen3CoderToolParser(ToolParser):
    """Parser for Qwen3 Coder model's tool call format.

    Handles the extraction of tool calls from Qwen3 Coder's output format, which uses purely XML tags for function names
    and parameters, e.g.,   arg_value 
    
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.tool_start_token = ''
        self.tool_end_token = ''
        self.func_prefix = '(.*?)', re.DOTALL)

    def _split(self, parser_state: ParserState, parsing_content: str) -> Tuple[str, str, bool]:
        """Split content into tuple: (text_content, tool_content, has_tool_end)"""
        try:
            start_idx = parsing_content.index(self.tool_start_token)
            parser_state.position += start_idx
        except ValueError:
            parser_state.position += len(parsing_content)
            return parsing_content, '', False

        try:
            end_idx = parsing_content.index(self.tool_end_token)
        except ValueError:
            return parsing_content[:start_idx], parsing_content[start_idx:], False

        rem = end_idx - start_idx
        parser_state.position += rem + len(self.tool_end_token)
        return parsing_content[:start_idx], parsing_content[start_idx:end_idx + len(self.tool_end_token)], True

    def _extract_params(self, content: str) -> Tuple[Optional[str], Dict[str, Any], bool]:
        """Parse XML tool content into components."""
        content = content.replace(self.tool_start_token, '').replace(self.tool_end_token, '').strip()

        func_name = None
        func_start = content.find(self.func_prefix)
        if func_start != -1:
            name_start = func_start + len(self.func_prefix)
            terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1]
            if terminators:
                func_name = content[name_start:min(terminators)].strip()

        args_dict = {}
        search_idx = 0
        while True:
            param_start = content.find(self.param_prefix, search_idx)
            if param_start == -1:
                break

            name_start = param_start + len(self.param_prefix)
            terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1]
            if not terminators:
                break

            name_end = min(terminators)
            param_name = content[name_start:name_end].strip()

            val_start = name_end + 1
            val_end = content.find(self.param_end_token, val_start)
            if val_end == -1:
                break

            param_val_str = content[val_start:val_end].strip()

            if param_val_str.lower() == 'null':
                val = None
            elif param_val_str.lower() == 'true':
                val = True
            elif param_val_str.lower() == 'false':
                val = False
            else:
                try:
                    val = json.loads(param_val_str)
                except json.JSONDecodeError:
                    val = param_val_str
            args_dict[param_name] = val
            search_idx = val_end + len(self.param_end_token)

        is_func_closed = self.func_end_token in content
        return func_name, args_dict, is_func_closed

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        parser_state = getattr(request, '_tool_parser_state', None)
        if parser_state is None:
            parser_state = ParserState()
            setattr(request, '_tool_parser_state', parser_state)

        split_result = self._split(parser_state, current_text[parser_state.position:])
        text_content, tool_content, has_tool_end = split_result

        delta = DeltaMessage()
        if text_content:
            delta.content = text_content

        if tool_content:
            if not parser_state.id:
                parser_state.id = f'chatcmpl-tool-{shortuuid.random()}'
                parser_state.current_index += 1
                parser_state.has_emitted_name = False
                parser_state.has_emitted_json_start = False
                parser_state.json_closed = False
                parser_state.emitted_params = set()

            func_name, args_dict, is_func_closed = self._extract_params(tool_content)

            fcall_delta = DeltaFunctionCall()
            has_updates = False

            if func_name and not getattr(parser_state, 'has_emitted_name', False):
                fcall_delta.name = func_name
                parser_state.has_emitted_name = True
                has_updates = True

            json_fragments = []
            if not getattr(parser_state, 'has_emitted_json_start', False):
                if args_dict or is_func_closed:
                    json_fragments.append('{')
                    parser_state.has_emitted_json_start = True

            for k, v in args_dict.items():
                if k not in parser_state.emitted_params:
                    prefix = ', ' if len(parser_state.emitted_params) > 0 else ''
                    serialized = json.dumps(v, ensure_ascii=False)
                    json_fragments.append(f'{prefix}"{k}": {serialized}')
                    parser_state.emitted_params.add(k)

            if is_func_closed and not getattr(parser_state, 'json_closed', False):
                if getattr(parser_state, 'has_emitted_json_start', False):
                    json_fragments.append('}')
                    parser_state.json_closed = True

            joined_fragments = ''.join(json_fragments)
            if joined_fragments:
                fcall_delta.arguments = joined_fragments
                has_updates = True

            if has_updates:
                parsed_delta = DeltaToolCall(
                    id=parser_state.id,
                    index=parser_state.current_index,
                    function=fcall_delta,
                )
                delta.tool_calls = [parsed_delta]

        if has_tool_end:
            parser_state.reset_tool_call()
            # Prepare for the next tool call
            if hasattr(parser_state, 'has_emitted_name'):
                delattr(parser_state, 'has_emitted_name')
                delattr(parser_state, 'has_emitted_json_start')
                delattr(parser_state, 'json_closed')
                delattr(parser_state, 'emitted_params')

        return delta

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        text = model_output
        buf = []
        scan_pos = 0
        tool_calls = []

        for idx, match in enumerate(self.tool_call_pat.finditer(text)):
            buf.append(text[scan_pos:match.start()])
            scan_pos = match.end()

            tool_content = match.group(1)
            func_name, args_dict, _ = self._extract_params(tool_content)

            if func_name:
                tool_calls.append(
                    ToolCall(function=FunctionCall(
                        name=func_name, arguments=json.dumps(args_dict, ensure_ascii=False) if args_dict else '{}')))

        if scan_pos < len(text):
            buf.append(text[scan_pos:])

        text = ''.join(buf)

        return ExtractedToolCallInformation(
            content=text,
            tool_calls=tool_calls,
            tools_called=bool(tool_calls),
        )


================================================
FILE: lmdeploy/serve/openai/tool_parser/tool_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers
from functools import cached_property
from typing import Dict, List, Sequence, Union

from mmengine import Registry

from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')
ToolParserManager = Registry('tool_parser', locations=['lmdeploy.serve.openai.tool_parser'])


class ToolParser:
    """Abstract ToolParser class that should not be used directly.

    Provided properties and methods should be used in derived classes.
    """

    def __init__(self, tokenizer: object):
        self.prev_tool_call_arr: List[Dict] = []
        # the index of the tool call that is currently being parsed
        self.current_tool_id: int = -1
        self.current_tool_name_sent: bool = False
        self.streamed_args_for_tool: List[str] = []

        self.model_tokenizer = tokenizer

    @cached_property
    def vocab(self) -> Dict[str, int]:
        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
        # whereas all tokenizers have .get_vocab()
        return self.model_tokenizer.get_vocab()

    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
        """Static method that used to adjust the request parameters."""
        return request

    def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
        """Static method that should be implemented for extracting tool calls
        from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response available before sending to the client.
        Static because it's stateless.
        """
        raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!')

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        """Instance method that should be implemented for extracting tool calls
        from an incomplete response; for use when handling tool calls and
        streaming.

        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information
        about what has previously been parsed and extracted (see constructor)
        """
        raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been '
                                  'implemented!')


================================================
FILE: lmdeploy/serve/openai/tool_parser/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Copied from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/tool_parsers/utils.py

import json
from json import JSONDecodeError, JSONDecoder
from typing import Any, List, Tuple

import partial_json_parser
from partial_json_parser.core.options import Allow


def find_common_prefix(s1: str, s2: str) -> str:
    """Finds a common prefix that is shared between two strings, if there is
    one. Order of arguments is NOT important.

    This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to
    help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-
    braces are not returned prematurely.

    e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap'
    """
    prefix = ''
    min_length = min(len(s1), len(s2))
    for i in range(0, min_length):
        if s1[i] == s2[i]:
            prefix += s1[i]
        else:
            break
    return prefix


def find_common_suffix(s1: str, s2: str) -> str:
    """Finds a common suffix shared between two strings, if there is one. Order
    of arguments is NOT important. Stops when the suffix ends OR it hits an
    alphanumeric character.

    e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
    """
    suffix = ''
    min_length = min(len(s1), len(s2))
    for i in range(1, min_length + 1):
        if s1[-i] == s2[-i] and not s1[-i].isalnum():
            suffix = s1[-i] + suffix
        else:
            break
    return suffix


def extract_intermediate_diff(curr: str, old: str) -> str:
    """Given two strings, extract the difference in the middle between two
    strings that are known to have a common prefix and/or suffix.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely. The order of arguments IS
    important - the new version of the partially-parsed JSON must be the first
    argument, and the secnod argument must be from the previous generation.

    What it returns, is tokens that should be streamed to the client.

    e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
        -> 'ple'
    """
    suffix = find_common_suffix(curr, old)

    old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
    prefix = find_common_prefix(curr, old)
    diff = curr
    if len(suffix):
        diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]

    if len(prefix):
        # replace the prefix only once in case it's mirrored
        diff = diff.replace(prefix, '', 1)

    return diff


def find_all_indices(string: str, substring: str) -> List[int]:
    """Find all (starting) indices of a substring in a given string.

    Useful for tool call extraction
    """
    indices = []
    index = -1
    while True:
        index = string.find(substring, index + 1)
        if index == -1:
            break
        indices.append(index)
    return indices


# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
    try:
        return (partial_json_parser.loads(input_str, flags), len(input_str))
    except JSONDecodeError as e:
        if 'Extra data' in e.msg:
            dec = JSONDecoder()
            return dec.raw_decode(input_str)
        raise


def is_complete_json(input_str: str) -> bool:
    try:
        json.loads(input_str)
        return True
    except JSONDecodeError:
        return False


def consume_space(i: int, s: str) -> int:
    while i < len(s) and s[i].isspace():
        i += 1
    return i


================================================
FILE: lmdeploy/serve/processors/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .multimodal import MultimodalProcessor

__all__ = ['MultimodalProcessor']


================================================
FILE: lmdeploy/serve/processors/multimodal.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Any, Dict, List, Literal, Tuple

import PIL

from lmdeploy.model import MODELS, BaseChatTemplate
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import Modality
from lmdeploy.vl.media.connection import load_from_url
from lmdeploy.vl.media.image import ImageMediaIO
from lmdeploy.vl.media.time_series import TimeSeriesMediaIO
from lmdeploy.vl.media.video import VideoMediaIO

logger = get_logger('lmdeploy')


class MultimodalProcessor:
    """Processor for handling prompt preprocessing, message content merging,
    and multimodal processing."""

    def __init__(self,
                 tokenizer: Tokenizer,
                 chat_template: BaseChatTemplate,
                 vl_encoder=None,
                 backend: str | None = None):
        """Initialize MultimodalProcessor.

        Args:
            tokenizer: Tokenizer instance for encoding prompts.
            chat_template: Chat template instance for message processing.
            vl_encoder: Optional ImageEncoder instance for multimodal processing.
            backend: Optional backend name ('turbomind' or 'pytorch') for multimodal processing.
        """
        self.tokenizer = tokenizer
        self.chat_template = chat_template
        self.vl_encoder = vl_encoder
        self.backend = backend

    @staticmethod
    def merge_message_content(msg: Dict) -> Dict:
        """Merge multimodal content blocks and ensure content field exists.

        This function normalizes message content to match vLLM's behavior:
        1. Missing content field -> add content='' (empty string)
        2. None content -> convert to content='' (empty string)
        3. String content -> return as-is
        4. List content (multimodal) -> merge all text blocks with newline separator

        Args:
            msg: A message dict with 'role' and optionally 'content' field

        Returns:
            A message dict with 'content' field guaranteed to exist

        Note:
            This implementation is based on vLLM's content processing logic.
            vLLM uses "\n".join() to merge multiple text blocks from multimodal content.

        References:
            - vLLM content normalization:
              https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/chat_utils.py
              See _parse_chat_message_content() and _parse_chat_message_content_parts()
            - vLLM text merging logic:
              text_prompt = "\n".join(texts)
        """
        # If content is missing or None, convert to empty string (matches vLLM behavior)
        # This prevents Jinja2 template errors when rendering chat templates
        if 'content' not in msg or msg['content'] is None:
            result = dict(msg)
            result['content'] = ''
            return result

        # If content is already a string, return as-is
        if isinstance(msg['content'], str):
            return msg

        # If content is a list, merge all text blocks into a single string
        # This matches vLLM's behavior: text_prompt = "\n".join(texts)
        content_parts = []
        for block in msg['content']:
            if isinstance(block, dict) and block.get('type') == 'text':
                content_parts.append(block.get('text', ''))
        merged_content = '\n'.join(content_parts)

        # Preserve all other fields in the message (e.g., tool_calls)
        result = dict(msg)
        result['content'] = merged_content
        return result

    @staticmethod
    def _parse_multimodal_item(i: int, in_messages: List[Dict], out_messages: List[Dict], media_io_kwargs: Dict[str,
                                                                                                                Any]):
        """Synchronous helper to parse a single multimodal message item."""
        role = in_messages[i]['role']
        content = in_messages[i]['content']

        if role != 'user' or isinstance(content, str):
            out_messages[i] = in_messages[i]
            return

        assert isinstance(content, list)
        out_message = dict(role=role, content=[])

        for item in content:
            item_type = item.get('type')
            if item_type == 'text':
                out_message['content'].append(item)
                continue

            item_params = item.get(item_type, {})
            data_src = item_params.pop('url', None) or item_params.pop('data', None)

            if item_type == 'image_data':
                modality = Modality.IMAGE
                data = data_src
            elif item_type == 'image_url':
                modality = Modality.IMAGE
                img_io = ImageMediaIO(**media_io_kwargs.get('image', {}))
                data = load_from_url(data_src, img_io)
            elif item_type == 'video_url':
                modality = Modality.VIDEO
                vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {}))
                data, metadata = load_from_url(data_src, vid_io)
                item_params['video_metadata'] = metadata
            elif item_type == 'time_series_url':
                modality = Modality.TIME_SERIES
                ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {}))
                data = load_from_url(data_src, ts_io)
            else:
                raise NotImplementedError(f'unknown type: {item_type}')

            out_message['content'].append({'type': modality, 'data': data, **item_params})

        out_messages[i] = out_message

    @staticmethod
    async def async_parse_multimodal_item(messages: List[Dict],
                                          media_io_kwargs: Dict[str, Any] | None = None) -> List[Dict]:
        """Convert user-input multimodal data into GPT4V message format."""
        if isinstance(messages, dict):
            messages = [messages]
        assert isinstance(messages, list)

        out_messages = [None] * len(messages)
        media_io_kwargs = media_io_kwargs or {}
        loop = asyncio.get_event_loop()

        await asyncio.gather(*[
            loop.run_in_executor(None, MultimodalProcessor._parse_multimodal_item, i, messages, out_messages,
                                 media_io_kwargs) for i in range(len(messages))
        ])
        return out_messages

    async def get_prompt_input(self,
                               prompt: str | List[Dict],
                               do_preprocess: bool,
                               sequence_start: bool,
                               adapter_name: str,
                               tools: List[object] | None = None,
                               reasoning_effort: Literal['low', 'medium', 'high'] | None = None,
                               chat_template_kwargs: Dict | None = None,
                               media_io_kwargs: Dict[str, Any] | None = None,
                               mm_processor_kwargs: Dict[str, Any] | None = None,
                               **kwargs):
        """Process prompt and return prompt string and input_ids.

        Handles both text-only and multimodal prompts. If multimodal input is detected
        and vl_encoder is available, processes images accordingly.

        Args:
            prompt: Input prompt as string or list of message dicts.
            do_preprocess: Whether to apply chat template preprocessing.
            sequence_start: Indicator for starting a sequence.
            adapter_name: Adapter name for selecting chat template.
            tools: Optional list of tools.
            reasoning_effort: Optional reasoning effort level.
            chat_template_kwargs: Optional kwargs for chat template.
            media_io_kwargs: Optional kwargs for media IO operations.
            mm_processor_kwargs: Optional kwargs for multimodal processor.
            **kwargs: Additional keyword arguments.

        Returns:
            Dict with 'prompt' (str) and 'input_ids' (List[int]) keys for text-only,
            or dict with multimodal data for multimodal prompts.
        """
        # Handle string input
        if isinstance(prompt, str):
            return await self._get_text_prompt_input(prompt=prompt,
                                                     do_preprocess=do_preprocess,
                                                     sequence_start=sequence_start,
                                                     adapter_name=adapter_name,
                                                     tools=tools,
                                                     reasoning_effort=reasoning_effort,
                                                     chat_template_kwargs=chat_template_kwargs,
                                                     **kwargs)

        # Handle list input
        elif isinstance(prompt, list):
            # Check if multimodal input exists
            has_multimodal_input = self._has_multimodal_input(prompt)

            # If no multimodal input or no vl_encoder, use text-only processing
            if not has_multimodal_input or self.vl_encoder is None:
                return await self._get_text_prompt_input(prompt=prompt,
                                                         do_preprocess=do_preprocess,
                                                         sequence_start=sequence_start,
                                                         adapter_name=adapter_name,
                                                         tools=tools,
                                                         reasoning_effort=reasoning_effort,
                                                         chat_template_kwargs=chat_template_kwargs,
                                                         **kwargs)

            # Process multimodal input
            return await self._get_multimodal_prompt_input(messages=prompt,
                                                           do_preprocess=do_preprocess,
                                                           sequence_start=sequence_start,
                                                           adapter_name=adapter_name,
                                                           tools=tools,
                                                           chat_template_kwargs=chat_template_kwargs,
                                                           media_io_kwargs=media_io_kwargs,
                                                           mm_processor_kwargs=mm_processor_kwargs,
                                                           **kwargs)
        else:
            raise RuntimeError(f'unsupported prompt type: {type(prompt)}')

    @staticmethod
    def format_prompts(prompts: Any) -> List[Dict]:
        """Format prompts."""
        if not isinstance(prompts, list):
            prompts = [prompts]
        # str or batch of str
        if all(isinstance(prompt, str) for prompt in prompts):
            return prompts
        if (MultimodalProcessor._is_openai_message(prompts)
                or all(MultimodalProcessor._is_openai_message(prompt) for prompt in prompts)):
            return prompts
        if all(MultimodalProcessor._is_str_images_pair(prompt) for prompt in prompts):
            # batch of (prompt, image or [images]) or (image or [images], prompt) ->
            # [[openai_gpt4v_message], [openai_gpt4v_message], ...]
            return [[MultimodalProcessor._re_format_prompt_images_pair(prompt)] for prompt in prompts]
        raise ValueError(f'Unsupported prompts: {prompts}. Only support str, openai message format, '
                         'or (prompt, image or [images]) or (image or [images], prompt) pair.')

    @staticmethod
    def _is_openai_message(message) -> bool:
        """Check if the message conforms to openai message format."""
        return isinstance(message, list) and all(isinstance(msg, dict) for msg in message)

    @staticmethod
    def _is_str_images_pair(message) -> bool:
        """Check if the message is a (prompt, image or [images]) or (image or
        [images], prompt) pair."""
        if not (isinstance(message, tuple) and len(message) == 2):
            return False
        _1, _2 = message
        if MultimodalProcessor._is_image(_1) or MultimodalProcessor._is_image_list(_1):
            _1, _2 = _2, _1
        return isinstance(_1, str) and (MultimodalProcessor._is_image(_2) or MultimodalProcessor._is_image_list(_2))

    @staticmethod
    def _is_image(obj) -> bool:
        # image or image url or base64-encoded image data
        return (isinstance(obj, PIL.Image.Image)
                or isinstance(obj, str) and (obj.startswith('http') or obj.startswith('data:image')))

    @staticmethod
    def _is_image_list(obj) -> bool:
        return isinstance(obj, list) and all(MultimodalProcessor._is_image(img) for img in obj)

    @staticmethod
    def _re_format_prompt_images_pair(prompt: Tuple) -> Dict:
        """Reformat the prompt to openai message format."""
        from lmdeploy.vl import load_image

        messages = {'role': 'user', 'content': []}
        prompt, images = prompt
        prompt_first = True
        if MultimodalProcessor._is_image(prompt) or MultimodalProcessor._is_image_list(prompt):
            prompt, images = images, prompt
            prompt_first = False
        image_contents = []
        images = images if isinstance(images, list) else [images]
        for image in images:
            # 'image_url': means url or local path to image.
            # 'image_data': means PIL.Image.Image object.
            if isinstance(image, str):
                image = load_image(image)
                item = {'type': 'image_data', 'image_data': {'data': image}}
            elif isinstance(image, PIL.Image.Image):
                item = {'type': 'image_data', 'image_data': {'data': image}}
            else:
                raise ValueError('image should be a str(url/path) or PIL.Image.Image')
            image_contents.append(item)

        if prompt_first:
            messages['content'].append({'type': 'text', 'text': prompt})
            messages['content'].extend(image_contents)
        else:
            messages['content'].extend(image_contents)
            messages['content'].append({'type': 'text', 'text': prompt})
        return messages

    def _has_multimodal_input(self, messages: List[Dict]) -> bool:
        """Check if messages contain multimodal input (images)."""
        multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url']
        return any(
            isinstance(message.get('content'), list) and any(
                item.get('type') in multimodal_types for item in message['content']) for message in messages)

    async def _get_text_prompt_input(self,
                                     prompt: str | List[Dict],
                                     do_preprocess: bool,
                                     sequence_start: bool,
                                     adapter_name: str,
                                     tools: List[object] | None = None,
                                     reasoning_effort: Literal['low', 'medium', 'high'] | None = None,
                                     chat_template_kwargs: Dict | None = None,
                                     **kwargs):
        """Process text-only prompt and return prompt string and input_ids."""
        # Change multimodal data to openai text messages
        if isinstance(prompt, list):
            prompt = [self.merge_message_content(msg) for msg in prompt]
        if do_preprocess:
            # use adapter's chat template if possible
            chat_template = self.chat_template
            if adapter_name in MODELS.module_dict:
                chat_template = MODELS.module_dict[adapter_name]()
        else:
            chat_template = BaseChatTemplate()
        chat_template_kwargs = chat_template_kwargs or {}
        prompt = chat_template.messages2prompt(prompt,
                                               sequence_start,
                                               tools=tools,
                                               reasoning_effort=reasoning_effort,
                                               **chat_template_kwargs)
        if prompt is None:
            raise ValueError(
                f'You are using base template to handle chat task. Please specify a `--chat-template` name chosen from `lmdeploy list` if you want to use OpenAI messages input.'  # noqa
            )
        input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
        return {'prompt': prompt, 'input_ids': input_ids}

    async def _get_multimodal_prompt_input(self,
                                           messages: List[Dict],
                                           do_preprocess: bool,
                                           sequence_start: bool,
                                           adapter_name: str,
                                           tools: List[object] | None = None,
                                           chat_template_kwargs: Dict | None = None,
                                           media_io_kwargs: Dict[str, Any] | None = None,
                                           mm_processor_kwargs: Dict[str, Any] | None = None,
                                           **kwargs):
        """Process multimodal prompt and return processed data for inference
        engines."""
        chat_template = self.chat_template if do_preprocess else BaseChatTemplate()
        messages = await self.async_parse_multimodal_item(messages, media_io_kwargs)
        results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs)

        if self.backend == 'turbomind':
            # for tm engine, this module perform vision embedding after image
            # preprocessing. It utilizes the hf model's vision embeddings
            # functions and returns the input_ids, input_embeddings,
            # embedding_ranges and so on. All the returned values are passed
            # to tm engine for token generation
            results = await self.vl_encoder.async_infer(results)
            results = await self.vl_encoder.wrap_for_turbomind(messages=results,
                                                               chat_template=chat_template,
                                                               tokenizer=self.tokenizer,
                                                               sequence_start=sequence_start,
                                                               tools=tools,
                                                               chat_template_kwargs=chat_template_kwargs)
        elif self.backend == 'pytorch':
            # for pt engine, this module only conduct the image preprocessing
            # It leaves the vision embedding to the pt engine
            results = await self.vl_encoder.wrap_for_pytorch(messages=results,
                                                             chat_template=chat_template,
                                                             tokenizer=self.tokenizer,
                                                             sequence_start=sequence_start,
                                                             tools=tools,
                                                             chat_template_kwargs=chat_template_kwargs)
        return results


================================================
FILE: lmdeploy/serve/proxy/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/serve/proxy/proxy.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import asyncio
import copy
import json
import os
import os.path as osp
import random
import threading
import time
from collections import deque
from http import HTTPStatus
from typing import Deque, Literal

import aiohttp
import numpy as np
import requests
import uvicorn
from fastapi import BackgroundTasks, Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field

from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy
from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest
from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool
from lmdeploy.pytorch.disagg.messages import PDConnectionMessage
from lmdeploy.serve.openai.api_server import create_error_response
from lmdeploy.serve.openai.protocol import ModelCard  # noqa: E501
from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission
from lmdeploy.serve.proxy.utils import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, RoutingStrategy, err_msg
from lmdeploy.serve.utils.server_utils import validate_json_request
from lmdeploy.utils import get_logger

from .streaming_response import ProxyStreamingResponse
from .utils import APIServerException

logger = get_logger('lmdeploy')


class Status(BaseModel):
    """Status protocol consists of models' information."""
    role: EngineRole = EngineRole.Hybrid
    models: list[str] = Field(default=[], examples=[[]])
    unfinished: int = 0
    latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]])
    speed: int | None = Field(default=None, examples=[None])


class Node(BaseModel):
    """Node protocol consists of url and status."""
    url: str
    status: Status | None = None


CONTROLLER_HEART_BEAT_EXPIRATION = int(os.getenv('LMDEPLOY_CONTROLLER_HEART_BEAT_EXPIRATION', 90))


def heart_beat_controller(proxy_controller):
    while True:
        time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
        logger.info('Start heart beat check')
        proxy_controller.remove_stale_nodes_by_expiration()


class NodeManager:
    """Manage all the sub nodes.

    Args:
        config_path (str): the path of the config file.
        strategy (str): the strategy to dispatch node to handle the requests.
            - **random**: not fully radom, but decided by the speed of nodes.
            - **min_expected_latency**: will compute the expected latency to
                process the requests. The sooner of the node, the more requests
                will be dispatched to it.
            - **min_observed_latency**: Based on previous finished requests. The
                sooner they get processed, the more requests will be dispatched
                to.
    """

    def __init__(self,
                 config_path: str | None = None,
                 serving_strategy: str = 'Hybrid',
                 routing_strategy: str = 'min_expected_latency',
                 migration_protocol: str = 'RDMA',
                 link_type: str = 'RoCE',
                 with_gdr: bool = True,
                 cache_status: bool = True) -> None:
        self.nodes = dict()
        self.serving_strategy = ServingStrategy[serving_strategy]
        self.routing_strategy = RoutingStrategy.from_str(routing_strategy)

        self.cache_status = cache_status
        self.latencies = dict()
        self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.json')
        if config_path is not None:
            self.config_path = config_path
        if osp.exists(self.config_path) and self.cache_status:
            with open(self.config_path, 'r') as config_file:
                if os.path.getsize(self.config_path) > 0:
                    logger.info(f'loading node configuration: {self.config_path}')
                    config = json.load(config_file)
                    self.nodes = {
                        node_url: Status.model_validate_json(node_status)
                        for node_url, node_status in config.items()
                    }
        self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True)
        self.heart_beat_thread.start()
        self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT)

        # For PD Disaggregation
        self.migration_protocol = MigrationProtocol[migration_protocol]
        self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type])
        self.pd_connection_pool = PDConnectionPool()
        self.dummy_prefill = False

    def get_nodes(self, role: EngineRole) -> dict[str, Status]:
        items = list(self.nodes.items())
        return {node_url: node_status for (node_url, node_status) in items if node_status.role == role}

    @property
    def hybrid_nodes(self):
        return self.get_nodes(EngineRole.Hybrid)

    @property
    def prefill_nodes(self):
        return self.get_nodes(EngineRole.Prefill)

    @property
    def decode_nodes(self):
        return self.get_nodes(EngineRole.Decode)

    def update_config_file(self):
        """Update the config file."""
        nodes = copy.deepcopy(self.nodes)
        for _, status in nodes.items():
            status.latency = deque(list(status.latency)[-LATENCY_DEQUE_LEN:])
        if self.cache_status:
            with open(self.config_path, 'w') as config_file:  # update cfg yml
                json.dump({
                    node_url: node_status.model_dump_json()
                    for node_url, node_status in nodes.items()
                },
                          config_file,
                          indent=2)

    def add(self, node_url: str, status: Status | None = None):
        """Add a node to the manager.

        Args:
            node_url (str): A http url. Can be the url generated by
                `lmdeploy serve api_server`.
            description (Dict): The description of the node. An example:
                {'http://0.0.0.0:23333': {models: ['internlm-chat-7b]},
                speed: -1}. The speed here can be RPM or other metric. All the
                values of nodes should be the same metric.
        """
        if status is None:
            status = self.nodes.get(node_url, Status())
        if status.models != []:  # force register directly
            self.remove(node_url)
            self.nodes[node_url] = status
            self.update_config_file()
            return
        try:
            from lmdeploy.serve.openai.api_client import APIClient
            client = APIClient(api_server_url=node_url)
            status.models = client.available_models
            self.nodes[node_url] = status
        except requests.exceptions.RequestException as e:  # noqa
            logger.error(f'exception happened when adding node {node_url}, {e}')
            return self.handle_api_timeout(node_url)
        self.update_config_file()

    def remove(self, node_url: str):
        """Remove a node."""
        if node_url in self.nodes.keys():
            self.nodes.pop(node_url)
            self.update_config_file()
            self.pd_connection_pool.dereg_instance(node_url)

    def terminate_node(self, node_url: str):
        """Terminate a node."""
        success = True
        if node_url in self.nodes:
            self.nodes.pop(node_url)
            headers = {'accept': 'application/json'}
            try:
                response = requests.get(f'{node_url}/terminate', headers=headers)
                if response.status_code != 200:
                    success = False
                    logger.error(f'Failed to terminate node {node_url}, '
                                 f'error_code={response.status_code}, '
                                 f'error_msg={response.text}')
            except Exception as e:  # noqa
                logger.error(f'exception happened when terminating node {node_url}, {e}')
                success = False
        else:
            logger.error(f'terminating node {node_url} failed since it does not exist. '
                         'May try /nodes/status to check the node list')
            success = False
        self.update_config_file()
        return success

    def terminate_all_nodes(self):
        """Terminate all nodes."""
        node_url_li = list(self.nodes.keys())
        all_success = True
        for node_url in node_url_li:
            if not self.terminate_node(node_url):
                all_success = False
        return all_success

    def remove_stale_nodes_by_expiration(self):
        """Remove stale nodes."""
        to_be_deleted = []
        node_urls = list(self.nodes.keys())
        for node_url in node_urls:
            url = f'{node_url}/health'
            headers = {'accept': 'application/json'}
            try:
                response = requests.get(url, headers=headers)
                if response.status_code != 200:
                    to_be_deleted.append(node_url)
            except:  # noqa
                to_be_deleted.append(node_url)
        for node_url in to_be_deleted:
            self.remove(node_url)
            logger.info(f'Removed node_url: {node_url} '
                        'due to heart beat expiration')

    @property
    def model_list(self):
        """Supported model list."""
        model_names = []
        items = list(self.nodes.items())
        for _, status in items:
            model_names.extend(status.models)
        return model_names

    @property
    def status(self):
        """Return the status."""
        return self.nodes

    def get_node_url(self, model_name: str, role: EngineRole = EngineRole.Hybrid):
        """Add a node to the manager.

        Args:
            model_name (str): A http url. Can be the url generated by
                `lmdeploy serve api_server`.
        Return:
            A node url or None.
        """

        def get_matched_urls():
            urls_with_speeds, speeds, urls_without_speeds = [], [], []
            for node_url, status in self.get_nodes(role).items():
                if model_name in status.models:
                    if status.speed is not None:
                        urls_with_speeds.append(node_url)
                        speeds.append(status.speed)
                    else:
                        urls_without_speeds.append(node_url)
            all_matched_urls = urls_with_speeds + urls_without_speeds
            if len(all_matched_urls) == 0:
                return None
            # some nodes does not contain speed
            # we can set them the average speed value
            average_speed = sum(speeds) / len(speeds) if len(speeds) else 1
            all_the_speeds = speeds + [average_speed] * len(urls_without_speeds)
            return all_matched_urls, all_the_speeds

        if self.routing_strategy == RoutingStrategy.RANDOM:
            all_matched_urls, all_the_speeds = get_matched_urls()
            if len(all_matched_urls) == 0:
                return None
            speed_sum = sum(all_the_speeds)
            weights = [speed / speed_sum for speed in all_the_speeds]
            index = random.choices(range(len(all_matched_urls)), weights=weights)[0]
            url = all_matched_urls[index]
            return url
        elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY:
            all_matched_urls, all_the_speeds = get_matched_urls()
            if len(all_matched_urls) == 0:
                return None
            min_latency = float('inf')
            min_index = 0
            # random traverse nodes for low concurrency situation
            all_indexes = [i for i in range(len(all_the_speeds))]
            random.shuffle(all_indexes)
            for index in all_indexes:
                latency = self.get_nodes(role)[all_matched_urls[index]].unfinished / all_the_speeds[index]
                if min_latency > latency:
                    min_latency = latency
                    min_index = index
            url = all_matched_urls[min_index]
            return url
        elif self.routing_strategy == RoutingStrategy.MIN_OBSERVED_LATENCY:
            all_matched_urls, latencies = [], []
            for node_url, node_status in self.get_nodes(role).items():
                if model_name in node_status.models:
                    if len(node_status.latency):
                        latencies.append(np.mean(np.array(node_status.latency)))
                    else:
                        latencies.append(float('inf'))
                    all_matched_urls.append(node_url)
            if len(all_matched_urls) == 0:
                return None
            index = np.argmin(np.array(latencies))
            return all_matched_urls[index]
        else:
            raise ValueError(f'Invalid strategy: {self.routing_strategy}')

    async def check_request_model(self, model_name) -> JSONResponse | None:
        """Check if a request is valid."""
        if model_name in self.model_list:
            return
        ret = create_error_response(HTTPStatus.NOT_FOUND, f'The model {model_name!r} does not exist.')
        return ret

    def handle_unavailable_model(self, model_name):
        """Handle unavailable model.

        Args:
            model_name (str): the model in the request.
        """
        logger.warning(f'no model name: {model_name}')
        ret = {
            'error_code': ErrorCodes.MODEL_NOT_FOUND,
            'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],
        }
        return json.dumps(ret).encode() + b'\n'

    def handle_api_timeout(self, node_url):
        """Handle the api time out."""
        logger.warning(f'api timeout: {node_url}')
        ret = {
            'error_code': ErrorCodes.API_TIMEOUT.value,
            'text': err_msg[ErrorCodes.API_TIMEOUT],
        }
        return json.dumps(ret).encode() + b'\n'

    async def stream_generate(self, request: dict, node_url: str, endpoint: str):
        """Return a generator to handle the input request.

        Args:
            request (Dict): the input request.
            node_url (str): the node url.
            endpoint (str): the endpoint. Such as `/v1/chat/completions`.
        """
        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:
                    async for line in response.content:
                        if line.strip():
                            yield line + b'\n\n'
        except (Exception, GeneratorExit, aiohttp.ClientError) as e:  # noqa
            logger.error(f'caught an exception: {e}')
            # exception happened, reduce unfinished num
            yield self.handle_api_timeout(node_url)

    async def generate(self, request: dict, node_url: str, endpoint: str):
        """Return a the response of the input request.

        Args:
            request (Dict): the input request.
            node_url (str): the node url.
            endpoint (str): the endpoint. Such as `/v1/chat/completions`.
        """
        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:
                    return await response.text()
        except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e:  # noqa  # yapf: disable
            logger.error(f'caught an exception: {e}')
            return self.handle_api_timeout(node_url)

    async def forward_raw_request_stream_generate(self, raw_request: Request, node_url: str, endpoint: str):
        try:
            target_url = node_url.rstrip('/') + endpoint
            headers = self._prepare_headers(raw_request)
            body_bytes = await raw_request.body()
            async with aiohttp.ClientSession() as session:
                async with session.post(target_url, headers=headers, data=body_bytes,
                                        timeout=self.aiotimeout) as response:
                    if response.status != 200:
                        error_body = await response.read()
                        raise APIServerException(status_code=response.status, body=error_body)
                    async for line in response.content:
                        if line.strip():
                            yield line + b'\n\n'
        except APIServerException:
            # raise APIServerException again to be caught by the outer layer
            raise
        except (Exception, GeneratorExit, aiohttp.ClientError) as e:  # noqa
            logger.error(f'caught an exception: {e}')
            # exception happened, reduce unfinished num
            yield self.handle_api_timeout(node_url)

    async def forward_raw_request_generate(self, raw_request: Request, node_url: str, endpoint: str):
        try:
            target_url = node_url.rstrip('/') + endpoint
            headers = self._prepare_headers(raw_request)
            body_bytes = await raw_request.body()
            async with aiohttp.ClientSession() as session:
                async with session.post(target_url, headers=headers, data=body_bytes,
                                        timeout=self.aiotimeout) as response:
                    return await response.text()
        except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e:  # noqa  # yapf: disable
            logger.error(f'caught an exception: {e}')
            return self.handle_api_timeout(node_url)

    def pre_call(self, node_url):
        """Preprocess before the request get processed.

        Args:
            node_url (str): the node url.
        """
        self.nodes[node_url].unfinished += 1
        return time.time()

    def post_call(self, node_url: str, start: int):
        """Post process after the response finished.

        Args:
            node_url (str): the node url.
            start (int): the start time point. time.time()
        """
        if node_url in self.nodes:
            self.nodes[node_url].unfinished -= 1
            self.nodes[node_url].latency.append(time.time() - start)

    def create_background_tasks(self, url: str, start: int):
        """To create a background task.

        Args:
            node_url (str): the node url.
            start (int): the start time point. time.time()
        """
        background_tasks = BackgroundTasks()
        background_tasks.add_task(self.post_call, url, start)
        return background_tasks

    def _prepare_headers(self, raw_request: Request) -> dict[str, str]:
        headers = dict((name, value) for name, value in raw_request.headers.items() if name.lower() != 'host')

        client_ip = raw_request.client.host if raw_request.client else 'unknown'
        headers.update({
            'X-Forwarded-For': client_ip,
            'X-Forwarded-Host': raw_request.headers.get('host', ''),
            'X-Forwarded-Proto': raw_request.url.scheme,
        })

        return headers


app = FastAPI(docs_url='/')
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)
node_manager = NodeManager()


@app.get('/v1/models')
def available_models():
    """Show available models."""
    model_cards = []
    for model_name in node_manager.model_list:
        model_cards.append(ModelCard(id=model_name, root=model_name, permission=[ModelPermission()]))
    return ModelList(data=model_cards)


@app.get('/nodes/status')
def node_status():
    """Show nodes status."""
    try:
        return node_manager.status
    except:  # noqa
        return False


@app.post('/nodes/add', dependencies=[Depends(validate_json_request)])
def add_node(node: Node, raw_request: Request = None):
    """Add a node to the manager.

    - **url** (str): A http url. Can be the url generated by
      `lmdeploy serve api_server`.
    - **status** (Dict): The description of the node. An example:
      ``{models: ['internlm-chat-7b],  speed: 1}``. The speed here can be
      RPM or other metric. All the values of nodes should be the same metric.
    """
    try:
        res = node_manager.add(node.url, node.status)
        if res is not None:
            logger.error(f'add node {node.url} failed, {res}')
            return res
        logger.info(f'add node {node.url} successfully')
        return 'Added successfully'
    except:  # noqa
        return 'Failed to add, please check the input url.'


@app.post('/nodes/remove', dependencies=[Depends(validate_json_request)])
def remove_node(node: Node):
    """Show available models."""
    try:
        node_url = node.url
        node_manager.remove(node_url)
        logger.info(f'delete node {node_url} successfully')
        return 'Deleted successfully'
    except:  # noqa
        logger.error(f'delete node {node.url} failed.')
        return 'Failed to delete, please check the input url.'


@app.post('/nodes/terminate', dependencies=[Depends(validate_json_request)])
def terminate_node(node: Node):
    """Terminate nodes."""
    try:
        node_url = node.url
        success = node_manager.terminate_node(node_url)
        if not success:
            return f'Failed to terminate node {node_url}'
        return 'Terminated successfully'
    except:  # noqa
        logger.error(f'Terminate node {node_url} failed.')
        return 'Failed to terminate node {node_url}, please check the input url.'


@app.get('/nodes/terminate_all', dependencies=[Depends(validate_json_request)])
def terminate_node_all():
    """Terminate nodes."""
    try:
        success = node_manager.terminate_all_nodes()
        if not success:
            return 'Failed to terminate all nodes'
        return 'All nodes terminated successfully'
    except:  # noqa
        logger.error('Failed to terminate all nodes')
        return 'Failed to terminate all nodes.'


@app.post('/distserve/connection_warmup', dependencies=[Depends(validate_json_request)])
async def connection_warmup():
    await asyncio.gather(*[
        node_manager.pd_connection_pool.connect(
            PDConnectionMessage(
                p_url=p_url,
                d_url=d_url,
                protocol=node_manager.migration_protocol,
                rdma_config=node_manager.rdma_config,
            )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes
    ])
    return JSONResponse({'SUCCESS': True})


@app.post('/distserve/gc', dependencies=[Depends(validate_json_request)])
async def cache_block_gc_to_be_migrated():
    # TODO (JimyMa): add garbage collection of to be migrated request
    raise NotImplementedError


@app.post('/v1/chat/completions', dependencies=[Depends(validate_json_request)])
async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None):
    """Completion API similar to OpenAI's API.

    Refer to https://platform.openai.com/docs/api-reference/chat/create
    for the API specification.

    The request should be a JSON object with the following fields:

    - **model**: model name. Available from /v1/models.
    - **messages**: string prompt or chat history in OpenAI format. Chat history
      example: `[{"role": "user", "content": "hi"}]`.
    - **temperature** (float): to modulate the next token probability
    - **top_p** (float): If set to float < 1, only the smallest set of most
      probable tokens with probabilities that add up to top_p or higher
      are kept for generation.
    - **n** (int): How many chat completion choices to generate for each input
      message. **Only support one here**.
    - **stream**: whether to stream the results or not. Default to false.
    - **max_completion_tokens** (int | None): output token nums. Default to None.
    - **max_tokens** (int | None): output token nums. Default to None.
      Deprecated: Use max_completion_tokens instead.
    - **repetition_penalty** (float): The parameter for repetition penalty.
      1.0 means no penalty
    - **stop** (str | List[str] | None): To stop generating further
      tokens. Only accept stop words that's encoded to one token idex.
    - **response_format** (Dict | None): To generate response according to given
      schema. Examples:

      .. code-block:: json

        {
          "type": "json_schema",
          "json_schema":{
            "name": "test",
            "schema":{
              "properties":{
                "name":{"type":"string"}
              },
              "required":["name"],
              "type":"object"
            }
          }
        }

      or
      ``{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}``
    - **logit_bias** (Dict): Bias to logits. Only supported in pytorch engine.
    - **tools** (List): A list of tools the model may call. Currently, only
      internlm2 functions are supported as a tool. Use this to specify a
      list of functions for which the model can generate JSON inputs.
    - **tool_choice** (str | object): Controls which (if any) tool is called by
      the model. `none` means the model will not call any tool and instead
      generates a message. Specifying a particular tool via
      ``{"type": "function", "function": {"name": "my_function"}}``
      forces the model to call that tool. `auto` or `required` will put all
      the tools information to the model.

    Additional arguments supported by LMDeploy:

    - **top_k** (int): The number of the highest probability vocabulary
      tokens to keep for top-k-filtering
    - **ignore_eos** (bool): indicator for ignoring eos
    - **skip_special_tokens** (bool): Whether or not to remove special tokens
      in the decoding. Default to be True.
    - **min_new_tokens** (int): To generate at least numbers of tokens.
    - **min_p** (float): Minimum token probability, which will be scaled by the
      probability of the most likely token. It must be a value between
      0 and 1. Typical values are in the 0.01-0.2 range, comparably
      selective as setting `top_p` in the 0.99-0.8 range (use the
      opposite of normal `top_p` values)

    Currently we do not support the following features:

    - **presence_penalty** (replaced with repetition_penalty)
    - **frequency_penalty** (replaced with repetition_penalty)
    """
    check_response = await node_manager.check_request_model(request.model)
    if check_response is not None:
        return check_response

    if node_manager.serving_strategy == ServingStrategy.Hybrid:
        node_url = node_manager.get_node_url(request.model)
        if not node_url:
            return node_manager.handle_unavailable_model(request.model)

        logger.info(f'A request is dispatched to {node_url}')
        start = node_manager.pre_call(node_url)
        if request.stream is True:
            response = node_manager.forward_raw_request_stream_generate(raw_request, node_url, '/v1/chat/completions')
            background_task = node_manager.create_background_tasks(node_url, start)
            return ProxyStreamingResponse(response, background=background_task, media_type='text/event-stream')
        else:
            response = await node_manager.forward_raw_request_generate(raw_request, node_url, '/v1/chat/completions')
            node_manager.post_call(node_url, start)
            return JSONResponse(json.loads(response))
    elif node_manager.serving_strategy == ServingStrategy.DistServe:
        request_dict = request.model_dump()

        # Prefill
        prefill_request_dict = copy.deepcopy(request_dict)
        prefill_request_dict['max_tokens'] = 1
        prefill_request_dict['max_completion_tokens'] = 1
        prefill_request_dict['stream'] = False
        prefill_request_dict['with_cache'] = True
        prefill_request_dict['preserve_cache'] = True

        prefill_info = {}
        p_url = 'dummy:dummy'
        if not node_manager.dummy_prefill:
            p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)
            if not p_url:
                return node_manager.handle_unavailable_model(request.model)
            logger.info(f'A Prefill request is dispatched to {p_url}')

            start = node_manager.pre_call(p_url)
            prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/chat/completions'))
            node_manager.post_call(p_url, start)

        # # Decode
        d_url = node_manager.get_node_url(request.model, EngineRole.Decode)
        if not d_url:
            return node_manager.handle_unavailable_model(request.model)
        logger.info(f'A Decode request is dispatched to {d_url}')

        if not node_manager.dummy_prefill:
            if not node_manager.pd_connection_pool.is_connected(p_url, d_url):
                await node_manager.pd_connection_pool.connect(
                    PDConnectionMessage(
                        p_url=p_url,
                        d_url=d_url,
                        protocol=node_manager.migration_protocol,
                        rdma_config=node_manager.rdma_config,
                    ))

        remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0
        remote_block_ids = prefill_info.get('cache_block_ids') or []
        remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0

        request_dict['migration_request'] = MigrationRequest(
            protocol=node_manager.migration_protocol,
            remote_engine_id=p_url,
            remote_session_id=remote_session_id,
            remote_block_ids=remote_block_ids,
            remote_token_id=remote_token_id,
            is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json')

        start = node_manager.pre_call(d_url)
        if not node_manager.dummy_prefill:
            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])
        if request.stream is True:
            response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions')
            background_task = node_manager.create_background_tasks(d_url, start)
            resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')
        else:
            response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions')
            node_manager.post_call(d_url, start)
            resp = JSONResponse(json.loads(response))

        if not node_manager.dummy_prefill:
            node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])

        return resp

    else:
        raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')


@app.post('/v1/completions', dependencies=[Depends(validate_json_request)])
async def completions_v1(request: CompletionRequest, raw_request: Request = None):
    """Completion API similar to OpenAI's API.

    Go to https://platform.openai.com/docs/api-reference/completions/create
    for the API specification.

    The request should be a JSON object with the following fields:

    - **model** (str): model name. Available from /v1/models.
    - **prompt** (str): the input prompt.
    - **suffix** (str): The suffix that comes after a completion of inserted text.
    - **max_completion_tokens** (int | None): output token nums. Default to None.
    - **max_tokens** (int): output token nums. Default to 16.
      Deprecated: Use max_completion_tokens instead.
    - **temperature** (float): to modulate the next token probability
    - **top_p** (float): If set to float < 1, only the smallest set of most
      probable tokens with probabilities that add up to top_p or higher
      are kept for generation.
    - **n** (int): How many chat completion choices to generate for each input
      message. **Only support one here**.
    - **stream**: whether to stream the results or not. Default to false.
    - **repetition_penalty** (float): The parameter for repetition penalty.
      1.0 means no penalty
    - **user** (str): A unique identifier representing your end-user.
    - **stop** (str | List[str] | None): To stop generating further
      tokens. Only accept stop words that's encoded to one token idex.

    Additional arguments supported by LMDeploy:

    - **ignore_eos** (bool): indicator for ignoring eos
    - **skip_special_tokens** (bool): Whether or not to remove special tokens
      in the decoding. Default to be True.
    - **top_k** (int): The number of the highest probability vocabulary
      tokens to keep for top-k-filtering

    Currently we do not support the following features:

    - **logprobs** (not supported yet)
    - **presence_penalty** (replaced with repetition_penalty)
    - **frequency_penalty** (replaced with repetition_penalty)
    """
    check_response = await node_manager.check_request_model(request.model)
    if check_response is not None:
        return check_response
    if node_manager.serving_strategy == ServingStrategy.Hybrid:
        node_url = node_manager.get_node_url(request.model)
        if not node_url:
            return node_manager.handle_unavailable_model(request.model)

        logger.info(f'A request is dispatched to {node_url}')
        start = node_manager.pre_call(node_url)
        if request.stream is True:
            response = node_manager.forward_raw_request_stream_generate(raw_request, node_url, '/v1/completions')
            background_task = node_manager.create_background_tasks(node_url, start)
            return ProxyStreamingResponse(response, background=background_task, media_type='text/event-stream')
        else:
            response = await node_manager.forward_raw_request_generate(raw_request, node_url, '/v1/completions')
            node_manager.post_call(node_url, start)
            return JSONResponse(json.loads(response))
    elif node_manager.serving_strategy == ServingStrategy.DistServe:
        request_dict = request.model_dump()

        # Prefill
        prefill_request_dict = copy.deepcopy(request_dict)
        prefill_request_dict['max_tokens'] = 1
        prefill_request_dict['stream'] = False
        prefill_request_dict['with_cache'] = True
        prefill_request_dict['preserve_cache'] = True

        if not node_manager.dummy_prefill:
            try:
                p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)
            except Exception as e:
                logger.error(f'error Msg: {str(e)}')
                return {'status': 'Instance sch error, cannot find available p_url'}

            if not p_url:
                return node_manager.handle_unavailable_model(request.model)
            logger.info(f'A Prefill request is dispatched to {p_url}')

            start = node_manager.pre_call(p_url)
            prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/completions'))
            node_manager.post_call(p_url, start)
        else:
            p_url = 'dummy:dummy'
            prefill_info = {}

        # Decode
        try:
            d_url = node_manager.get_node_url(request.model, EngineRole.Decode)
        except Exception as e:
            logger.error(f'error Msg: {str(e)}')
            return {'status': 'Instance sch error, cannot find available p_url'}

        if not d_url:
            return node_manager.handle_unavailable_model(request.model)
        logger.info(f'A Decode request is dispatched to {d_url}')

        if not node_manager.dummy_prefill:
            if not node_manager.pd_connection_pool.is_connected(p_url, d_url):
                try:
                    await node_manager.pd_connection_pool.connect(
                        PDConnectionMessage(
                            p_url=p_url,
                            d_url=d_url,
                            protocol=node_manager.migration_protocol,
                            rdma_config=node_manager.rdma_config,
                        ))
                except Exception as e:
                    logger.error(f'error Msg: {str(e)}')
                    return {'status': f'Connection error, cannot establish connection {(p_url, d_url)}'}
            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])

        remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0
        remote_block_ids = prefill_info.get('cache_block_ids') or []
        remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0
        request_dict['migration_request'] = MigrationRequest(
            protocol=node_manager.migration_protocol,
            remote_engine_id=p_url,
            remote_session_id=remote_session_id,
            remote_block_ids=remote_block_ids,
            remote_token_id=remote_token_id,
            is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json')

        start = node_manager.pre_call(d_url)
        if not node_manager.dummy_prefill:
            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])
        if request.stream is True:
            response = node_manager.stream_generate(request_dict, d_url, '/v1/completions')
            background_task = node_manager.create_background_tasks(d_url, start)
            resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')
        else:
            response = await node_manager.generate(request_dict, d_url, '/v1/completions')
            node_manager.post_call(d_url, start)
            resp = JSONResponse(json.loads(response))
        if not node_manager.dummy_prefill:
            node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info.get('id'))
        return resp
    else:
        raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')


def proxy(server_name: str = '0.0.0.0',
          server_port: int = 8000,
          serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid',
          routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency',
          api_keys: list[str] | str | None = None,
          ssl: bool = False,
          log_level: str = 'INFO',
          disable_cache_status: bool = False,
          link_type: Literal['RoCE', 'IB'] = 'RoCE',
          migration_protocol: Literal['RDMA'] = 'RDMA',
          dummy_prefill: bool = False,
          **kwargs):
    """To launch the proxy server.

    Args:
        server_name (str): the server name of the proxy. Default to '0.0.0.0'.
        server_port (str): the server port. Default to 8000.
        serving_strategy ('Hybrid' | 'DistServe'):  the strategy to serving. Hybrid default.
            DistServe for PD Disaggregation.
        route_strategy ('random' | 'min_expected_latency' | 'min_observed_latency'):
            the strategy to dispatch requests to nodes. Default to
            'min_expected_latency'
        api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
            a single api_key. Default to None, which means no api key applied.
        ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
        log_level (str): Set the log level. Default to INFO.
        disable_cache_status (str): Whether to cache the proxy status to
             proxy_config.yml.
        migration_protocol: migration protocol when PD disaggregation. RDMA default.
    """  # noqa
    node_manager.serving_strategy = ServingStrategy[serving_strategy]
    node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy)
    node_manager.migration_protocol = MigrationProtocol[migration_protocol]
    node_manager.dummy_prefill = dummy_prefill

    node_manager.rdma_config = DistServeRDMAConfig(
        link_type=RDMALinkType[link_type],
        with_gdr=True,
    )
    node_manager.cache_status = not disable_cache_status
    if api_keys is not None and (tokens := [key for key in api_keys if key]):
        from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware

        app.add_middleware(AuthenticationMiddleware, tokens=tokens)

    ssl_keyfile, ssl_certfile = None, None
    if ssl:
        ssl_keyfile = os.environ['SSL_KEYFILE']
        ssl_certfile = os.environ['SSL_CERTFILE']

    logger.setLevel(log_level)
    uvicorn_log_level = os.getenv('UVICORN_LOG_LEVEL', 'info').lower()
    uvicorn.run(app=app,
                host=server_name,
                port=server_port,
                log_level=uvicorn_log_level,
                ssl_keyfile=ssl_keyfile,
                ssl_certfile=ssl_certfile)


if __name__ == '__main__':
    import fire

    fire.Fire(proxy)


================================================
FILE: lmdeploy/serve/proxy/streaming_response.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import json

from fastapi.responses import StreamingResponse

from .utils import APIServerException


class ProxyStreamingResponse(StreamingResponse):
    """StreamingResponse that can handle exceptions thrown by the generator."""

    def __init__(self, content, **kwargs):
        super().__init__(content, **kwargs)

    async def stream_response(self, send) -> None:
        iterator = self.body_iterator.__aiter__()
        try:
            # get the first chunk(stream_generate's first yield)
            first_chunk = await iterator.__anext__()

        except APIServerException as e:
            headers = self._convert_headers_to_asgi(e.headers) if e.headers else self.raw_headers
            await send({'type': 'http.response.start', 'status': e.status_code, 'headers': headers})
            await send({
                'type': 'http.response.body',
                'body': e.body,
                'more_body': False,
            })
            return

        # normal case, send the header first
        await send({
            'type': 'http.response.start',
            'status': self.status_code,
            'headers': self.raw_headers,
        })

        # send body with the first chunk
        await send({
            'type': 'http.response.body',
            'body': first_chunk,
            'more_body': True,
        })

        # continue streaming output
        try:
            async for chunk in iterator:
                await send({
                    'type': 'http.response.body',
                    'body': chunk,
                    'more_body': True,
                })
        except Exception:
            error_data = {'error': True, 'status': 500, 'message': 'Internal streaming error'}
            await send({
                'type': 'http.response.body',
                'body': json.dumps(error_data).encode('utf-8'),
                'more_body': False,
            })
            return

        await send({
            'type': 'http.response.body',
            'body': b'',
            'more_body': False,
        })

    def _convert_headers_to_asgi(self, headers: dict) -> list[tuple[bytes, bytes]]:
        """Convert dict headers to ASGI raw header tuples."""
        return [(name.lower().encode('latin-1'), str(value).encode('latin-1')) for name, value in headers.items()]


================================================
FILE: lmdeploy/serve/proxy/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import enum
import os

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

LATENCY_DEQUE_LEN = 15
AIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None)
if AIOHTTP_TIMEOUT is not None:
    AIOHTTP_TIMEOUT = int(AIOHTTP_TIMEOUT)
logger.info(f'AIOHTTP_TIMEOUT set to {AIOHTTP_TIMEOUT}. It can be modified before launching the proxy server '
            'through env variable AIOHTTP_TIMEOUT')


class RoutingStrategy(enum.Enum):
    """Strategy to dispatch requests to nodes."""
    RANDOM = enum.auto()
    MIN_EXPECTED_LATENCY = enum.auto()
    MIN_OBSERVED_LATENCY = enum.auto()

    @classmethod
    def from_str(cls, name):
        """Get strategy from string."""
        if name == 'random':
            return cls.RANDOM
        elif name == 'min_expected_latency':
            return cls.MIN_EXPECTED_LATENCY
        elif name == 'min_observed_latency':
            return cls.MIN_OBSERVED_LATENCY
        else:
            raise ValueError(f'Invalid strategy: {name}. Supported: random, '
                             f'min_expected_latency, min_observed_latency.')


class ErrorCodes(enum.Enum):
    """Error codes."""
    MODEL_NOT_FOUND = 10400
    SERVICE_UNAVAILABLE = 10401
    API_TIMEOUT = 10402


err_msg = {
    ErrorCodes.MODEL_NOT_FOUND: 'The request model name does not exist in the model list.',
    ErrorCodes.SERVICE_UNAVAILABLE: 'The service is unavailable now. May retry later.',
    ErrorCodes.API_TIMEOUT: 'Failed to get response after a period of time'
}


class APIServerException(Exception):

    def __init__(self, status_code: int, body: bytes, headers: dict | None = None):
        self.status_code = status_code
        self.body = body
        self.headers = headers or {}
        if 'content-type' not in self.headers:
            self.headers['content-type'] = 'application/json'


================================================
FILE: lmdeploy/serve/utils/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/serve/utils/server_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/server_utils.py
import hashlib
import secrets
from collections.abc import Awaitable

from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.datastructures import URL, Headers
from starlette.types import ASGIApp, Receive, Scope, Send


def validate_json_request(raw_request: Request):
    content_type = raw_request.headers.get('content-type', '').lower()
    media_type = content_type.split(';', maxsplit=1)[0]
    if media_type != 'application/json':
        raise RequestValidationError(errors=["Unsupported Media Type: Only 'application/json' is allowed"])


class AuthenticationMiddleware:
    """Pure ASGI middleware that authenticates each request by checking if the
    Authorization Bearer token exists and equals anyof "{api_key}".

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
        self.app = app
        self.api_tokens = [hashlib.sha256(t.encode('utf-8')).digest() for t in tokens]
        # Path prefixes that bypass authentication
        self.skip_prefixes = [
            '/health',  # Health check endpoints
            '/docs',  # Swagger UI documentation
            '/redoc',  # ReDoc documentation
            '/nodes',  # Endpoints about node operation between proxy and api_server
        ]

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get('Authorization')
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(' ')
        if scheme.lower() != 'bearer':
            return False

        param_hash = hashlib.sha256(param.encode('utf-8')).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match

    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope['type'] not in ('http', 'websocket'):
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        if scope['type'] == 'http' and scope['method'] == 'OPTIONS':
            return self.app(scope, receive, send)

        root_path = scope.get('root_path', '')
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        if not any(url_path.startswith(path) for path in self.skip_prefixes) and not self.verify_token(headers):
            response = JSONResponse(content={'error': 'Unauthorized'}, status_code=401)
            return response(scope, receive, send)
        return self.app(scope, receive, send)


================================================
FILE: lmdeploy/tokenizer.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from collections import deque
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union

from lmdeploy.utils import get_logger

# this file will be copied to triton server, make sure all
# importing are starting from the package root lmdeploy


@dataclass
class DetokenizeState:
    """A state collection of incrementally detekenization.

    Args:
        ids_offset (int): offset to all input ids. In LMDeploy, the output
            ids length is not one by one. It could be random by random.
        prev_tokens (List[str] | None): for incrementally decoding.
            Default to None, which means the first round.
        prefix_offset (int): the start index of tokens to be converted to
            string (prev + new tokens). Default to 0 for the first round.
        read_offset (int): the end index of tokens to be converted to
            string (prev token). Default to 0 for the first round.
    """
    ids_offset: int = 0
    prev_tokens: Optional[List[str]] = None
    prefix_offset: int = 0
    read_offset: int = 0

    def as_tuple(self) -> Tuple:
        """Return a tuple of states."""
        return (self.ids_offset, self.prev_tokens, self.prefix_offset, self.read_offset)


class HuggingFaceTokenizer:
    """A wrapper of transformers' AutoTokenizer.

    Args:
        model_dir (str): the directory of the tokenizer model
    """

    def __init__(self, model_dir: str):
        self._check_transformers_version(model_dir)
        from transformers import AutoTokenizer
        self.logger = get_logger('lmdeploy')
        self.model = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
        self._prefix_space_tokens = None

        if self.model.eos_token_id is None:
            generation_config_file = osp.join(model_dir, 'generation_config.json')
            if osp.exists(generation_config_file):
                with open(generation_config_file, 'r') as f:
                    cfg = json.load(f)
                    self.model.eos_token_id = cfg['eos_token_id']
            elif hasattr(self.model, 'eod_id'):  # Qwen remote
                self.model.eos_token_id = self.model.eod_id

        # for stop words
        self._vocab_size_with_added: int = None
        self._maybe_decode_bytes: bool = None
        # TODO maybe lack a constant.py
        self._indexes_tokens_deque = deque(maxlen=10)
        self.max_indexes_num = 5
        self.token2id = {}

    def _check_transformers_version(self, model_dir: str):
        import transformers
        from packaging import version

        from lmdeploy.archs import get_model_arch

        logger = get_logger('lmdeploy')

        current_transformers_version = version.parse(transformers.__version__)
        cfg = get_model_arch(model_dir)[1]
        cfg_ver = getattr(cfg, 'transformers_version', None)
        if cfg_ver is None:
            llm_config = getattr(cfg, 'llm_config', None)
            if llm_config:
                cfg_ver = getattr(llm_config, 'transformers_version', None)
        if cfg_ver is None:
            return
        required_transformers_version = version.parse(cfg_ver)
        if current_transformers_version < required_transformers_version:
            logger.warning(
                f'The current version of `transformers` is transformers=={current_transformers_version}, '  # noqa: E501
                f'which is lower than the required version transformers=={required_transformers_version}. '  # noqa: E501
                'Please upgrade to the required version.')

    def get_vocab(self):
        """Get vocab."""
        return self.model.get_vocab()

    @property
    def vocab_size(self):
        """Vocabulary size."""
        return self.model.vocab_size

    @property
    def vocab_size_with_added(self):
        """Vocabulary size with added vocab."""
        if self._vocab_size_with_added is not None:
            return self._vocab_size_with_added
        self._vocab_size_with_added = len(self.model.get_vocab())
        return self._vocab_size_with_added

    @property
    def bos_token_id(self):
        """Begin of the sentence token id."""
        return self.model.bos_token_id

    @property
    def eos_token_id(self):
        """End of the sentence token id."""
        return self.model.eos_token_id

    @property
    def prefix_space_tokens(self):
        """Tokens without prefix space."""
        if self._prefix_space_tokens is None:
            vocab = self.model.convert_ids_to_tokens(list(range(self.vocab_size)))
            self._prefix_space_tokens = {
                i
                for i, tok in enumerate(vocab) if tok.startswith('▁' if isinstance(tok, str) else b' ')
            }
        return self._prefix_space_tokens

    def _maybe_add_prefix_space(self, tokens: List[int], decoded: str):
        """Maybe add prefix space for incremental decoding."""
        if len(tokens) and not decoded.startswith(' ') and\
                tokens[0] in self.prefix_space_tokens:
            return ' ' + decoded
        else:
            return decoded

    @property
    def maybe_decode_bytes(self):
        """Check if self.model.convert_ids_to_tokens return not a str value."""
        if self._maybe_decode_bytes is None:
            self._maybe_decode_bytes = False
            vocab = self.model.convert_ids_to_tokens(list(range(self.vocab_size)))
            for tok in vocab:
                if not isinstance(tok, str):
                    self._maybe_decode_bytes = True
                    break
        return self._maybe_decode_bytes

    def indexes_containing_token(self, token: str):
        """Return all the possible indexes, whose decoding output may contain
        the input token."""
        # traversing vocab is time consuming, can not be accelerated with
        # multi threads (computation) or multi process (can't pickle tokenizer)
        # so, we maintain latest 10 stop words and return directly if matched
        for _token, _indexes in self._indexes_tokens_deque:
            if token == _token:
                return _indexes

        if self.token2id == {}:
            # decode is slower than convert_ids_to_tokens
            if self.maybe_decode_bytes:
                for i in range(self.vocab_size):
                    try:
                        self.token2id[self.model.decode(i)] = i
                    except:  # noqa: E722
                        # some tokens just can't be decoded by `decode`
                        pass
            else:
                self.token2id = {self.model.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        if token == ' ':  # ' ' is special
            token = '▁'
        indexes = [i for _token, i in self.token2id.items() if token in _token]
        if len(indexes) > self.max_indexes_num:
            # multiple id decode to same token
            indexes = [i for i in indexes if self.decode([i]) == token]
            indexes = indexes[:self.max_indexes_num]
            self.logger.warning(f'There are too many(>{self.max_indexes_num}) possible '
                                f'indexes may decoding {token}, we will use {indexes} only')
        # there might be token id that exceeds self.vocab_size
        if len(indexes) == 0:
            indexes = self.encode(token, False)
            if len(indexes) != 1:
                self.logger.warning(f'The token {token}, its length of indexes {indexes} is '
                                    'not 1. Currently, it can not be used as stop words')
                indexes = []
        self._indexes_tokens_deque.append((token, indexes))
        return indexes

    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):
        """Tokenize a prompt.

        Args:
            s (str): a prompt
            add_bos (bool): Whether to add `bos` token id when encoding
                the prompt
            add_special_tokens (bool): Whether or not to add special tokens
                when encoding the prompt
        Returns:
            list[int]: token ids
        """
        encoded = self.model.encode(s, add_special_tokens=add_special_tokens, **kwargs)
        if not add_bos:
            # in the middle of a session
            if len(encoded) and encoded[0] == self.bos_token_id:
                encoded = encoded[1:]
        return encoded

    def decode(self, t: Sequence[int], offset: Optional[int] = None, skip_special_tokens: bool = True):
        """De-tokenize.

        Args:
            t (List[int]): a list of token ids
            offset (int): for incrementally decoding. Default to None, which
                means not applied.
            skip_special_tokens (bool): Whether or not to remove special
                tokens in the decoding.
        Returns:
            str: text of decoding tokens
        """
        t = t[offset:]
        out_string = self.model.decode(t, skip_special_tokens=skip_special_tokens)
        if offset:
            logger = get_logger('lmdeploy')
            logger.warning('For incrementally detokenization, please try '
                           'detokenize_incrementally function instead.')
            out_string = self._maybe_add_prefix_space(t, out_string)
        return out_string

    @staticmethod
    def _convert_tokens_to_string_with_added_encoders(
        tokenizer,
        output_tokens: List[str],
        skip_special_tokens: bool,
        spaces_between_special_tokens: bool,
    ) -> str:
        if tokenizer.is_fast or not tokenizer.get_added_vocab():
            return tokenizer.convert_tokens_to_string(output_tokens)
        # Adapted from
        # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
        sub_texts = []
        current_sub_text = []
        all_special_tokens = set(tokenizer.all_special_tokens)
        for token in output_tokens:
            if skip_special_tokens and token in all_special_tokens:
                continue
            if token in tokenizer.get_added_vocab():
                if current_sub_text:
                    sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
                    sub_texts.append(sub_text)
                    current_sub_text = []
                sub_texts.append(token)
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
            sub_texts.append(sub_text)
        if spaces_between_special_tokens:
            return ' '.join(sub_texts)
        else:
            return ''.join(sub_texts)

    # Based on
    # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
    def detokenize_incrementally(self,
                                 all_input_ids: Sequence[int],
                                 state: DetokenizeState,
                                 skip_special_tokens: bool = True,
                                 spaces_between_special_tokens: bool = True):
        """Incrementally detokenize the input indexes.

        Args:
            all_input_ids (List[int]): a list of token ids. Expected to be
                different sections of a long sequence.
            state (DetokenizeState): an instance of DetokenizeState. Consists
                of incrementally decoding states.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be True.
            spaces_between_special_tokens (bool): Whether or not to add spaces
                between special tokens. Default to be True.
        Returns:
            str: decoding output string of the current round.
            state (DetokenizeState): an instance of DetokenizeState. Consists
                of incrementally decoding states.
        """
        tokenizer = self.model
        ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()
        # This is the first iteration for this sequence
        new_tokens = tokenizer.convert_ids_to_tokens(all_input_ids[ids_offset:],
                                                     skip_special_tokens=skip_special_tokens)
        # `convert_ids_to_tokens` returns None for out-of-range token_id
        new_tokens = new_tokens or []
        new_tokens = [x for x in new_tokens if x is not None] if None in new_tokens else new_tokens
        if prev_tokens is None:
            # Please notice that in VLLM, indexes are detokenized one by one
            # while in LMDeploy, every turn, the detokenized indexes length
            # can be different.
            prev_tokens = tokenizer.convert_ids_to_tokens(all_input_ids[:ids_offset],
                                                          skip_special_tokens=skip_special_tokens)
            # `convert_ids_to_tokens` returns None for out-of-range token_id
            prev_tokens = prev_tokens or []
            prev_tokens = [x for x in prev_tokens if x is not None] if None in prev_tokens else prev_tokens
            read_offset = len(prev_tokens)
            if skip_special_tokens and new_tokens and new_tokens[0] in tokenizer.all_special_ids:
                read_offset = read_offset + 1  # skip special token

        output_tokens = prev_tokens + new_tokens
        prev_tokens += new_tokens
        prefix_text = self._convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )
        new_text = self._convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:],
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )

        # update state and get final decoded output
        if len(new_text) > len(prefix_text) and not new_text.endswith('�'):
            # utf-8 char at the end means it's a potential unfinished byte
            # sequence from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            prefix_offset = read_offset
            read_offset = len(output_tokens)
            new_text = new_text[len(prefix_text):]
        else:
            new_text = ''

        return new_text, DetokenizeState(len(all_input_ids), prev_tokens, prefix_offset, read_offset)

    def __call__(self, s: Union[str, Sequence[str]]):
        """Tokenize prompts.

        Args:
            s (str): prompts
        Returns:
            list[int]: token ids
        """
        add_special_tokens = False
        return self.model(s, add_special_tokens=add_special_tokens)


class ChatGLM4Tokenizer(HuggingFaceTokenizer):
    """Tokenizer of GLM4."""

    def __init__(self, model_path):
        super(ChatGLM4Tokenizer, self).__init__(model_path)
        original_pad = self.model._pad

        def __pad(*args, **kwargs):
            if 'padding_side' in kwargs:
                kwargs.pop('padding_side')
            return original_pad(*args, **kwargs)

        # fix for transformers>4.45.0
        self.model._pad = __pad

    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):
        """Tokenize a prompt."""
        # ChtGLM4Tokenizer hardcode `add_speical_tokens=False` when tokenizing
        # a prompt. Refer to https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/tokenization_chatglm.py#L227 # noqa E501
        return super(ChatGLM4Tokenizer, self).encode(s, add_bos, add_special_tokens=False, **kwargs)


class ChatGLMTokenizer(HuggingFaceTokenizer):
    """Tokenizer of GLM2."""

    def __init__(self, model_path):
        super(ChatGLMTokenizer, self).__init__(model_path)
        original_pad = self.model._pad

        def __pad(*args, **kwargs):
            if 'padding_side' in kwargs:
                kwargs.pop('padding_side')
            return original_pad(*args, **kwargs)

        # fix for transformers>4.45.0
        self.model._pad = __pad


class GptOssTokenizer(HuggingFaceTokenizer):
    """Tokenizer of GPT-OSS."""

    def __init__(self, model_dir: str):
        super(GptOssTokenizer, self).__init__(model_dir)
        from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding
        encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.role = Role.ASSISTANT
        self.parser = partial(StreamableParser, encoding, role=Role.ASSISTANT)

    def detokenize_incrementally(self,
                                 all_input_ids: Sequence[int],
                                 state: DetokenizeState,
                                 skip_special_tokens: bool = True,
                                 spaces_between_special_tokens: bool = True):
        if not hasattr(state, 'stream'):
            state.stream = self.parser()

        response = ''
        stream = state.stream
        for token_id in all_input_ids[state.ids_offset:]:
            stream.process(token_id)
            if stream.current_channel in ['final', 'analysis'] and stream.current_role == self.role:
                response += stream.last_content_delta or ''

        state.ids_offset = len(all_input_ids)
        return response, state


class Tokenizer:
    """Tokenize prompts or de-tokenize tokens into texts.

    Args:
        model_path (str): the path of the tokenizer model
    """

    def __init__(self, model_path: str):
        from transformers import AutoConfig, PretrainedConfig
        try:
            model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        except Exception as e:  # noqa
            model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
        is_gpt_oss = getattr(model_cfg, 'model_type', '') == 'gpt_oss'
        from transformers.models.auto.tokenization_auto import get_tokenizer_config
        tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True)
        config_tokenizer_class = tokenizer_config.get('tokenizer_class')
        if config_tokenizer_class == 'ChatGLM4Tokenizer':
            self.model = ChatGLM4Tokenizer(model_path)
        elif config_tokenizer_class == 'ChatGLMTokenizer':
            self.model = ChatGLMTokenizer(model_path)
        elif is_gpt_oss:
            self.model = GptOssTokenizer(model_path)
        else:
            self.model = HuggingFaceTokenizer(model_path)
        self.logger = get_logger('lmdeploy')

    @property
    def vocab_size(self):
        """Vocabulary size."""
        return self.model.vocab_size

    @property
    def bos_token_id(self):
        """Begin of the sentence token id."""
        return self.model.bos_token_id

    @property
    def eos_token_id(self):
        """End of the sentence token id."""
        return self.model.eos_token_id

    def get_vocab(self):
        """Get vocab."""
        return self.model.get_vocab()

    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):
        """Tokenize a prompt.

        Args:
            s (str): a prompt
            add_bos (bool): Whether to add `bos` token id when encoding
                the prompt
            add_special_tokens (bool): Whether or not to add special tokens
                when encoding the prompt
        Returns:
            list[int]: token ids
        """
        encoded = self.model.encode(s, add_bos, add_special_tokens, **kwargs)
        if encoded[:2] == [self.bos_token_id] * 2:
            self.logger.warning(f'Detected duplicate bos token {self.bos_token_id} in prompt, '
                                'this will likely reduce response quality, one of them will be'
                                'removed')
            encoded = encoded[1:]
        return encoded

    def decode(
        self,
        t: Sequence[int],
        offset: Optional[int] = None,
        skip_special_tokens: bool = True,
    ):
        """De-tokenize.

        Args:
            t (List[int]): a list of token ids
            offset (int): for incrementally decoding. Default to None, which
                means not applied.
            skip_special_tokens (bool): Whether or not to remove special
                tokens in the decoding.
        Returns:
            str: text of decoding tokens
        """
        return self.model.decode(t, offset, skip_special_tokens)

    def detokenize_incrementally(self,
                                 all_input_ids: Sequence[int],
                                 state: DetokenizeState,
                                 skip_special_tokens: bool = True,
                                 spaces_between_special_tokens: bool = True):
        """Incrementally detokenize the input indexes.

        Args:
            all_input_ids (List[int]): a list of token ids. Expected to be
                different sections of a long sequence.
            state (DetokenizeState): an instance of DetokenizeState. Consists
                of incrementally decoding states.
            skip_special_tokens (bool): Whether or not to remove special tokens
                in the decoding. Default to be True.
            spaces_between_special_tokens (bool): Whether or not to add spaces
                between special tokens. Default to be True.
        Returns:
            str: decoding output string of the current round.
            state (DetokenizeState): an instance of DetokenizeState. Consists
                of incrementally decoding states.
        """
        return self.model.detokenize_incrementally(all_input_ids,
                                                   state=state,
                                                   skip_special_tokens=skip_special_tokens,
                                                   spaces_between_special_tokens=spaces_between_special_tokens)

    def __call__(self, s: Union[str, Sequence[str]]):
        """Tokenize prompts.

        Args:
            s (str): prompts
        Returns:
            list[int]: token ids
        """
        return self.model(s)

    def indexes_containing_token(self, token):
        """Return all the possible indexes, whose decoding output may contain
        the input token."""
        encoded = self.encode(token, add_bos=False)
        if len(encoded) > 1:
            self.logger.warning(f'The token {token}, its length of indexes {encoded} is over '
                                'than 1. Currently, it can not be used as stop words')
            return []
        return self.model.indexes_containing_token(token)


================================================
FILE: lmdeploy/turbomind/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


def bootstrap():
    import os
    import sys

    has_turbomind = False
    pwd = os.path.dirname(__file__)
    if os.path.exists(os.path.join(pwd, '..', 'lib')):
        has_turbomind = True
    if os.name == 'nt' and has_turbomind:
        if sys.version_info[:2] >= (3, 8):
            CUDA_PATH = os.getenv('CUDA_PATH')
            assert CUDA_PATH is not None, 'Can not find $env:CUDA_PATH'
            dll_path = os.path.join(CUDA_PATH, 'bin')
            print(f'Add dll path {dll_path}, please note cuda version '
                  'should >= 11.3 when compiled with cuda 11')
            os.add_dll_directory(dll_path)


bootstrap()

from .turbomind import TurboMind, update_parallel_config  # noqa: E402

__all__ = ['TurboMind', 'update_parallel_config']


================================================
FILE: lmdeploy/turbomind/deploy/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/turbomind/deploy/config.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import json
from dataclasses import asdict, field, fields
from typing import List

# use pydantic.dataclasses.dataclass to check data type
from pydantic.dataclasses import dataclass

from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def config_from_dict(cls, env):
    """Initiate an instance of a config class from a dict."""
    params = inspect.signature(cls).parameters
    used = {k: v for k, v in env.items() if k in params and v is not None}

    def _remove_none(d: dict):
        for k, v in d.items():
            if isinstance(v, dict):
                d[k] = _remove_none(v)
        return {k: v for k, v in d.items() if v is not None}

    used = _remove_none(used)
    return cls(**used)


def config_to_dict(config):
    """Export config to a dict."""
    if not config:
        return dict()
    assert isinstance(config, (ModelConfig, AttentionConfig, LoraConfig)), \
        f'A dataclass is expected, but got {type(config)}'

    return asdict(config)


@dataclass
class ModelConfig:
    model_name: str = ''
    chat_template: str = ''
    model_arch: str = None
    head_num: int = None
    kv_head_num: int = None
    hidden_units: int = None
    vocab_size: int = None
    # Turbomind used to assume token_embedding and lm_head has the same size
    # at vocab dim, i.e. `vocab_size`
    # But in molmo, embedding.shape is [vocab_size + 128, hidden_units]
    # while lm_head shape is [hidden_units, vocab_size].
    # Therefore, we add a new attr "embedding_size" to represent the vocab dim
    # of token_embedding
    embedding_size: int = 0
    num_layer: int = None
    inter_size: List[int] = None
    norm_eps: float = None
    attn_bias: int = 0
    mlp_bias: bool = False
    window_size: List[int] = field(default_factory=list)
    attn_sink: bool = False
    qk_norm: bool = False
    size_per_head: int = 128
    group_size: int = 32
    data_type: str = None
    weight_type: str = None
    expert_weight_type: str = None
    ffn_weight_type: str = None
    session_len: int = None
    attn_tp_size: int = 1
    attn_cp_size: int = 1
    mlp_tp_size: int = 1
    model_format: str = 'hf'
    expert_num: List[int] = ()
    expert_router_bias: bool = False
    expert_inter_size: int = 0
    experts_per_token: int = 0
    activation_type: str = ''
    moe_shared_gate: bool = False
    norm_topk_prob: bool = False
    routed_scale: float = 1.0
    topk_group: int = 1
    topk_method: str = 'greedy'
    moe_group_num: int = 1
    scoring_func: str = 'softmax'
    router_n_groups: int = -1
    # MLA
    q_lora_rank: int = 0
    kv_lora_rank: int = 0
    qk_rope_dim: int = 0
    v_head_dim: int = 0
    # Qwen 3.5
    layer_types: List[str] = field(default_factory=list)
    linear_key_head_dim: int = 0
    linear_value_head_dim: int = 0
    linear_conv_kernel_dim: int = 0
    linear_num_key_heads: int = 0
    linear_num_value_heads: int = 0
    attn_output_gate: bool = False
    # Per-layer expert weight type override: layer indices whose
    # MoE experts are unquantized (fp16) despite expert_weight_type=int4.
    # Populated from modules_to_not_convert patterns like 'model.layers.0.'.
    unquantized_expert_layers: List[int] = field(default_factory=list)
    # tuning
    tune_layer_num: int = 1

    def verify(self):
        invalid = {}
        for k, v in self.__dict__.items():
            if v is None:
                invalid[k] = v
        assert not invalid, f'incomplete model config: {invalid}'


@dataclass
class RopeParam:
    type: str
    base: float
    dim: int
    factor: float = 1.0
    max_position_embeddings: int = None
    attention_factor: float = 1.0
    beta_fast: float = 32
    beta_slow: float = 1
    low_freq_factor: float = None
    high_freq_factor: float = None
    original_max_position_embeddings: int = None
    mrope_section: List[int] = None


@dataclass
class AttentionConfig:
    softmax_scale: float = 0
    cache_block_seq_len: int = 64
    use_logn_attn: int = 0
    max_position_embeddings: int = 0
    rope_param: RopeParam = None


@dataclass
class LoraConfig:
    lora_policy: str = ''
    lora_r: int = 0
    lora_scale: float = 0.0
    lora_max_wo_r: int = 0
    lora_rank_pattern: str = ''
    lora_scale_pattern: str = ''


@dataclass
class TurbomindModelConfig:
    """Config for turbomind model."""
    model_config: ModelConfig = None
    attention_config: AttentionConfig = None
    lora_config: LoraConfig = None

    def update_from_engine_config(self, config: TurbomindEngineConfig):
        """Update the attributes of this instance with the attributes from
        TurbomindEngineConfig.

        Args:
            config (TurbomindEngineConfig): The turbomind engine config
        """
        if config is None:
            return
        for key, value in asdict(config).items():
            if not value:
                continue

            if hasattr(self.model_config, key):
                setattr(self.model_config, key, value)
            if hasattr(self.attention_config, key):
                setattr(self.attention_config, key, value)

        # update from hf_overrides
        if hasattr(config, 'hf_overrides') and config.hf_overrides:
            hf_overrides = config.hf_overrides

            if hf_overrides.get('rope_scaling'):
                override_params = hf_overrides.get('rope_scaling')

                rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
                rope_param.type = override_params.get('rope_type', '')
                if rope_param.type == 'yarn' and 'original_max_position_embeddings' in override_params:
                    rope_param.factor = self.attention_config.max_position_embeddings / override_params[
                        'original_max_position_embeddings']
                    rope_param.max_position_embeddings = override_params['original_max_position_embeddings']
                else:
                    rope_param.factor = override_params.get('factor', 1.0)
                    rope_param.max_position_embeddings = override_params.get('original_max_position_embeddings', None)

                self.attention_config.rope_param = rope_param
            logger.warning(f'Overriding HF config with {hf_overrides}')

        # use dynamic ntk
        if config.rope_scaling_factor:
            # some ut will create empty RopeParam, will check base/dim in src code
            rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
            rope_param.type = 'dynamic'
            rope_param.factor = config.rope_scaling_factor
            rope_param.max_position_embeddings = self.attention_config.max_position_embeddings

            self.attention_config.rope_param = rope_param
            logger.warning(
                '`--rope-scaling-factor` will be removed in a future release. Please instead use `--hf-overrides`.')

    @classmethod
    def from_dict(cls, config: dict = {}):
        """Construct TurbomindModelConfig instance from config in a dict."""
        _cfg = {field.name: config.get(field.name, {}) for field in fields(TurbomindModelConfig)}

        return TurbomindModelConfig(model_config=config_from_dict(ModelConfig, _cfg['model_config']),
                                    attention_config=config_from_dict(AttentionConfig, _cfg['attention_config']),
                                    lora_config=config_from_dict(LoraConfig, _cfg['lora_config']))

    def to_dict(self):
        """Export to a dict."""
        return dict(model_config=config_to_dict(self.model_config),
                    attention_config=config_to_dict(self.attention_config),
                    lora_config=config_to_dict(self.lora_config))

    @property
    def session_len(self):
        return self.model_config.session_len

    @property
    def weight_type(self):
        return self.model_config.weight_type

    @property
    def group_size(self):
        return self.model_config.group_size

    @property
    def vocab_size(self):
        return self.model_config.vocab_size

    def __str__(self):
        return json.dumps(self.to_dict(), indent=2)


================================================
FILE: lmdeploy/turbomind/deploy/converter.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.utils import get_logger

from ...utils import _get_and_verify_max_len, is_bf16_supported
from ..supported_models import SUPPORTED_ARCHS
from .config import TurbomindModelConfig
from .module import Transformer
from .policy import get_input_policy
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, BaseOutputModel

SUPPORTED_FORMATS = ['hf', 'awq', 'gptq', 'fp8', None]
logger = get_logger('lmdeploy')


def get_input_model_registered_name(model_path: str, model_format: str):
    """Get the registered name of a model. The name will be used to access the
    INPUT_MODELS registry.

    Args:
        model_path (str): the path of the input model
        model_format (str): the format of the model, which can be one of
            ['hf', 'awq', 'gptq']
    """
    arch = get_model_arch(model_path)[0]
    register_name = SUPPORTED_ARCHS[arch]
    return register_name


def get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int):
    """Get the registered name of the turbomind model and its configuration
    according to the input model path, format and user-input config. The name
    will be used to access the OUTPUT_MODELS registry.

    Args:
        model_path (str): the path of the input model
        model_format (str): the format of the model, which can be one of
            ['hf', 'awq', 'gptq']
        dtype (str): the data type of the model's weights and activations
        group_size (int): the size of group used by awq model
    """
    register_name = 'tm'

    has_bf16 = is_bf16_supported()

    model_arch, model_config = get_model_arch(model_path)

    # infer dtype from device and model config
    if dtype == 'auto':
        # pick dtype by device as default
        dtype = 'bfloat16' if has_bf16 else 'float16'
        # dtype from model (prefer `dtype` over deprecated `torch_dtype`)
        torch_dtype = getattr(model_config, 'dtype', None)
        if torch_dtype is None:
            torch_dtype = getattr(model_config, 'torch_dtype', None)
        if not torch_dtype:
            if model_arch in ['QWenLMHeadModel', 'GptOssForCausalLM']:
                torch_dtype = torch.bfloat16
        TORCH_DTYPE_MAP = {torch.bfloat16: 'bfloat16', torch.float16: 'float16'}
        dtype = TORCH_DTYPE_MAP.get(torch_dtype, dtype)

    if dtype == 'bfloat16' and not has_bf16:
        logger.warning('data type fallback to float16 since '
                       'torch.cuda.is_bf16_supported is False')
        dtype = 'float16'

    weight_type = dtype

    config = TurbomindModelConfig.from_dict()

    session_len = _get_and_verify_max_len(model_config, None)

    if model_format in ['awq', 'gptq', 'compressed-tensors']:
        weight_type = 'int4'
        dtype = 'float16'  # force float16 for int4 quantized weights
        group_size = 128 if group_size == 0 else group_size
        if model_format == 'compressed-tensors':
            model_format = 'awq'
    elif model_format == 'fp8':
        weight_type = 'fp8'
        group_size = 128
    elif model_format == 'mxfp4':
        weight_type = 'e2m1'
        group_size = 32

    expert_weight_type = weight_type

    # ONLY experts are in mxfp4
    if model_arch == 'GptOssForCausalLM':
        weight_type = dtype

    # Three weight types control allocation for mixed quantization:
    #   weight_type        - attention weights
    #   ffn_weight_type    - dense FFN / shared expert weights
    #   expert_weight_type - MoE routed expert weights
    #
    # The assignment order matters:
    #   1. expert_weight_type = original weight_type (before any overrides)
    #   2. GptOss override:   weight_type -> dtype  (attn + shared experts are fp16)
    #   3. ffn_weight_type  = weight_type           (captures post-GptOss value)
    #   4. Mixed AWQ override: weight_type -> dtype  (only attn becomes fp16)
    #
    #                  weight_type   ffn_weight_type   expert_weight_type
    #  Pure fp16       float16       float16           float16
    #  Full AWQ        int4          int4              int4
    #  Mixed AWQ       float16       int4              int4
    #  GptOss mxfp4    bfloat16      bfloat16          e2m1
    ffn_weight_type = weight_type

    # When attention weights are not quantized (e.g. AWQ with self_attn in
    # modules_to_not_convert), weight_type becomes fp16 for attention.
    # ffn_weight_type and expert_weight_type retain int4.
    if model_format in ['awq', 'gptq'] and weight_type != dtype:
        quant_config = getattr(model_config, 'quantization_config', None)
        if quant_config is None:
            quant_config = {}
        if isinstance(quant_config, dict):
            modules_to_not_convert = quant_config.get('modules_to_not_convert') or []
        else:
            modules_to_not_convert = getattr(quant_config, 'modules_to_not_convert', None) or []
        if any('self_attn' in m for m in modules_to_not_convert):
            weight_type = dtype
        if any('shared_expert' in m for m in modules_to_not_convert):
            ffn_weight_type = dtype
        # Detect per-layer exclusions like 'model.layers.0.' which mean
        # ALL weights in that layer (including MoE experts) are fp16.
        import re as _re
        unquantized_expert_layers = []
        for m in modules_to_not_convert:
            _m = _re.match(r'model\.layers\.(\d+)\.?$', m)
            if _m:
                unquantized_expert_layers.append(int(_m.group(1)))
        config.model_config.unquantized_expert_layers = unquantized_expert_layers

    config.model_config.model_arch = model_arch
    config.model_config.data_type = dtype
    config.model_config.weight_type = weight_type
    config.model_config.expert_weight_type = expert_weight_type
    config.model_config.ffn_weight_type = ffn_weight_type
    config.model_config.model_format = model_format
    config.model_config.group_size = group_size
    config.model_config.session_len = session_len

    return register_name, config


def get_tm_model(model_path,
                 model_name,
                 chat_template_name,
                 engine_config: TurbomindEngineConfig,
                 group_size: int = None,
                 out_dir: str = None) -> BaseOutputModel:
    """Create turbomind model.

    Args:
        model_path (str): the path of the input model, which is supposed
            to be a local path, or huggingface hub repo_id, or modelscope
            hub repo_id
        model_name (str): user customized model name
        chat_template_name (str): the name of the chat template of
            the input model
        engine_config(TurbomindEngineConfig): user input engine config
        group_size(int): refers to the group_size if the input model
            is a w4a16(awq or gptq) quantized model
        out_dir(str): the output directory where to save to turbomind model.
            If it is None, the turbomind model won't be saved
    """
    _, cfg = get_model_arch(model_path)
    quant_config = search_nested_config(cfg.to_dict(), 'quantization_config')
    mixed_awq = False
    if quant_config:
        quant_method = quant_config.get('quant_method')
        _group_size = int(quant_config.get('group_size', 0))
        version = quant_config.get('version')
        assert engine_config.model_format is None or engine_config.model_format == quant_method, (
            f'mismatched quant method: user input "{engine_config.model_format}" '
            f'vs model quant_config "{quant_method}"')
        assert not group_size or group_size == _group_size, (f'mismatched quant group size: user input "{group_size}" '
                                                             f'vs model quant_config "{_group_size}"')

        if quant_method == 'awq':
            assert version == 'gemm', f'unsupported quant config: {quant_config}'
            modules_to_not_convert = quant_config.get('modules_to_not_convert') or []
            if any('self_attn' in name for name in modules_to_not_convert):
                mixed_awq = True
        elif quant_method == 'gptq':
            assert not quant_config.get('desc_act', False) and quant_config.get(
                'sym', True), f'unsupported quant config: {quant_config}'
        elif quant_method == 'fp8':
            pass
        elif quant_method == 'mxfp4':
            _group_size = 32
        elif quant_method == 'compressed-tensors':
            _format = quant_config['config_groups']['group_0']['format']
            assert _format == 'pack-quantized', ('compressed-tennsors only supports pack-quantized format, '
                                                 f'but got {_format}')
            _weights = quant_config['config_groups']['group_0']['weights']
            _group_size = _weights['group_size']
            _num_bits = _weights['num_bits']
            _type = _weights['type']
            assert _num_bits == 4 and _type == 'int', ('pack-quantized requires 4-bit int, '
                                                       f'but got {_num_bits}-bit {_type}')
        else:
            assert 0, f'unsupported quant_config: {quant_config}'

        engine_config.model_format = quant_method
        group_size = _group_size

    if engine_config.model_format in ['awq', 'gptq', 'compressed-tensors']:
        # Compatible to awq models that are quantized by lmdeploy (<=v0.3.0)
        if not group_size:
            group_size = 128
        assert group_size == 128, (f'model format is "{engine_config.model_format}" '
                                   f'but group_size is {group_size}. Currently, only 128 '
                                   'is supported')

    input_model_name = get_input_model_registered_name(model_path, engine_config.model_format)

    fp8_quant = (engine_config.model_format == 'fp8' and not quant_config)
    input_policy = get_input_policy(engine_config.model_format)
    input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path,
                                                     tokenizer_path=model_path,
                                                     input_policy=input_policy,
                                                     fp8_quant=fp8_quant)

    output_model_name, tm_cfg = get_output_model_registered_name_and_config(model_path=model_path,
                                                                            model_format=engine_config.model_format,
                                                                            dtype=engine_config.dtype,
                                                                            group_size=group_size)

    if mixed_awq:
        # Mixed-precision AWQ: attention weights are fp16 (not quantized),
        # but expert weights remain as int4 AWQ for efficient inference.
        tm_cfg.model_config.weight_type = tm_cfg.model_config.data_type
        # expert_weight_type stays as 'int4' (set by get_output_model_registered_name_and_config)

    tm_cfg.model_config.chat_template = chat_template_name
    tm_cfg.model_config.model_name = model_name

    if engine_config.attn_tp_size is not None:
        tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size
    if engine_config.attn_cp_size is not None:
        tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size
    if engine_config.mlp_tp_size is not None:
        tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size

    output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model,
                                                        cfg=tm_cfg,
                                                        model_cls=Transformer,
                                                        out_dir=out_dir)

    return output_model


================================================
FILE: lmdeploy/turbomind/deploy/loader.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from glob import glob
from queue import Queue
from typing import Iterator, Tuple, Union

import torch
from safetensors import safe_open

# https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/modeling_utils.py#L372
WEIGHT_INDEX_NAME = 'pytorch_model.bin.index.json'
WEIGHT_PATTERN = 'pytorch_model*.bin'
SAFE_WEIGHT_INDEX_NAME = 'model.safetensors.index.json'
SAFE_WEIGHT_PATTERN = 'model*.safetensors'
EXTRA_WEIGHT_PATTERNS = ['*.pt', '*.bin']
EXTRA_SAFE_WEIGHT_PATTERN = '*.safetensors'


class BaseLoader(ABC):

    def __init__(self, model_path: str, pattern, mappings: list):
        self.model_path = model_path
        self.pattern = pattern
        self.item_count = defaultdict(int)
        self.mappings = mappings

    def get_index(self, index_name: str, file_pattern: str) -> Tuple[dict, list]:
        """Get shards and weight map (if possible) for the model."""
        get_path = partial(osp.join, self.model_path)
        shards = []
        if index_name:
            with open(get_path(index_name), 'r') as f:
                index = json.load(f)
            index = index['weight_map']
            shards = list(map(get_path, set(index.values())))
        else:
            index = {}
            shards = glob(get_path(file_pattern))
        if not shards:
            raise RuntimeError(f'failed to locate weight files for {self.model_path}')
        return sorted(shards), index

    def map_key(self, key: str):
        if self.mappings:
            k = str(key)
            for f in self.mappings:
                k = f(k)
            return k
        else:
            return key

    @abstractmethod
    def items(self) -> Iterator[Tuple[int, dict]]:
        pass


class SafetensorsLoader(BaseLoader):

    def __init__(self, model_path: str, pattern: str, mappings: list, index_name=None, file_pattern=None):
        super().__init__(model_path, pattern, mappings)
        self.shards, index = self.get_index(index_name, file_pattern)
        if not index:
            # there is no model.safetensors.index.json in the model_path,
            # read tensor form the safetensor file and update the index
            for shard in self.shards:
                filename = osp.basename(shard)
                with safe_open(shard, 'pt') as f:
                    index.update({k: filename for k in f.keys()})
        # self.index maps weight names to their corresponding safetensors file name
        self.index = index
        # count layer-wise parameters
        for k in index.keys():
            match = re.findall(self.pattern, k)
            if match:
                self.item_count[int(match[0])] += 1

    def items(self):
        params = defaultdict(dict)
        for shard in self.shards:
            with safe_open(shard, 'pt') as f:
                misc = []
                filename = osp.basename(shard)
                for k in f.keys():
                    # Filtering logic:
                    # - Exclude weights not found in the mapping
                    # - Exclude duplicated weights (present in multiple files)
                    if k not in self.index or self.index[k] != filename:
                        continue
                    match = re.findall(self.pattern, k)
                    if not match:
                        misc.append(k)
                    else:
                        idx = int(match[0])
                        param = params[idx]
                        param[self.map_key(k)] = f.get_tensor(k)
                        if len(param) == self.item_count[idx]:
                            yield (idx, params.pop(idx))
                if misc:
                    yield (-1, {k: f.get_tensor(k) for k in misc})
        assert not params


class PytorchLoader(BaseLoader):

    def __init__(self, model_path: str, pattern: str, mappings: list, index_name=None, file_pattern=None):
        super().__init__(model_path, pattern, mappings)
        self.shards, index = self.get_index(index_name, file_pattern)
        for k in index.keys():
            match = re.findall(self.pattern, k)
            if match:
                self.item_count[int(match[0])] += 1

    def items(self):
        params = defaultdict(dict)
        for shard in self.shards:
            misc = {}
            tmp = torch.load(shard, map_location='cpu', weights_only=True)
            for k, v in tmp.items():
                match = re.findall(self.pattern, k)
                if not match:
                    misc[k] = v
                else:
                    idx = int(match[0])
                    params[idx][k] = v
            del tmp
            if misc:
                yield (-1, misc)
                misc.clear()
            ready = []
            if self.item_count:
                for idx, param in params.items():
                    if len(param) == self.item_count[idx]:
                        ready.append(idx)
            else:
                ready = sorted(params.keys())[:-1]
            for idx in ready:
                yield (idx, params.pop(idx))
        idxs = sorted(params.keys())
        for idx in idxs:
            yield (idx, params.pop(idx))


class StateDictLoader:
    """This loader is used for `update_params`.

    Currently, the item in the queue should be full state dict of a decoder layer or the meta of the model (embedding,
    lm_head, norm).
    """

    def __init__(self, queue: Queue, pattern: str, mappings: list):
        self.que = queue
        self.pattern = pattern

    def items(self):
        for data in iter(self.que.get, None):
            # If data is state dict of a decoder layer, any key will match the pattern.
            # Otherwise, none of the keys will match the pattern.
            for k in data.keys():
                match = re.findall(self.pattern, k)
                break

            if not match:
                yield (-1, data)
            else:
                idx = int(match[0])
                yield (idx, data)

            torch.cuda.empty_cache()
            self.que.task_done()


def create_loader(model_path: Union[str, Queue], pattern: str, mappings: list) -> BaseLoader:
    args = (model_path, pattern, mappings)

    if isinstance(model_path, Queue):
        # used for `update_params`
        return StateDictLoader(*args)

    if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):
        return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME)

    if glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)):
        return SafetensorsLoader(*args, file_pattern=SAFE_WEIGHT_PATTERN)

    if osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)):
        return PytorchLoader(*args, index_name=WEIGHT_INDEX_NAME)

    if glob(osp.join(model_path, WEIGHT_PATTERN)):
        return PytorchLoader(*args, file_pattern=WEIGHT_PATTERN)

    # non-standard safetensors model (*.safetensors)
    if glob(osp.join(model_path, EXTRA_SAFE_WEIGHT_PATTERN)):
        return SafetensorsLoader(*args, file_pattern=EXTRA_SAFE_WEIGHT_PATTERN)

    # non-standard pytorch model (*.bin, *.pt)
    for p in EXTRA_WEIGHT_PATTERNS:
        if glob(osp.join(model_path, p)):
            return PytorchLoader(*args, file_pattern=p)

    raise RuntimeError(f'Failed to find valid loader for {model_path}')


================================================
FILE: lmdeploy/turbomind/deploy/module.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from functools import partial

import torch

from .parameter import get_params
from .source_model.base import BaseReader
from .target_model.base import BaseOutputModel


def permute_v2(x: torch.Tensor, size_per_head: int = 128):
    """
        Contract: x.size(-1) is output dims
    """

    assert x.size(-1) > 1

    output_dims = x.size(-1)
    head_num = output_dims // size_per_head

    return x.view(-1, head_num, 2, size_per_head // 2).transpose(2, 3).reshape(x.shape)


def permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int):
    """Permute only the first rotary_dim elements of each head.

    Used when partial_rotary_factor < 1.0: only the rotary portion needs interleaving for TurboMind's RoPE kernel
    layout.
    """
    assert x.size(-1) > 1
    assert rotary_dim % 2 == 0, f'rotary_dim must be even, got {rotary_dim}'
    assert rotary_dim <= size_per_head, f'rotary_dim ({rotary_dim}) must be <= size_per_head ({size_per_head})'
    output_dims = x.size(-1)
    assert output_dims % size_per_head == 0, (f'output_dims ({output_dims}) must be divisible by '
                                              f'size_per_head ({size_per_head})')
    head_num = output_dims // size_per_head
    orig_shape = x.shape
    if x.dim() == 1:
        x = x.unsqueeze(0)
    x = x.view(x.size(0), head_num, size_per_head)
    rotary = x[:, :, :rotary_dim]
    passthrough = x[:, :, rotary_dim:]
    # Interleave rotary part: [2, rotary_dim//2] -> [rotary_dim//2, 2]
    rotary = rotary.view(x.size(0), head_num, 2, rotary_dim // 2).transpose(2, 3).contiguous()
    rotary = rotary.view(x.size(0), head_num, rotary_dim)
    x = torch.cat([rotary, passthrough], dim=-1)
    return x.reshape(orig_shape)


def merge_qkv_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int):
    """
        Contract: x.size(-1) is output dims
    """

    def reshape(x):
        return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)

    qkv = torch.cat(tuple(map(reshape, (q, k, v))), dim=-1)

    qkv = qkv.view(-1, qkv.size(-1) * tp)
    if q.dim() == 1:
        qkv.squeeze_()

    return qkv


def merge_qkvg_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, tp: int):
    """Merge Q, K, V, and Gate with gate appended after V.

    Layout per tp-shard: [Q | K | V | Gate].
    """

    def reshape(x):
        return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)

    qkvg = torch.cat(tuple(map(reshape, (q, k, v, gate))), dim=-1)

    qkvg = qkvg.view(-1, qkvg.size(-1) * tp)
    if q.dim() == 1:
        qkvg.squeeze_()

    return qkvg


def transpose(x):
    return x.t() if x is not None else x


def pad_out_dims(x: torch.Tensor, dims: int):
    pad = dims - x.size(-1)
    assert pad >= 0
    return torch.nn.functional.pad(x, (0, pad), 'constant', 0)


def pad_in_dims(x: torch.Tensor, dims: int):
    if x.dim() == 1:  # 1-dim object does not have input dim (e.g. bias)
        return x
    pad = dims - x.size(0)
    assert x.dim() == 2
    assert pad >= 0
    return torch.nn.functional.pad(x, (0, 0, 0, pad), 'constant', 0)


# split out dims -> copy A, split-out-dims B (qkv, w1, w3)
# split  in dims -> split-in-dims A,  copy B (  o, w2)
def get_lora_flags(kind: str):
    return ('lora_a' in kind, 'lora_b' in kind)


class Module(ABC):

    def __init__(self, model: BaseOutputModel):
        self.model = model

    def __call__(self, *args, **kwargs):
        return self.apply(*args, **kwargs)

    @abstractmethod
    def apply(self, idx: int, r: BaseReader):
        pass


class LayerNorm(Module):

    def apply(self, i: int, r: BaseReader):
        attn_norm = r.attn_norm(i)
        ffn_norm = r.ffn_norm(i)
        self.model.save_split(attn_norm, f'layers.{i}.attention_norm.weight')
        self.model.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight')


class Ffn(Module):
    """
    requires:
        r.ffn(i, kind)
    """

    _ffn = 'layers.{0}.feed_forward.{1}.{2}'

    def __init__(self, model: BaseOutputModel):
        self.model = model
        self.tp = model.mlp_tp_size
        # inter_sizes in config are padded and may be different from what's
        # in the weights
        self.inter_size = model.model_config.inter_size
        self.group_size = max(1, model.model_config.group_size)

    def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=[], **kwargs):
        is_lora_a, is_lora_b = get_lora_flags(kind)
        w1, w2, w3 = map(transpose, w123)

        gs1 = self.group_size if 'w1' in apply_gs else 1
        w1 = pad_out_dims(w1, inter_size // gs1)

        gs3 = self.group_size if 'w3' in apply_gs else 1
        w3 = pad_out_dims(w3, inter_size // gs3)

        gs2 = self.group_size if 'w2' in apply_gs else 1
        w2 = pad_in_dims(w2, inter_size // gs2)

        w1, w2, w3 = map(pack_fn, (w1, w2, w3))
        self.model.save_split(w1, fmt.format(idx, 'w1', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)
        self.model.save_split(w3, fmt.format(idx, 'w3', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)
        self.model.save_split(w2, fmt.format(idx, 'w2', kind), split_dim=0, split_num=self.tp, copy=is_lora_b)

    def apply(self, i: int, r: BaseReader):
        if i >= len(self.inter_size) or not self.inter_size[i]:
            return
        keys = r.ffn(i, None)

        for e in get_params(keys):
            e(partial(self._export, self.inter_size[i], self._ffn), partial(r.ffn, i), i)


class MoeFfn(Ffn):
    """
    requires:
        r.moe_ffn_expert(e, i, kind)
        r.moe_ffn_gate(i)
        r.moe_ffn_shared_gate(i)
    """

    _moe_ffn_expert = 'layers.{0}.moe_ffn.experts.E.{1}.{2}'
    _moe_ffn_gate = 'layers.{0}.moe_ffn.gate.{1}'
    _moe_ffn_shared_gate = 'layers.{0}.moe_ffn.shared_gate.weight'

    def __init__(self, model: BaseOutputModel):
        super().__init__(model)
        self.expert_num = model.model_config.expert_num
        self.inter_size = model.model_config.expert_inter_size
        self.shared_gate = model.model_config.moe_shared_gate

    def apply(self, i: int, r: BaseReader):
        if i >= len(self.expert_num) or self.expert_num[i] == 0:
            return

        # Export expert weights with outer loop over experts (not params)
        # to ensure each expert's full weight set is grouped together
        for e in range(self.expert_num[i]):
            for p in get_params(r.moe_ffn_expert(), 1):
                fmt = self._moe_ffn_expert.replace('E', str(e))
                p(partial(self._export, self.inter_size, fmt), partial(r.moe_ffn_expert, e, i), i)

        # router
        gate = transpose(r.moe_ffn_gate(i, 'weight'))
        self.model.save_split(gate, self._moe_ffn_gate.format(i, 'weight'))
        bias = r.moe_ffn_gate(i, 'bias')
        if bias is not None:
            self.model.save_split(bias, self._moe_ffn_gate.format(i, 'bias'))

        # Export score_correction_bias for noaux_tc routing (GLM 4.7 Flash)
        correction_bias = getattr(r, 'moe_ffn_gate_correction_bias', None)
        if callable(correction_bias):
            correction = correction_bias(i)
            if correction is not None:
                self.model.save_split(correction, self._moe_ffn_gate.format(i, 'score_correction_bias'))

        if self.shared_gate:
            shared_gate = transpose(r.moe_ffn_shared_gate(i))
            self.model.save_split(shared_gate, self._moe_ffn_shared_gate.format(i))


class Attn(Module):
    """
    requires:
        r.attn(i, kind)
    """

    _attn = 'layers.{0}.attention.{1}.{2}'

    def __init__(self, model: BaseOutputModel):
        self.model = model
        self.tp = model.attn_tp_size
        self.head_dim = model.model_config.size_per_head
        self.attn_bias = model.model_config.attn_bias
        self.qk_norm = model.model_config.qk_norm
        self.attn_sink = model.model_config.attn_sink
        self.group_size = max(1, model.model_config.group_size)
        self.attn_output_gate = model.model_config.attn_output_gate
        rope_param = model.attention_config.rope_param
        self.rope_dim = rope_param.dim if rope_param else self.head_dim
        self.head_num = model.model_config.head_num

    def _split_q_gate(self, q):
        """Split interleaved Q+gate tensor into separate Q and gate.

        HF layout: [Q_head0, Gate_head0, Q_head1, Gate_head1, ...]
        Returns: (q_real, gate) each with shape [..., num_heads * head_dim]
        """
        output_dims = q.size(-1)
        head_num = output_dims // (self.head_dim * 2)
        orig_shape = list(q.shape)
        if q.dim() == 1:
            q = q.unsqueeze(0)
        q = q.view(q.size(0), head_num, 2, self.head_dim)
        q_real = q[:, :, 0, :].contiguous()
        gate = q[:, :, 1, :].contiguous()
        new_last_dim = head_num * self.head_dim
        q_real = q_real.reshape(-1, new_last_dim)
        gate = gate.reshape(-1, new_last_dim)
        if len(orig_shape) == 1:
            q_real = q_real.squeeze(0)
            gate = gate.squeeze(0)
        return q_real, gate

    def _reorder_and_merge(self, qkvo, gs: int):
        q, k, v, o = qkvo
        gate = None
        # When attn_output_gate, Q is interleaved [Q0, G0, Q1, G1, ...]
        # Split into separate Q and gate before permuting
        if self.attn_output_gate and q is not None:
            q, gate = self._split_q_gate(q)
        # reorder output dim for tm's rotary embedding layout
        if self.model.permute_qk:
            if gs == 1:
                if self.rope_dim < self.head_dim:
                    q = permute_v2_partial(q, self.head_dim, self.rope_dim)
                    k = permute_v2_partial(k, self.head_dim, self.rope_dim)
                else:
                    q = permute_v2(q, self.head_dim)
                    k = permute_v2(k, self.head_dim)
            else:
                assert gs % self.head_dim == 0
        # Merge QKV with gate appended at end if present
        if gate is not None:
            qkv = merge_qkvg_v2(q, k, v, gate, self.tp)
        else:
            qkv = merge_qkv_v2(q, k, v, self.tp)
        # zero bias for `wo` when `w_qkv` has bias but `wo` doesn't
        if o is None and q.dim() == 1:
            o = torch.zeros_like(q)
        return qkv, o

    def _repeat_kv(self, qkvo, gs: int, kind: str):
        """Replicate kv."""
        q, k, v, o = qkvo
        head_dim = self.model.model_config.size_per_head // gs
        kv_head_num = self.model.model_config.kv_head_num // self.model.repeat_kv
        hidden_dim = self.model.model_config.hidden_units

        def _repeat(x):
            n = self.model.repeat_kv

            x = x.reshape(-1, kv_head_num, head_dim)
            x = x.repeat(1, 1, n)
            x = x.reshape(-1, kv_head_num * n * head_dim)

            return x

        k, v = map(_repeat, (k, v))

        if kind == 'bias':
            if o is None:
                o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
            q, k, v, o = map(torch.squeeze, (q, k, v, o))

        return (q, k, v, o)

    def _export(self, idx: int, qkvo, kind: str, pack_fn, apply_gs=[], **kwargs):
        if all(x is None for x in qkvo):
            return
        is_lora_a, is_lora_b = get_lora_flags(kind)
        assert not (is_lora_a or is_lora_b)

        qkvo = tuple(map(transpose, qkvo))

        gs = self.group_size if ('w1' in apply_gs) else 1

        if self.model.repeat_kv:
            qkvo = self._repeat_kv(qkvo, gs, kind)

        qkv, o = self._reorder_and_merge(qkvo, gs)

        self.model.save_split(pack_fn(qkv),
                              self._attn.format(idx, 'w_qkv', kind),
                              split_dim=-1,
                              split_num=self.tp,
                              copy=is_lora_a)
        self.model.save_split(pack_fn(o),
                              self._attn.format(idx, 'wo', kind),
                              split_dim=0,
                              split_num=self.tp,
                              copy=is_lora_b)

    def apply(self, i: int, r: BaseReader):
        for e in get_params(r.attn(i, None), bias=self.attn_bias):
            e(self._export, partial(r.attn, i), i)
        if self.qk_norm:
            q, k = r.qk_norm(i)
            if q is not None and k is not None:
                if self.model.permute_qk:
                    if self.rope_dim < self.head_dim:
                        q = permute_v2_partial(q, self.head_dim, self.rope_dim)
                        k = permute_v2_partial(k, self.head_dim, self.rope_dim)
                    else:
                        q = permute_v2(q, self.head_dim)
                        k = permute_v2(k, self.head_dim)
                self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1])
                self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1])
        if self.attn_sink:
            sinks = r.attn_sinks(i)
            self.model.save_split(sinks, self._attn.format(i, 'sinks', '')[:-1], split_dim=-1, split_num=self.tp)


class MLA(Module):
    """
    requires:
        r.mla(i, kind)
        r.mla_norm(i)
    """

    _mla = 'layers.{0}.attention.{1}.{2}'

    def __init__(self, model: BaseOutputModel):
        self.model = model

    def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs):
        if all(x is None for x in xs):
            return
        q_a, q_b, q, kv_a, kv_b, o = xs

        cfg = self.model.model_config
        head_num = cfg.head_num
        kv_lora_rank = cfg.kv_lora_rank
        qk_rope_dim = cfg.qk_rope_dim
        size_per_head = cfg.size_per_head
        v_head_dim = cfg.v_head_dim

        # ========== MLA Weight Folding for Dimension Mismatch ==========
        # When kv_lora_rank != qk_nope_dim (e.g., GLM 4.7 Flash: 512 != 512+64=576),
        # fold the kc/vc compression/decompression BMMs into q_b_proj/o_proj weights
        # at conversion time to avoid runtime overhead.
        if kind == 'weight' and kv_lora_rank and q is None and q_b is not None and kv_b is not None and o is not None:
            if not (torch.is_floating_point(q_b) and torch.is_floating_point(kv_b) and torch.is_floating_point(o)):
                raise ValueError('MLA weight folding requires floating-point attention weights.')

            orig_q_head_dim = q_b.size(0) // head_num
            orig_qk_nope_dim = orig_q_head_dim - qk_rope_dim
            orig_kv_dim_total = kv_b.size(0) // head_num
            orig_v_head_dim = o.size(1) // head_num
            actual_orig_qk_nope_dim = orig_kv_dim_total - orig_v_head_dim

            if abs(orig_qk_nope_dim - actual_orig_qk_nope_dim) > 1:
                raise ValueError(f'Dimension mismatch: inferred qk_nope from q_b ({orig_qk_nope_dim}) != '
                                 f'inferred from kv_b ({actual_orig_qk_nope_dim})')

            orig_qk_nope_dim = actual_orig_qk_nope_dim
            target_nope_dim = size_per_head - qk_rope_dim
            target_v_head_dim = v_head_dim

            if orig_qk_nope_dim != target_nope_dim or orig_v_head_dim != target_v_head_dim:
                if target_nope_dim != kv_lora_rank or target_v_head_dim != kv_lora_rank:
                    raise ValueError(f'MLA folding expects v_head_dim and nope_dim to equal kv_lora_rank, '
                                     f'got nope={target_nope_dim}, v_head={target_v_head_dim}, rank={kv_lora_rank}')

                if kv_b.size(1) != kv_lora_rank:
                    raise ValueError(f'kv_b_proj second dim must equal kv_lora_rank for MLA folding, '
                                     f'got {kv_b.size(1)} != {kv_lora_rank}')

                # Split kv_b into kc and vc
                kv_b_per_head = kv_b.reshape(head_num, orig_qk_nope_dim + orig_v_head_dim, kv_lora_rank)
                kc_w = kv_b_per_head[:, :orig_qk_nope_dim, :]
                vc_w = kv_b_per_head[:, orig_qk_nope_dim:, :]

                # Fold kc into q_b_proj
                q_b_per_head = q_b.reshape(head_num, orig_q_head_dim, q_b.size(1))
                q_nope_w = q_b_per_head[:, :orig_qk_nope_dim, :]
                q_rope_w = q_b_per_head[:, orig_qk_nope_dim:, :]
                q_nope_expanded = torch.bmm(kc_w.transpose(1, 2), q_nope_w)
                q_b_folded = torch.cat([q_nope_expanded, q_rope_w], dim=1)
                q_b = q_b_folded.reshape(head_num * size_per_head, q_b.size(1))

                # Fold vc into o_proj
                o_per_head = o.reshape(o.size(0), head_num, orig_v_head_dim)
                o_folded = torch.bmm(o_per_head.permute(1, 0, 2), vc_w)
                o = o_folded.permute(1, 0, 2).reshape(o.size(0), head_num * kv_lora_rank)

                # Set kv_b to identity (kc/vc are now absorbed)
                eye = torch.eye(kv_lora_rank, dtype=kv_b.dtype, device=kv_b.device)
                kv_b = torch.cat([eye, eye], dim=0).repeat(head_num, 1)
        # ========== End MLA Weight Folding ==========

        # Transpose after folding
        q_a, q_b, q, kv_a, kv_b, o = map(transpose, (q_a, q_b, q, kv_a, kv_b, o))

        if q is not None:
            q_b = q

        # Pad o_proj to size_per_head if present
        if o is not None:
            o = o.reshape(head_num, v_head_dim, -1)
            o = torch.nn.functional.pad(o, (0, 0, size_per_head - v_head_dim, 0, 0, 0))
            o = o.view(head_num * size_per_head, cfg.hidden_units)

        tp = self.model.attn_tp_size

        # Export MLA weights (handle None for folded-away tensors)
        if q_a is not None:
            self.model.save_split(pack_fn(q_a), self._mla.format(idx, 'q_a_proj', kind))
        q_b_name = 'q_proj' if q_a is None else 'q_b_proj'
        if q_b is not None:
            self.model.save_split(pack_fn(q_b), self._mla.format(idx, q_b_name, kind), split_dim=-1, split_num=tp)
        if kv_a is not None:
            self.model.save_split(pack_fn(kv_a), self._mla.format(idx, 'kv_a_proj', kind))
        # if kv_b is not None:
        #     self.model.save_split(pack_fn(kv_b), self._mla.format(idx, 'kv_b_proj', kind), split_dim=-1, split_num=tp)
        if o is not None:
            self.model.save_split(pack_fn(o), self._mla.format(idx, 'wo', kind), split_dim=0, split_num=tp)

    _layernorm = 'layers.{0}.attention.{1}_a_layernorm'

    def apply(self, i: int, r: BaseReader):

        for f in get_params(r.attn(i, None), bias=False):
            f(self._export, partial(r.mla, i), i)

        q, k = r.mla_norm(i)
        if q is not None:
            self.model.save_split(q, self._layernorm.format(i, 'q'))
        self.model.save_split(k, self._layernorm.format(i, 'kv'))


class LinearAttn(Module):
    _linear_attn = 'layers.{0}.linear_attn.{1}.{2}'

    def __init__(self, model: BaseOutputModel):
        self.model = model
        self.tp = model.attn_tp_size
        cfg = model.model_config
        self.key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim
        self.value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim

    def _tp_interleave_qkv(self, tensor, dim):
        """Split a concatenated [Q, K, V] tensor into components, reshape each
        for TP interleaving, and re-concatenate.

        in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim).
        A naive split doesn't respect component boundaries when key_dim and
        value_dim differ.  This method splits Q/K/V, reshapes each to
        ``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens
        so that a subsequent ``save_split(split_dim=dim)`` gives each rank the
        correct portion.
        """
        if dim < 0:
            dim = tensor.dim() + dim
        q, k, v = torch.split(tensor, [self.key_dim, self.key_dim, self.value_dim], dim=dim)

        def reshape(x):
            # Move TP axis to a new dimension right after ``dim``
            shape = list(x.shape)
            d = shape[dim]
            new_shape = shape[:dim] + [self.tp, d // self.tp] + shape[dim + 1:]
            return x.view(new_shape)

        parts = torch.cat([reshape(q), reshape(k), reshape(v)], dim=dim + 1)
        # Collapse tp and per-shard dims back
        shape = list(parts.shape)
        final_shape = shape[:dim] + [shape[dim] * shape[dim + 1]] + shape[dim + 2:]
        return parts.reshape(final_shape)

    def apply(self, i: int, r: BaseReader):
        layer_types = getattr(self.model.model_config, 'layer_types', [])
        if i >= len(layer_types) or layer_types[i] != 'linear_attention':
            return

        for kind in ['weight', 'bias']:
            weights = r.linear_attn(i, kind)
            if not weights:
                continue

            names = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']
            for name, tensor in zip(names, weights):
                if tensor is None:
                    continue
                if name == 'conv1d':
                    # conv1d shape: (conv_dim, 1, d_conv) where
                    # conv_dim = key_dim*2 + value_dim.  Interleave Q/K/V
                    # portions along dim 0 before splitting for TP.
                    tensor = self._tp_interleave_qkv(tensor, dim=0)
                    self.model.save_split(tensor,
                                          self._linear_attn.format(i, name, kind),
                                          split_dim=0,
                                          split_num=self.tp)
                elif name in ['A_log', 'dt_bias']:
                    # Split per-head params across TP ranks (use -1 to
                    # avoid the 1-D copy shortcut in save_split).
                    self.model.save_split(tensor,
                                          self._linear_attn.format(i, name, kind),
                                          split_dim=-1,
                                          split_num=self.tp)
                elif name == 'out_proj':
                    self.model.save_split(transpose(tensor),
                                          self._linear_attn.format(i, name, kind),
                                          split_dim=0,
                                          split_num=self.tp)
                elif name == 'in_proj_qkv':
                    # in_proj_qkv: (conv_dim, hidden) where conv_dim =
                    # key_dim*2 + value_dim.  After transpose the QKV
                    # components are along dim -1.  Interleave for TP so
                    # each shard gets the correct Q/K/V slice.
                    t = transpose(tensor)
                    t = self._tp_interleave_qkv(t, dim=-1)
                    self.model.save_split(t, self._linear_attn.format(i, name, kind), split_dim=-1, split_num=self.tp)
                else:
                    self.model.save_split(transpose(tensor),
                                          self._linear_attn.format(i, name, kind),
                                          split_dim=-1,
                                          split_num=self.tp)

        norm = r.linear_norm(i, 'weight')
        if norm is not None:
            self.model.export_weight(norm, f'layers.{i}.linear_attn.norm.weight')


class Misc(Module):
    """
    requires:
        r.tok_embeddings()
        r.norm_weight()
        r.output_weight()
    """

    def apply(self, i: int, r: BaseReader):
        """Export embedding, norm, output weight."""
        emb = r.tok_embeddings()
        norm_weight = r.norm_weight()
        output_weight = r.output_weight()

        def pad_weight(tensor: torch.Tensor, tp: int):
            pad_size = None
            vocab_size = self.model.model_config.vocab_size
            if vocab_size % tp != 0:
                pad_size = (vocab_size + tp - 1) // tp * tp - vocab_size
            if pad_size is None:
                return tensor
            return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0)

        tp = self.model.attn_tp_size * self.model.attn_cp_size
        if emb is not None:
            emb = pad_weight(emb, tp=tp)
            self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp)
        if norm_weight is not None:
            self.model.export_weight(norm_weight, 'norm.weight')
        if output_weight is not None:
            output_weight = pad_weight(output_weight, tp=tp)
            # transpose
            self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)


class Transformer:

    def __init__(self, model: BaseOutputModel):
        self.model = model
        modules = [LayerNorm]
        if model.model_config.kv_lora_rank:
            modules.append(MLA)
        else:
            modules.append(Attn)
        if getattr(model.model_config, 'layer_types', []):
            modules.append(LinearAttn)
        if model.model_config.inter_size:
            modules.append(Ffn)
        if model.model_config.expert_num:
            modules.append(MoeFfn)
        self.modules = [c(model) for c in modules]
        self.misc = Misc(model)

    def __call__(self, i: int, r: BaseReader):
        if i >= 0:
            for m in self.modules:
                m(i, r)
            return 1
        else:
            self.misc(i, r)


================================================
FILE: lmdeploy/turbomind/deploy/parameter.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod

import torch


def identity(x):
    return x


def to_half(x: torch.Tensor):
    return x.to(torch.half)


def to_float(x: torch.Tensor):
    return x.to(torch.float)


def to_fp8(x: torch.Tensor):
    assert x.dtype == torch.uint8
    return x.view(dtype=torch.float8_e4m3fn)


def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
    assert x.dtype == torch.uint8, f'x.dtype: {x.dtype}'
    xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)
    a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device)
    for t in reversed(xs):
        a = (a << 4) | t
    return a.squeeze(dim=-1)


def generate_zero_point(g):
    weight_shapes = g('weight_shape')
    result = []
    for weight_shape in weight_shapes:
        row, col = weight_shape
        tensor = torch.full((row, col // 128), 8, dtype=torch.uint8)
        result.append(tensor)
    return (*result, )


class Parameter:
    KEY = ()

    @classmethod
    def take(cls, keys: list[str]):
        if not any(k.endswith(cls.KEYS[0]) for k in keys):
            return False
        xs = []
        for k in keys:
            if any(k.endswith(p) for p in cls.KEYS):
                xs.append(k)
        for x in xs:
            keys.remove(x)
        return xs

    @abstractmethod
    def __call__(cls, f, g, i):
        pass


class QuantWeightOnly(Parameter):
    KEYS = '.qweight', '.scales', '.qzeros'

    def __call__(self, f, g, i):
        f(i, g('qweight'), 'qweight', pack_u4_row)
        f(i, g('scales'), 'scales', to_half, apply_gs=['w2'])
        f(i, g('qzeros'), 'zeros', to_half, apply_gs=['w2'])


class WeightScaleInv(Parameter):
    KEYS = '.weight_scale_inv', '.weight'

    # TODO: flag any operations crossing the quant blocks as illegal
    def __call__(self, f, g, i):
        f(i, g('weight_scale_inv'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2'])
        f(i, g('weight'), 'weight', identity)


class CompressedWeight(Parameter):
    KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point'

    def __init__(self, xs):
        self.has_zero_point = False
        if any(key.endswith(self.KEYS[2]) for key in xs):
            self.has_zero_point = True

    def __call__(self, f, g, i):
        f(i, g('weight_packed'), 'qweight', pack_u4_row)
        f(i, g('weight_scale'), 'scales', to_half, apply_gs=['w2'])
        if self.has_zero_point:
            f(i, g('weight_zero_point'), 'zeros', to_half, apply_gs=['w2'])
        else:
            f(i, generate_zero_point(g), 'zeros', to_half, apply_gs=['w2'])


class Mxfp4Weight(Parameter):
    KEYS = '.blocks', '.scales'

    def __call__(self, f, g, i):
        f(i, g('blocks'), 'weight', pack_u4_row)
        f(i, g('scales'), 'scales', identity, apply_gs=['w2'])


class Weight(Parameter):
    KEYS = '.weight',

    def __call__(self, f, g, i):
        f(i, g('weight'), 'weight', identity)


class Bias(Parameter):
    KEYS = '.bias',

    def __call__(self, f, g, i):
        f(i, g('bias'), 'bias', identity)


class PLora(Parameter):
    KEYS = '.Plora_A.weight', '.Plora_B.weight'

    def __call__(self, f, g, i):
        f(i, g('Plora_A.weight'), 'lora_a.weight', identity)
        f(i, g('Plora_B.weight'), 'lora_b.weight', identity)


def get_params(keys: list[str], bias=0):
    ps = []
    if PLora.take(keys):
        ps.append(PLora())
    if QuantWeightOnly.take(keys):
        ps.append(QuantWeightOnly())
    if WeightScaleInv.take(keys):
        ps.append(WeightScaleInv())
    xs = CompressedWeight.take(keys)
    if xs:
        ps.append(CompressedWeight(xs))
    if Mxfp4Weight.take(keys):
        ps.append(Mxfp4Weight())
    if Weight.take(keys):
        ps.append(Weight())
    if bias and Bias.take(keys):
        ps.append(Bias())
    return ps


================================================
FILE: lmdeploy/turbomind/deploy/policy.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch.cuda


def to_cuda(x: torch.Tensor, *args):
    return x.cuda()


def get_u4_slices(x: torch.Tensor, dtype: torch.dtype) -> List[torch.Tensor]:
    MAP = {torch.int32: 8, torch.uint8: 2}
    xs = []
    for _ in range(MAP[x.dtype]):
        xs.append((x & 15).to(dtype))
        x = x >> 4
    return xs


def unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor:
    xs = get_u4_slices(x, torch.uint8)
    order = [0, 4, 1, 5, 2, 6, 3, 7]
    ys = [xs[i] for i in order]
    return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1)


def process_awq_gemm(x: torch.Tensor, kind: str):
    x = x.cuda()
    if x.dtype == torch.int32:
        x = unpack_awq_gemm(x)
    if kind in ['qweight', 'qzeros', 'scales']:
        x = x.t()
    return x


def process_gptq(x: torch.Tensor, kind: str):
    x = x.cuda()
    if x.dtype == torch.int32:
        xs = get_u4_slices(x, torch.uint8)
        if kind == 'qweight':  # (k/8,n)
            x = torch.stack(xs, dim=1).view(-1, x.size(-1))
        else:  # 'qzeros' (k/g,n/8)
            x = torch.stack(xs, dim=-1).view(x.size(0), -1) + 1
    if kind in ['qweight', 'qzeros', 'scales']:
        x = x.t()
    return x


def process_mxfp4(x: torch.Tensor, kind: str):
    # print(x.shape, x.dtype, kind)
    x = x.cuda()
    if kind == 'blocks':
        xs = get_u4_slices(torch.flatten(x, start_dim=-2), torch.uint8)
        x = torch.flatten(torch.stack(xs, dim=-1), start_dim=-2)
    if kind == 'scales':
        pass
    return x


def process_fp8(x: torch.Tensor, kind: str):
    x = x.cuda()
    if x.dtype == torch.float8_e4m3fn:
        # some ops (e.g. torch.cat) for fp8 is not implemented in pytorch
        return x.view(dtype=torch.uint8)
    elif kind != 'weight_scale_inv' and x.dtype == torch.float:
        return x.to(dtype=torch.bfloat16)
    else:
        return x.to(dtype=torch.bfloat16)


def process_compressed_tensor(x: torch.Tensor, kind: str):
    x = x.cuda()
    if x.dtype == torch.int32:
        xs = get_u4_slices(x, torch.uint8)
        if kind == 'weight_packed':  # (out_channels, in_channels // 8)
            x = torch.stack(xs, dim=-1).view(*x.shape[:-1], -1)
        elif kind == 'weight_zero_point':  # (out_channels // 8, in_channels // group_size)
            x = torch.stack(xs, dim=1).view(-1, x.size(-1))
    return x


def get_input_policy(model_format):
    if model_format == 'awq':
        return process_awq_gemm
    elif model_format == 'gptq':
        return process_gptq
    elif model_format == 'mxfp4':
        return process_mxfp4
    elif model_format == 'fp8':
        return process_fp8
    elif model_format == 'compressed-tensors':
        return process_compressed_tensor
    else:
        return to_cuda


================================================
FILE: lmdeploy/turbomind/deploy/source_model/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .baichuan import Baichuan2Model, BaichuanModel  # noqa: F401
from .deepseek2 import DeepSeek2Model  # noqa: F401
from .deepseek_vl import DeepSeekVLModel  # noqa: F401
from .glm4 import Glm4Model  # noqa: F401
from .glm4_moe_lite import Glm4MoeLiteModel  # noqa: F401
from .gpt_oss import GptOssModel  # noqa: F401
from .internlm2 import InternLM2Model  # noqa: F401
from .internvl import InternVLModel  # noqa: F401
from .llama import LlamaModel  # noqa: F401
from .llava import LlavaModel  # noqa: F401
from .minicpmv import MiniCPMVModel  # noqa: F401
from .mixtral import MixtralModel  # noqa: F401
from .molmo import MolmoModel  # noqa: F401
from .qwen import QwenModel  # noqa: F401
from .xcomposer2 import Xcomposer2Model  # noqa: F401


================================================
FILE: lmdeploy/turbomind/deploy/source_model/baichuan.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class BaichuanReader(LlamaReader):
    """BaichuanReader."""

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind for layer i."""
        q, k, v, o = (None, ) * 4
        pack_key = f'model.layers.{i}.self_attn.W_pack.{kind}'
        qkv = self.transform(self.params.get(pack_key), kind)
        if qkv is not None:
            q, k, v = torch.split(qkv, qkv.shape[0] // 3, dim=0)
        o = self.params.get(f'model.layers.{i}.self_attn.o_proj.{kind}')
        o = self.transform(o, kind)
        return q, k, v, o


@INPUT_MODELS.register_module(name='baichuan')
class BaichuanModel(LlamaModel):
    """Llama model in baichuan format."""

    Reader = BaichuanReader


class Baichuan2Reader(BaichuanReader):
    """Baichuan2Reader."""

    def output_weight(self):
        """Get output."""
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
        tensor = self.params.get('lm_head.weight', None)
        if tensor is not None:
            tensor = tensor.cuda()
            tensor = torch.nn.functional.normalize(tensor)
        return tensor


@INPUT_MODELS.register_module(name='baichuan2')
class Baichuan2Model(LlamaModel):
    """Llama model in baichuan format."""

    Reader = Baichuan2Reader


================================================
FILE: lmdeploy/turbomind/deploy/source_model/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Dict, Iterator, Union

import torch
from mmengine import Registry

INPUT_MODELS = Registry('source model', locations=['lmdeploy.turbomind.deploy.source_model.base'])


class BaseReader(ABC):
    """Mapping between TM modules and source modules."""

    def __init__(self):
        pass

    def transform(self, x: Union[torch.Tensor, None], kind: str) -> Union[torch.Tensor, None]:
        return None if x is None else self._transform(x, kind)

    @abstractmethod
    def _transform(self, x: torch.Tensor, kind: str):
        """Transform x."""
        pass


class BaseInputModel(ABC):
    """Base class for input model."""

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
        """Constructor for BaseInputModel.

        Args:
            model_path (str): the path of the model.
            tokenizer_path (str): the path of the tokenizer model.
        """
        self.model_path = model_path
        self.tokenizer_path = tokenizer_path

    @abstractmethod
    def model_info(self) -> Dict:
        """Read model info."""
        pass

    @abstractmethod
    def readers(self) -> Iterator[BaseReader]:
        pass


================================================
FILE: lmdeploy/turbomind/deploy/source_model/deepseek2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os

from ..config import RopeParam
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class DeepSeek2Reader(LlamaReader):

    def moe_ffn_gate(self, i, kind):
        return self.params.get(f'model.layers.{i}.mlp.gate.{kind}')

    def moe_ffn_expert(self, e=None, i=None, kind=None):
        if not kind:
            return self.filter(r'experts', i)
        result = []
        for key in ['gate', 'down', 'up']:
            name = f'model.layers.{i}.mlp.experts.{e}.{key}_proj.{kind}'
            tensor = self.params.get(name)
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        if not kind:
            # Filter by layer number to get only keys for this specific layer
            if i == 0:
                pattern = rf'model\.layers\.{i}\.mlp\.'
            else:
                pattern = rf'model\.layers\.{i}\.mlp\.shared_experts\.'
            return self.filter(pattern, None)
        result = []
        for key in ['gate', 'down', 'up']:
            name = f'model.layers.{i}.mlp.shared_experts.{key}_proj.{kind}'
            if i == 0:
                name = name.replace('shared_experts.', '')
            tensor = self.params.get(name)
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def ffn(self, i: int, kind: str):
        return self._ffn(i, kind)

    def mla(self, i: int, kind: str):
        if not kind:
            return self.filter(r'self_attn.*proj', i)
        result = []
        for key in ['q_a_proj', 'q_b_proj', 'q_proj', 'kv_a_proj_with_mqa', 'kv_b_proj', 'o_proj']:
            tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.{key}.{kind}')
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def mla_norm(self, i: int):
        result = []
        for k in ['q', 'kv']:
            name = f'{self.attn_layer_prefix}.{i}.self_attn.{k}_a_layernorm.weight'  # noqa: E501
            result.append(self.params.get(name))
        return (*result, )


def get_yarn_params(rope_scaling: dict):

    scaling_factor = float(rope_scaling['factor'])
    mscale = rope_scaling['mscale']
    mscale_all_dim = rope_scaling['mscale_all_dim']

    def yarn_get_mscale(scale=1, mscale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * mscale * math.log(scale) + 1.0

    _mscale = float(yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim))

    softmax_scale = 0
    if mscale_all_dim:
        scale = yarn_get_mscale(scaling_factor, mscale_all_dim)
        softmax_scale = scale * scale

    return _mscale, softmax_scale


@INPUT_MODELS.register_module(name='deepseek2')
class DeepSeek2Model(LlamaModel):

    Reader = DeepSeek2Reader

    def model_info(self):
        cfg = self.model_config
        info = super().model_info()
        qk_nope_dim = cfg['qk_nope_head_dim']
        qk_rope_dim = cfg['qk_rope_head_dim']
        kv_lora_rank = cfg['kv_lora_rank']
        q_head_dim = qk_nope_dim + qk_rope_dim
        num_layer = cfg['num_hidden_layers']
        expert_num = cfg['n_routed_experts']
        expert_num = [expert_num] * num_layer
        expert_num[0] = 0
        n_shared_experts = cfg['n_shared_experts']
        expert_inter_size = cfg['moe_intermediate_size']
        experts_per_token = cfg['num_experts_per_tok']
        inter_size = [n_shared_experts * expert_inter_size] * num_layer
        inter_size[0] = cfg['intermediate_size']
        norm_topk_prob = cfg['norm_topk_prob']
        size_per_head = qk_rope_dim + qk_nope_dim
        v_head_dim = cfg['v_head_dim']
        softmax_scale = 0.0
        disable_mla_fold = os.getenv('LMDEPLOY_MLA_FOLD', '1').lower() in ('0', 'false', 'no')
        if kv_lora_rank and kv_lora_rank != qk_nope_dim and not disable_mla_fold:
            # MLA folding: remap to kv_lora_rank-based head dims and fold
            # kc/vc BMMs into q_b_proj/o_proj at conversion time.
            size_per_head = kv_lora_rank + qk_rope_dim
            v_head_dim = kv_lora_rank
            softmax_scale = q_head_dim**(-0.5)
        elif kv_lora_rank and kv_lora_rank != qk_nope_dim:
            softmax_scale = q_head_dim**(-0.5)

        info.update(kv_lora_rank=kv_lora_rank,
                    q_lora_rank=cfg['q_lora_rank'] or 0,
                    qk_rope_dim=qk_rope_dim,
                    v_head_dim=v_head_dim,
                    size_per_head=size_per_head,
                    kv_head_num=1,
                    expert_num=expert_num,
                    expert_inter_size=expert_inter_size,
                    experts_per_token=experts_per_token,
                    inter_size=inter_size,
                    norm_topk_prob=norm_topk_prob,
                    routed_scale=cfg['routed_scaling_factor'],
                    topk_method=cfg['topk_method'],
                    topk_group=cfg['topk_group'],
                    moe_group_num=cfg['n_group'],
                    scoring_func=cfg.get('scoring_func', 'softmax'),
                    tune_layer_num=2)
        if 'router_n_groups' in cfg and cfg['router_n_groups'] > 0:
            info['router_n_groups'] = cfg['router_n_groups']
        rope_param: RopeParam = info['rope_param']
        rope_param.dim = qk_rope_dim
        if 'rope_parameters' in cfg:
            # transformers v5.0.0 aggregates all rope-related parameters into 'rope_parameters'
            rope_scaling = cfg['rope_parameters']
        else:
            rope_scaling = cfg.get('rope_scaling')
        if rope_scaling and rope_scaling.get('type') == 'yarn':
            attention_factor, yarn_scale = get_yarn_params(rope_scaling)
            yarn_scale *= q_head_dim**(-0.5)
            rope_param.max_position_embeddings = rope_scaling['original_max_position_embeddings']
            rope_param.attention_factor = attention_factor
            info.update(rope_param=rope_param, softmax_scale=yarn_scale)
        elif softmax_scale:
            info.update(softmax_scale=softmax_scale)
        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/deepseek_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp

from ..config import RopeParam
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class DeepSeekVLReader(LlamaReader):
    """DeepSeekVL model reader."""

    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'

    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
        model_cfg = model_cfg['language_config']
        super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.params[f'language_model.model.layers.{i}.input_layernorm.weight']

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.params[f'language_model.model.layers.{i}.post_attention_layernorm.weight']


@INPUT_MODELS.register_module(name='deepseekvl')
class DeepSeekVLModel(LlamaModel):
    """DeepSeekVL model in hf format."""

    Reader = DeepSeekVLReader

    def model_info(self):
        """Read model info."""
        params_path = osp.join(self.model_path, 'config.json')
        with open(params_path) as f:
            model_arg = json.load(f)
            if 'language_config' in model_arg and model_arg['language_config'].get('model_type', None) == 'llama':
                model_arg = model_arg['language_config']  # depseek-vl
            num_layer = model_arg['num_hidden_layers']
            hidden_units = model_arg.get('hidden_size', 4096)
            inter_size = model_arg.get('intermediate_size', 11008)
            vocab_size = model_arg.get('vocab_size', 102400)
            norm_eps = model_arg.get('rms_norm_eps', 1e-06)
            attn_head_num = model_arg.get('num_attention_heads', 32)
            if 'num_key_value_heads' in model_arg:
                kv_head_num = model_arg['num_key_value_heads']
            else:
                kv_head_num = model_arg.get('num_attention_heads', 32)
            rope_theta = float(model_arg.get('rope_theta', 10000.0))
            max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
            rope_scaling = model_arg.get('rope_scaling', None)
            scaling_factor = 0.0
            scaling_type = 'default'
            if isinstance(rope_scaling, dict):
                scaling_type = model_arg['rope_scaling'].get('type', 'default')
                scaling_factor = model_arg['rope_scaling'].get('factor', '')
            head_dim = model_arg.get('head_dim', hidden_units // attn_head_num)
            rope_param = RopeParam(type=scaling_type,
                                   base=rope_theta,
                                   dim=head_dim,
                                   max_position_embeddings=max_position_embeddings,
                                   factor=scaling_factor)

        return dict(num_layer=num_layer,
                    norm_eps=norm_eps,
                    head_num=attn_head_num,
                    kv_head_num=kv_head_num,
                    hidden_units=hidden_units,
                    inter_size=inter_size,
                    vocab_size=vocab_size,
                    max_position_embeddings=max_position_embeddings,
                    rope_param=rope_param)


================================================
FILE: lmdeploy/turbomind/deploy/source_model/glm4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp

import torch

from ..config import RopeParam
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class Glm4Reader(LlamaReader):
    """Glm4Reader."""

    attn_layer_patten = r'transformer\.encoder\.layers\.([0-9]+).'
    tok_embeddings_key = 'transformer.embedding.word_embeddings.weight'
    norm_weight_key = 'transformer.encoder.final_layernorm.weight'
    output_weight_key = 'transformer.output_layer.weight'

    attn_pattern = r'self_attention'

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind for layer i."""
        qkv = self.params[f'transformer.encoder.layers.{i}'
                          f'.self_attention.query_key_value.{kind}']
        qkv = self.transform(qkv, kind)
        attn_head_num = self.model_cfg['num_attention_heads']
        kv_head_num = attn_head_num
        if self.model_cfg.get('multi_query_attention', False):
            kv_head_num = self.model_cfg['multi_query_group_num']
        HEAD_DIM = 128
        q, k, v = torch.split(qkv, [attn_head_num * HEAD_DIM, kv_head_num * HEAD_DIM, kv_head_num * HEAD_DIM], dim=0)
        o = self.params.get(f'transformer.encoder.layers.{i}.self_attention.dense.{kind}')
        o = self.transform(o, kind)
        if o is None:  # handle the case when qkv has bias but o doesn't
            o = torch.zeros_like(q)
        return q, k, v, o

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.params[f'transformer.encoder.layers.{i}.input_layernorm.weight']

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        up_and_gate = self.params[f'transformer.encoder.layers.{i}.mlp.dense_h_to_4h.{kind}']
        up_and_gate = self.transform(up_and_gate, kind)
        up, gate = up_and_gate.chunk(2, dim=0)
        down = self.params[f'transformer.encoder.layers.{i}.mlp.dense_4h_to_h.{kind}']
        down = self.transform(down, kind)
        return (up, down, gate)

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.params[f'transformer.encoder.layers.{i}.post_attention_layernorm.weight']


@INPUT_MODELS.register_module(name='glm4')
class Glm4Model(LlamaModel):
    """Glm2/3/4 model in hf format."""

    Reader = Glm4Reader

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
        super().__init__(model_path, tokenizer_path, **kwargs)
        config_path = osp.join(self.model_path, 'config.json')
        with open(config_path) as f:
            self.config = json.load(f)

    def model_info(self):
        """Read model info."""
        config = self.config
        hidden_units = config.get('hidden_size', None)
        num_layer = config.get('num_hidden_layers', None)
        num_layer = config.get('num_layers', num_layer)
        norm_eps = config['layernorm_epsilon']
        rope_theta = float(config.get('rotary_emb_base', 10000.0))
        rope_ratio = float(config.get('rope_ratio', 1.0))
        rope_theta *= rope_ratio
        attn_head_num = config['num_attention_heads']
        kv_head_num = attn_head_num
        inter_size = config['ffn_hidden_size']
        vocab_size = config['padded_vocab_size']
        attn_bias = config['add_qkv_bias']
        if config['multi_query_attention']:
            kv_head_num = config['multi_query_group_num']
        seq_length = config['seq_length']
        rope_param = RopeParam(type='default', base=rope_theta, dim=64)
        return dict(num_layer=num_layer,
                    norm_eps=norm_eps,
                    head_num=attn_head_num,
                    kv_head_num=kv_head_num,
                    hidden_units=hidden_units,
                    attn_bias=int(attn_bias),
                    inter_size=inter_size,
                    vocab_size=vocab_size,
                    rope_param=rope_param,
                    max_position_embeddings=seq_length,
                    permute_qk=False)  # head layout is same as TM


================================================
FILE: lmdeploy/turbomind/deploy/source_model/glm4_moe_lite.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
"""GLM-4 MoE Lite (e.g. GLM-4.7-Flash) source model for TurboMind.

Architecture: MLA (Multi-head Latent Attention) + MoE with dense first layer.
Weight layout follows HuggingFace checkpoint with model.layers.* (same family as DeepSeek2).
"""

from .base import INPUT_MODELS
from .deepseek2 import DeepSeek2Model, DeepSeek2Reader


class Glm4MoeLiteReader(DeepSeek2Reader):
    """Reader for Glm4MoeLiteForCausalLM (GLM-4.7-Flash).

    Uses same key layout as DeepSeek2: model.layers.{i}.self_attn.*, model.layers.{i}.mlp.*
    Supports noaux_tc via e_score_correction_bias.
    """

    attn_layer_prefix = 'model.layers'
    attn_layer_patten = r'model\.layers\.([0-9]+).'
    tok_embeddings_key = 'model.embed_tokens.weight'
    norm_weight_key = 'model.norm.weight'
    output_weight_key = 'lm_head.weight'

    def moe_ffn_gate_correction_bias(self, i: int):
        """Per-expert score correction bias for noaux_tc routing."""
        return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.e_score_correction_bias')


@INPUT_MODELS.register_module(name='glm4-moe-lite')
class Glm4MoeLiteModel(DeepSeek2Model):
    """GLM-4 MoE Lite (e.g. GLM-4.7-Flash) in HF format.

    MLA + MoE with first_k_dense_replace; config mapping aligned to DeepSeek2.
    """

    Reader = Glm4MoeLiteReader

    def model_info(self):
        cfg = self.model_config
        # Set default MoE routing config for GLM-4 MoE Lite if not in HF config
        if 'topk_method' not in cfg:
            cfg['topk_method'] = 'noaux_tc'
        if 'topk_group' not in cfg:
            cfg['topk_group'] = 1
        if 'n_group' not in cfg:
            cfg['n_group'] = 1
        if 'scoring_func' not in cfg:
            cfg['scoring_func'] = 'sigmoid'

        info = super().model_info()
        # GLM4 MoE Lite uses noaux_tc routing with sigmoid scoring
        info['topk_method'] = 'noaux_tc'
        info['scoring_func'] = 'sigmoid'
        if 'router_n_groups' in cfg and cfg['router_n_groups'] > 0:
            info['router_n_groups'] = cfg['router_n_groups']

        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/gpt_oss.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import re

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


def map_experts(str):
    s = re.sub(r'(experts.*proj)$', r'\1.weight', str)
    s = re.sub(r'(experts.*proj)_bias$', r'\1.bias', s)
    s = re.sub(r'(experts.*proj)_blocks$', r'\1.blocks', s)
    s = re.sub(r'(experts.*proj)_scales$', r'\1.scales', s)
    return s


class GptOssReader(LlamaReader):

    mappings = [map_experts]

    def moe_ffn_expert(self, e=None, i=None, kind=None):
        if not kind:
            return self.filter(r'experts', i)
        result = []
        for key in ['gate_up', 'down']:
            name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{key}_proj.{kind}'
            tensor = self.params.get(name)[e]
            if kind == 'weight':  # experts in BF16 models are in M-major
                tensor = tensor.cuda().t()
            if key == 'gate_up':
                gate, up = tensor[::2], tensor[1::2]
                result.append(self.transform(gate, kind))
                result.append(self.transform(up, kind))
            else:
                result.append(self.transform(tensor, kind))
        return (result[0], result[2], result[1])

    def moe_ffn_gate(self, i, kind):
        return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.router.{kind}'), kind)

    def attn_sinks(self, i):
        return self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.sinks')


@INPUT_MODELS.register_module(name='gpt-oss')
class GptOssModel(LlamaModel):

    Reader = GptOssReader

    def model_info(self):
        cfg = self.model_config
        types = cfg['layer_types']
        sliding_window = cfg['sliding_window']
        info = super().model_info()
        info.update(attn_bias=int(cfg['attention_bias']),
                    mlp_bias=True,
                    expert_router_bias=True,
                    expert_num=cfg['num_local_experts'],
                    expert_inter_size=cfg['intermediate_size'],
                    experts_per_token=cfg['experts_per_token'],
                    norm_topk_prob=True,
                    inter_size=0,
                    window_size=[sliding_window if x == 'sliding_attention' else 0 for x in types],
                    attn_sink=True,
                    activation_type='gpt-oss')
        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/internlm2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import re

import torch

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class InternLM2Reader(LlamaReader):
    """InternLM2 model reader."""

    attn_layer_prefix = 'model.layers'
    attn_layer_patten = r'model\.layers\.([0-9]+).'
    tok_embeddings_key = 'model.tok_embeddings.weight'
    norm_weight_key = 'model.norm.weight'
    output_weight_key = 'output.weight'

    attn_pattern = r'attention'
    ffn_pattern = r'feed_forward'

    proj_pattern = 'w'

    def filter(self, pattern: str, i: int | None):
        params = []
        for k in self.params.keys():
            if re.search(pattern, k):
                params.append(k)

        if self.fp8_quant and pattern == self.attn_pattern:
            from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8
            q, k, v = (None, ) * 3
            kv_head_num = self.model_cfg['num_key_value_heads']
            gs = int(self.model_cfg['num_attention_heads'] / kv_head_num)
            qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wqkv.weight')

            if qkv is not None:
                qkv = qkv.view(kv_head_num, gs + 2, 128, -1)
                hidden_dim = qkv.shape[-1]
                q, k, v = torch.split(qkv, [gs, 1, 1], dim=1)

                tensors = [q.reshape(-1, hidden_dim), k.reshape(-1, hidden_dim), v.reshape(-1, hidden_dim)]
                split_sizes = [gs, 1, 1]
                keys = ['q', 'k', 'v']
                qkv_weight = []
                for tensor, split_size, key in zip(tensors, split_sizes, keys):
                    qweight, scale = quant_blocked_fp8(tensor, torch.float8_e4m3fn, block_size=128)
                    qweight = qweight.reshape(kv_head_num, split_size, 128, -1)
                    qkv_weight.append(qweight)

                    self.params[f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.weight_scale_inv'] = scale
                    params.append(f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.weight_scale_inv')

                qkv_weight = torch.cat(qkv_weight, dim=1)
                qkv_weight = qkv_weight.reshape(-1, hidden_dim)
                self.params[f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.wqkv.weight'] = qkv_weight

            return params
        else:
            return params

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind for layer i."""
        if self.fp8_quant and kind == 'weight_scale_inv':
            result = []
            for key in ['q', 'k', 'v', 'o']:
                tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.{kind}')
                tensor = self.transform(tensor, kind)
                result.append(tensor)
            return (*result, )
        q, k, v = (None, ) * 3
        kv_head_num = self.model_cfg['num_key_value_heads']
        gs = int(self.model_cfg['num_attention_heads'] / kv_head_num)
        qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wqkv.{kind}')
        qkv = self.transform(qkv, kind)
        if qkv is not None:
            qkv = qkv.view(kv_head_num, gs + 2, 128, -1)
            hidden_dim = qkv.shape[-1]
            q, k, v = torch.split(qkv, [gs, 1, 1], dim=1)
            q = q.reshape(-1, hidden_dim)
            k = k.reshape(-1, hidden_dim)
            v = v.reshape(-1, hidden_dim)
        o = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wo.{kind}')
        o = self.transform(o, kind)
        return (q, k, v, o)

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.params[f'{self.attn_layer_prefix}.{i}.attention_norm.weight']

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        if not kind:
            return self.filter(self.ffn_pattern, i)
        result = []
        for key in ['w1', 'w2', 'w3']:
            tensor = self.params[f'{self.attn_layer_prefix}.{i}.feed_forward.{key}.{kind}']
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.params[f'{self.attn_layer_prefix}.{i}.ffn_norm.weight']


@INPUT_MODELS.register_module(name='internlm2')
class InternLM2Model(LlamaModel):
    """InternLM2 model in hf format."""

    Reader = InternLM2Reader


================================================
FILE: lmdeploy/turbomind/deploy/source_model/internvl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .base import INPUT_MODELS
from .gpt_oss import GptOssReader
from .internlm2 import InternLM2Reader
from .llama import LlamaModel, LlamaReader
from .qwen import Qwen3MoeReader, Qwen3Reader


class InternVLReader(LlamaReader):
    """InternVLReader for llama model."""

    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'


# Note the subtle difference in keys
class InternVL2Reader(InternLM2Reader):
    """InternVLReader for InternLM2 model."""

    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.tok_embeddings.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.output.weight'


class InternVL3d5Reader(Qwen3Reader):
    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'


class InternVL3d5Qwen3MoEReader(Qwen3MoeReader):
    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'


class InternVL3d5GptOSSReader(GptOssReader):
    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'


class InternS1Reader(Qwen3MoeReader):
    """InternS1Reader for internlm/InternS1 model."""

    attn_layer_prefix = 'model.language_model.layers'
    attn_layer_patten = r'model\.language_model\.layers\.([0-9]+).'
    tok_embeddings_key = 'model.language_model.embed_tokens.weight'
    norm_weight_key = 'model.language_model.norm.weight'
    output_weight_key = 'lm_head.weight'


class InternS1MiniReader(Qwen3Reader):

    attn_layer_prefix = 'model.language_model.layers'
    attn_layer_patten = r'model\.language_model\.layers\.([0-9]+).'
    tok_embeddings_key = 'model.language_model.embed_tokens.weight'
    norm_weight_key = 'model.language_model.norm.weight'
    output_weight_key = 'lm_head.weight'


@INPUT_MODELS.register_module(name='internvl')
class InternVLModel(LlamaModel):
    """InternVL model in hf format."""

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
        super().__init__(model_path, tokenizer_path, **kwargs)
        from transformers import AutoConfig
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

        arch = config.architectures[0]
        if arch == 'InternVLChatModel' or arch == 'InternVLForConditionalGeneration':
            relations = dict(InternLM2ForCausalLM=('internlm2', InternVL2Reader),
                             LlamaForCausalLM=('llama', InternVLReader),
                             Qwen2ForCausalLM=('qwen2', InternVLReader),
                             Qwen3MoeForCausalLM=('qwen3-moe', InternVL3d5Qwen3MoEReader),
                             Qwen3ForCausalLM=('qwen3', InternVL3d5Reader),
                             GptOssForCausalLM=('gpt-oss', InternVL3d5GptOSSReader))
        elif arch == 'InternS1ForConditionalGeneration':
            relations = dict(Qwen3MoeForCausalLM=('qwen3-moe', InternS1Reader),
                             Qwen3ForCausalLM=('qwen3', InternS1MiniReader))
        else:
            raise ValueError('unsupported model arch {arch}')
        self.llm_config = getattr(config, 'llm_config', None) or getattr(config, 'text_config', None)
        arch = self.llm_config.architectures[0]
        llm_model, self.Reader = relations[arch]
        self.llm_model = INPUT_MODELS.get(llm_model)(model_path=model_path, tokenizer_path=tokenizer_path, **kwargs)

    def model_info(self):
        """Read model info."""
        self.llm_model.model_config = self.llm_config.to_dict()
        return self.llm_model.model_info()


================================================
FILE: lmdeploy/turbomind/deploy/source_model/llama.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
import re

import torch

from lmdeploy.archs import get_model_arch

from ..config import RopeParam
from ..loader import create_loader
from .base import INPUT_MODELS, BaseInputModel, BaseReader


class LlamaReader(BaseReader):
    """LlamaReader."""

    attn_layer_prefix = 'model.layers'
    attn_layer_patten = r'model\.layers\.([0-9]+).'
    tok_embeddings_key = 'model.embed_tokens.weight'
    norm_weight_key = 'model.norm.weight'
    output_weight_key = 'lm_head.weight'

    attn_pattern = r'self_attn'
    ffn_pattern = r'mlp'

    proj_pattern = 'proj'
    scale_inv_suffix = '_scale_inv'

    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy, fp8_quant=False):
        super().__init__()
        self.params = unused_params
        self.params.update(new_params)
        self.last_bin = last_bin
        self.model_cfg = model_cfg
        tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False)
        if tie_word_embeddings:
            self.output_weight_key = self.tok_embeddings_key
        self.processor = policy
        self.fp8_quant = fp8_quant
        if self.fp8_quant:
            quant_params = self.quant_weight_fp8()
            self.params.update(quant_params)

    def quant_weight_fp8(self):
        from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8
        pattern_str = fr'({self.attn_pattern}|{self.ffn_pattern}).*{self.proj_pattern}.*\.weight'
        target_pattern = re.compile(pattern_str)

        if self.__class__.__name__ == 'InternLM2Reader':
            skip_pattern = re.compile(r'wqkv.*\.weight')
        else:
            skip_pattern = None

        quant_params = {}
        for name, weight in self.params.items():
            if target_pattern.search(name) and name.endswith('.weight'):
                if skip_pattern and skip_pattern.search(name):
                    continue
                q_weight, scale = quant_blocked_fp8(weight, torch.float8_e4m3fn, block_size=128)
                quant_params[name] = q_weight
                quant_params[f'{name}{self.scale_inv_suffix}'] = scale.to(weight.dtype)

        return quant_params

    def filter(self, pattern: str, i: int | None):
        params = []
        for k in self.params.keys():
            if re.search(pattern, k):
                params.append(k)
        return params

    def tok_embeddings(self):
        """Get embeddings."""
        return self.transform(self.params.get(self.tok_embeddings_key, None), 'weight')

    def norm_weight(self):
        """Get norm."""
        return self.transform(self.params.get(self.norm_weight_key, None), 'weight')

    def output_weight(self):
        """Get output."""
        return self.transform(self.params.get(self.output_weight_key, None), 'weight')

    def _transform(self, x: torch.Tensor, kind: str):
        return self.processor(x, kind)

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind for layer i."""
        result = []
        for key in ['q', 'k', 'v', 'o']:
            tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.{key}_proj.{kind}')
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def attn(self, i: int, kind: str):
        if not kind:
            return self.filter(self.attn_pattern, i)
        return self._attn(i, kind)

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.input_layernorm.weight'], 'weight')

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        if not kind:
            return self.filter(self.ffn_pattern, i)
        result = []
        for key in ['gate', 'down', 'up']:
            tensor = self.params[f'{self.attn_layer_prefix}.{i}.mlp.{key}_proj.{kind}']
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def ffn(self, i: int, kind: str):
        if not kind:
            return self.filter(self.ffn_pattern, i)
        return self._ffn(i, kind)

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.post_attention_layernorm.weight'], 'weight')


@INPUT_MODELS.register_module(name='llama')
class LlamaModel(BaseInputModel):
    """Llama model in hf format."""

    Reader = LlamaReader

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
        super().__init__(model_path, tokenizer_path)
        self.policy = kwargs.get('input_policy')
        _, model_config = get_model_arch(model_path)
        if hasattr(model_config, 'text_config'):
            model_config = model_config.text_config
        elif hasattr(model_config, 'llm_config'):
            model_config = model_config.llm_config
        if hasattr(model_config, 'to_dict'):
            self.model_config = model_config.to_dict()
        else:
            self.model_config = model_config
        self.fp8_quant = kwargs.get('fp8_quant', False)

    def readers(self):
        mappings = getattr(self.Reader, 'mappings', [])
        loader = create_loader(self.model_path, self.Reader.attn_layer_patten, mappings)
        for i, param in loader.items():
            reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant)
            yield i, reader
        torch.cuda.empty_cache()

    def model_info(self):
        """Read model info."""
        model_arg = self.model_config
        num_layer = model_arg['num_hidden_layers']
        norm_eps = model_arg['rms_norm_eps']
        attn_head_num = model_arg['num_attention_heads']
        vocab_size = model_arg['vocab_size']
        inter_size = model_arg.get('intermediate_size', 0)
        if 'num_key_value_heads' in model_arg:
            kv_head_num = model_arg['num_key_value_heads']
        else:
            kv_head_num = model_arg['num_attention_heads']
        hidden_units = model_arg['hidden_size']
        # head_dim could be none in config
        head_dim = model_arg.get('head_dim', None)
        head_dim = head_dim or hidden_units // attn_head_num
        # compute rope param
        if 'rope_parameters' in model_arg:
            # transformers v5.0.0 aggregates rope settings into rope_parameters
            rope_scaling = model_arg['rope_parameters']
            rope_theta = float(rope_scaling.get('rope_theta', 10000.0))
        else:
            rope_theta = float(model_arg.get('rope_theta', 10000.0))
            rope_scaling = model_arg.get('rope_scaling', None)
        max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
        rope_param = RopeParam(type='default', base=rope_theta, dim=head_dim)
        if isinstance(rope_scaling, dict):
            rope_type = rope_scaling.get('rope_type', '') or rope_scaling.get('type', '')
            if rope_scaling.get('mrope_section') is not None:
                # TODO: treat mrope as an option to the common rope functions
                rope_type = 'mrope'
            scaling_factor = rope_scaling.get('factor', 0.0)
            if rope_type == 'default':
                pass
            elif rope_type == 'dynamic':
                rope_param.type = 'dynamic'
                rope_param.factor = scaling_factor
                rope_param.max_position_embeddings = max_position_embeddings
            elif rope_type == 'linear':
                rope_param.type = 'linear'
                rope_param.factor = scaling_factor
            elif rope_type == 'llama3':
                low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
                high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
                original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings', 0)
                rope_param.type = 'llama3'
                rope_param.factor = scaling_factor
                rope_param.low_freq_factor = low_freq_factor
                rope_param.high_freq_factor = high_freq_factor
                rope_param.original_max_position_embeddings = original_max_position_embeddings
            elif rope_type == 'yarn':
                attention_factor = rope_scaling.get('attention_factor', None)
                if attention_factor is None:
                    attention_factor = 0.1 * math.log(scaling_factor) + 1.0
                beta_fast = rope_scaling.get('beta_fast', 32.0)
                beta_slow = rope_scaling.get('beta_slow', 1.0)
                rope_param.type = 'yarn'
                if 'original_max_position_embeddings' in rope_scaling:
                    original_max_position_embeddings = rope_scaling['original_max_position_embeddings']
                    scaling_factor = max_position_embeddings / original_max_position_embeddings
                else:
                    original_max_position_embeddings = max_position_embeddings
                rope_param.factor = scaling_factor
                rope_param.max_position_embeddings = original_max_position_embeddings
                rope_param.attention_factor = attention_factor
                rope_param.beta_fast = beta_fast
                rope_param.beta_slow = beta_slow
            elif rope_type == 'mrope':
                mrope_section = rope_scaling.get('mrope_section')
                rope_param.type = 'mrope'
                rope_param.mrope_section = mrope_section
            else:
                raise RuntimeError(f'Unsupported rope type: {rope_type}')

        return dict(size_per_head=head_dim,
                    num_layer=num_layer,
                    norm_eps=norm_eps,
                    head_num=attn_head_num,
                    kv_head_num=kv_head_num,
                    hidden_units=hidden_units,
                    inter_size=inter_size,
                    vocab_size=vocab_size,
                    max_position_embeddings=max_position_embeddings,
                    rope_param=rope_param)


================================================
FILE: lmdeploy/turbomind/deploy/source_model/llava.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp

from ..config import RopeParam
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class LlavaReader(LlamaReader):
    """LlavaReader for llama model."""

    attn_layer_prefix = 'language_model.model.layers'
    attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'language_model.model.embed_tokens.weight'
    norm_weight_key = 'language_model.model.norm.weight'
    output_weight_key = 'language_model.lm_head.weight'

    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy):
        model_cfg = model_cfg.get('text_config')
        super().__init__(new_params, unused_params, last_bin, model_cfg, policy)


@INPUT_MODELS.register_module(name='llava')
class LlavaModel(LlamaModel):
    """LlavaModel model in hf format."""

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
        super().__init__(model_path, tokenizer_path, **kwargs)
        from transformers import AutoConfig
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        config = getattr(config, 'text_config', config)
        arch = config.architectures[0]
        _readers = dict(Qwen2ForCausalLM=LlavaReader, LlamaForCausalLM=LlavaReader)
        self.Reader = _readers[arch]
        self.arch = arch

    def model_info(self):
        """Read model info for LlavaForConditionalGeneration.

        https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf
        """
        params_path = osp.join(self.model_path, 'config.json')
        with open(params_path) as f:
            model_arg = json.load(f)['text_config']
            num_layer = model_arg.get('num_hidden_layers', 32)
            norm_eps = model_arg.get('rms_norm_eps', 1e-6)
            attn_head_num = model_arg.get('num_attention_heads', 32)
            if 'num_key_value_heads' in model_arg:
                kv_head_num = model_arg.get('num_key_value_heads', 32)
            else:
                kv_head_num = model_arg.get('num_attention_heads', 32)
            rope_theta = float(model_arg.get('rope_theta', 10000.0))
            max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
            rope_scaling = model_arg.get('rope_scaling', None)
            scaling_factor = 0.0
            scaling_type = 'default'

            # special for the model: llava-hf/llava-interleave-qwen-7b-hf
            hidden_units = model_arg.get('hidden_size', 4096)
            vocab_size = model_arg.get('vocab_size', 152000)
            intermediate_size = model_arg.get('intermediate_size', 11008)
            attn_bias = 1 if model_arg['architectures'][0] \
                == 'Qwen2ForCausalLM' else 0
            attn_bias = int(model_arg.get('attn_bias', attn_bias))
            use_logn_attn = int(model_arg.get('use_logn_attn', 0))

            if isinstance(rope_scaling, dict):
                scaling_type = model_arg['rope_scaling'].get('type', '')
                scaling_factor = model_arg['rope_scaling'].get('factor', '')

            rope_param = RopeParam(type=scaling_type,
                                   base=rope_theta,
                                   dim=hidden_units // attn_head_num,
                                   max_position_embeddings=max_position_embeddings,
                                   factor=scaling_factor)

        return dict(num_layer=num_layer,
                    norm_eps=norm_eps,
                    head_num=attn_head_num,
                    hidden_units=hidden_units,
                    kv_head_num=kv_head_num,
                    rope_param=rope_param,
                    max_position_embeddings=max_position_embeddings,
                    inter_size=intermediate_size,
                    use_logn_attn=use_logn_attn,
                    attn_bias=attn_bias,
                    vocab_size=vocab_size)


================================================
FILE: lmdeploy/turbomind/deploy/source_model/minicpmv.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import json
import os.path as osp

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class MiniCPMVReader(LlamaReader):
    """MiniCPMVReader for llama model."""

    attn_layer_prefix = 'llm.model.layers'
    attn_layer_patten = r'llm\.model\.layers\.([0-9]+).'
    tok_embeddings_key = 'llm.model.embed_tokens.weight'
    norm_weight_key = 'llm.model.norm.weight'
    output_weight_key = 'llm.lm_head.weight'


@INPUT_MODELS.register_module(name='minicpmv')
class MiniCPMVModel(LlamaModel):
    """MiniCPMV model in hf format."""
    Reader = MiniCPMVReader

    def model_info(self):
        info = super().model_info()
        with open(osp.join(self.model_path, 'config.json')) as f:
            config = json.load(f)
            if str(config.get('version')) == '2.6':
                info['attn_bias'] = True
        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/mixtral.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class MixtralReader(LlamaReader):

    def moe_ffn_expert(self, e=None, i=None, kind=None):
        if not kind:
            return self.filter(r'experts', i)
        result = []
        for x in ['w1', 'w2', 'w3']:
            name = f'model.layers.{i}.block_sparse_moe.experts.{e}.{x}.{kind}'
            tensor = self.params.get(name)
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def moe_ffn_gate(self, i, kind):
        return self.params.get(f'model.layers.{i}.block_sparse_moe.gate.{kind}')


@INPUT_MODELS.register_module(name='mixtral')
class MixtralModel(LlamaModel):

    Reader = MixtralReader

    def model_info(self):
        cfg = self.model_config
        info = super().model_info()
        info['expert_num'] = cfg['num_local_experts']
        info['expert_inter_size'] = cfg['intermediate_size']
        info['experts_per_token'] = cfg['num_experts_per_tok']
        info['norm_topk_prob'] = True
        info['inter_size'] = 0
        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/molmo.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp

import torch

from ..config import RopeParam
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class MolmoReader(LlamaReader):
    attn_layer_prefix = 'model.transformer.blocks'
    attn_layer_patten = r'model\.transformer\.blocks\.([0-9]+).'
    norm_weight_key = 'model.transformer.ln_f.weight'
    output_weight_key = 'model.transformer.ff_out.weight'

    # In molmo, names of attention parameters are "att_proj.bias",
    # "att_proj.weight", "attn_norm.weight", "attn_out.weight", and names
    # of ffn parameters are "ff_norm", "ff_out", "ff_proj", so we
    # make the patterns are r'att' and r'ffn_', respectively.
    attn_pattern = r'att'
    ffn_pattern = r'ff_'

    def tok_embeddings(self):
        embed1 = self.params.get('model.transformer.wte.embedding', None)
        embed2 = self.params.get('model.transformer.wte.new_embedding', None)
        if embed1 is not None and embed2 is not None:
            return torch.cat((embed1, embed2), dim=0)
        else:
            assert embed1 is None and embed2 is None
            return None

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.params[f'{self.attn_layer_prefix}.{i}.attn_norm.weight']

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind(weight, bias, qweight) for layer i.

        Args:
            i (int): layer id
            kind (str): can be one of ["weight", "bias", "qweight"]
        """
        q, k, v = (None, ) * 3
        hidden_size = self.model_cfg['hidden_size']
        head_num = self.model_cfg['num_attention_heads']
        kv_head_num = self.model_cfg['num_key_value_heads']
        head_dim = hidden_size // head_num
        assert head_dim == 128
        fused_dims = (hidden_size, kv_head_num * head_dim, kv_head_num * head_dim)
        qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.att_proj.{kind}')
        qkv = self.transform(qkv, kind)
        if qkv is not None:
            q, k, v = qkv.split(fused_dims, dim=0)
        o = self.params.get(f'{self.attn_layer_prefix}.{i}.attn_out.{kind}')
        o = self.transform(o, kind)
        if o is None:  # handle the case when qkv has bias but o doesn't
            o = torch.zeros_like(q)
        return (q, k, v, o)

    def _ffn(self, i: int, kind: str):
        """Get ffn kind(weight, qweight) for layer i."""
        up_and_gate = self.params[f'{self.attn_layer_prefix}.{i}.ff_proj.{kind}']
        up_and_gate = self.transform(up_and_gate, kind)
        gate, up = up_and_gate.chunk(2, dim=0)
        down = self.params[f'{self.attn_layer_prefix}.{i}.ff_out.{kind}']
        down = self.transform(down, kind)
        return (up, down, gate)

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.params[f'{self.attn_layer_prefix}.{i}.ff_norm.weight']


@INPUT_MODELS.register_module(name='molmo')
class MolmoModel(LlamaModel):

    Reader = MolmoReader

    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
        super().__init__(model_path, tokenizer_path, **kwargs)
        config_path = osp.join(self.model_path, 'config.json')
        with open(config_path) as f:
            self.config = json.load(f)

    def model_info(self):
        config = self.config
        num_layer = config['num_hidden_layers']
        norm_eps = config['layer_norm_eps']
        attn_head_num = config['num_attention_heads']
        kv_head_num = config['num_key_value_heads']
        hidden_units = config['hidden_size']
        rope_theta = config['rope_theta']
        max_position_embeddings = config['max_position_embeddings']
        vocab_size = config['vocab_size']
        # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L2041
        additional_vocab_size = 128
        inter_size = config['intermediate_size'] // 2
        attn_bias = config['qkv_bias']
        rope_param = RopeParam(type='default', base=rope_theta, dim=hidden_units // attn_head_num)
        return dict(
            num_layer=num_layer,
            norm_eps=norm_eps,
            head_num=attn_head_num,
            kv_head_num=kv_head_num,
            hidden_units=hidden_units,
            attn_bias=int(attn_bias),
            inter_size=inter_size,
            vocab_size=vocab_size,
            # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L564
            embedding_size=vocab_size + additional_vocab_size,
            rope_param=rope_param,
            max_position_embeddings=max_position_embeddings,
        )


================================================
FILE: lmdeploy/turbomind/deploy/source_model/qwen.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import re

import torch

from ..config import RopeParam
from ..loader import create_loader
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class QwenReader(LlamaReader):
    """QwenReader."""

    attn_layer_patten = r'transformer\.h\.([0-9]+).'
    tok_embeddings_key = 'transformer.wte.weight'
    norm_weight_key = 'transformer.ln_f.weight'
    output_weight_key = 'lm_head.weight'

    attn_pattern = r'attn'
    ffn_pattern = r'mlp'

    def _attn(self, i: int, kind: str):
        """Get q, k, v, o kind for layer i."""
        q, k, v, o = (None, ) * 4
        qkv = self.params[f'transformer.h.{i}.attn.c_attn.{kind}']
        qkv = self.transform(qkv, kind)
        if qkv is not None:
            q, k, v = torch.split(qkv, qkv.size(0) // 3, dim=0)
        o = self.params.get(f'transformer.h.{i}.attn.c_proj.{kind}')
        o = self.transform(o, kind)
        if o is None:
            o = torch.zeros_like(q)
        return q, k, v, o

    def attn_norm(self, i: int):
        """Get attn norm for layer i."""
        return self.params[f'transformer.h.{i}.ln_1.weight']

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        result = []
        for key in ['w2', 'c_proj', 'w1']:
            tensor = self.params[f'transformer.h.{i}.mlp.{key}.{kind}']
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def ffn_norm(self, i: int):
        """Get ffn norm for layer i."""
        return self.params[f'transformer.h.{i}.ln_2.weight']


@INPUT_MODELS.register_module(name='qwen')
class QwenModel(LlamaModel):
    """Qwen model in hf format."""

    Reader = QwenReader

    def model_info(self):
        """Read model info."""
        params_path = osp.join(self.model_path, 'config.json')
        with open(params_path) as f:
            config = json.load(f)
            hidden_units = config['hidden_size']
            num_layer = config['num_hidden_layers']
            norm_eps = config['layer_norm_epsilon']
            kv_channels = config['kv_channels']
            rope_theta = float(config.get('rotary_emb_base', 10000.0))
            if 'num_key_value_heads' in config:
                kv_head_num = config['num_key_value_heads']
            else:
                kv_head_num = config['num_attention_heads']
            attn_head_num = config['num_attention_heads']
            seq_length = config['seq_length']
            use_dynamic_ntk = int(config['use_dynamic_ntk'])
            use_logn_attn = int(config['use_logn_attn'])
            vocab_size = config['vocab_size']
            inter_size = config['intermediate_size']
            scaling_type = 'dynamic' if use_dynamic_ntk else 'default'
            # need setting rope_scaling_factor in TurbomindEngineConfig if scaling_type is dynamic
            rope_param = RopeParam(type=scaling_type,
                                   base=rope_theta,
                                   dim=kv_channels,
                                   max_position_embeddings=seq_length,
                                   factor=0)

        return dict(size_per_head=kv_channels,
                    num_layer=num_layer,
                    norm_eps=norm_eps,
                    hidden_units=hidden_units,
                    head_num=attn_head_num,
                    kv_head_num=kv_head_num,
                    vocab_size=vocab_size,
                    inter_size=inter_size,
                    attn_bias=1,
                    rope_param=rope_param,
                    max_position_embeddings=seq_length,
                    use_dynamic_ntk=int(use_dynamic_ntk),
                    use_logn_attn=use_logn_attn)


@INPUT_MODELS.register_module(name='qwen2')
class Qwen2Model(LlamaModel):
    """Qwen model in hf format.

    The weight of qwen2 model is similar to Llama, except its attention bias doesn't include o_proj bias.
    """

    Reader = LlamaReader

    def model_info(self):
        cfg = super().model_info()
        cfg['attn_bias'] = 1
        return cfg


class Qwen2MoeReader(LlamaReader):

    def moe_ffn_expert(self, e=None, i=None, kind=None):
        if not kind:
            return self.filter(r'experts', i)
        result = []
        for key in ['gate', 'down', 'up']:
            name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{e}.{key}_proj.{kind}'
            tensor = self.params.get(name)
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def moe_ffn_gate(self, i, kind):
        return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.{kind}'), kind)

    def _ffn(self, i: int, kind: str):
        """Get ffn kind for layer i."""
        if not kind:
            return self.filter(r'shared_expert\.', i)
        result = []
        for key in ['gate', 'down', 'up']:
            tensor = self.params[f'{self.attn_layer_prefix}.{i}.mlp.shared_expert.{key}_proj.{kind}']
            tensor = self.transform(tensor, kind)
            result.append(tensor)
        return (*result, )

    def ffn(self, i: int, kind: str):
        if not kind:
            return self.filter(r'shared_expert\.', i)
        return self._ffn(i, kind)

    def moe_ffn_shared_gate(self, i):
        return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.shared_expert_gate.weight')


@INPUT_MODELS.register_module(name='qwen2-moe')
class Qwen2MoeModel(LlamaModel):

    Reader = Qwen2MoeReader

    def model_info(self):
        cfg = self.model_config
        info = super().model_info()
        info['expert_num'] = cfg['num_experts']
        info['expert_inter_size'] = cfg['moe_intermediate_size']
        info['experts_per_token'] = cfg['num_experts_per_tok']
        info['inter_size'] = cfg['shared_expert_intermediate_size']
        info['moe_shared_gate'] = True
        info['norm_topk_prob'] = cfg['norm_topk_prob']
        info['attn_bias'] = cfg.get('qkv_bias', 1)
        return info


class Qwen3Reader(LlamaReader):

    def qk_norm(self, i: int):
        result = []
        for x in ['q', 'k']:
            name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'
            result.append(self.transform(self.params.get(name), 'weight'))
        return (*result, )


@INPUT_MODELS.register_module(name='qwen3')
class Qwen3Model(LlamaModel):
    Reader = Qwen3Reader

    def model_info(self):
        cfg = self.model_config
        info = super().model_info()
        info.update(qk_norm=True, attn_bias=cfg.get('attention_bias', 0))
        return info


class Qwen3MoeReader(Qwen2MoeReader):

    def qk_norm(self, i: int):
        result = []
        for x in ['q', 'k']:
            name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'
            result.append(self.transform(self.params.get(name), 'weight'))
        return (*result, )


@INPUT_MODELS.register_module(name='qwen3-moe')
class Qwen3MoeModel(LlamaModel):
    Reader = Qwen3MoeReader

    def model_info(self):
        cfg = self.model_config
        info = super().model_info()
        info.update(
            qk_norm=True,
            expert_num=cfg.get('num_experts', 128),
            experts_per_token=cfg.get('num_experts_per_tok', 8),
            expert_inter_size=cfg.get('moe_intermediate_size', 768),
            attn_bias=cfg.get('attention_bias', 0),
            inter_size=0,  # no shared expert
            norm_topk_prob=cfg.get('norm_topk_prob', False))
        return info


class Qwen3_5ReaderMixin:
    """Mixin providing linear attention weight reading for Qwen3.5 models.

    Qwen3.5 uses a zero-centered RMSNorm: ``output = norm(x) * (1 + weight)``
    where weight is initialized to zeros.  TurboMind's RMSNorm kernel computes
    ``norm(x) * weight`` (standard LLaMA style), so we add 1 to every
    RMSNorm weight during export.  The GDN-internal norm
    (``Qwen3_5MoeRMSNormGated``) uses standard weight and is NOT affected.
    """

    attn_layer_pattern = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.'

    _LINEAR_ATTN_KEYS = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if any(k.startswith('model.language_model.') for k in self.params.keys()):
            self.attn_layer_prefix = 'model.language_model.layers'
            self.tok_embeddings_key = 'model.language_model.embed_tokens.weight'
            self.norm_weight_key = 'model.language_model.norm.weight'
        tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False)
        if tie_word_embeddings:
            self.output_weight_key = self.tok_embeddings_key

    # ---- zero-centered RMSNorm: add 1 to weights during export ----
    def attn_norm(self, i: int):
        w = super().attn_norm(i)
        if w is not None:
            w = w.float() + 1.0
        return w

    def ffn_norm(self, i: int):
        w = super().ffn_norm(i)
        if w is not None:
            w = w.float() + 1.0
        return w

    def norm_weight(self):
        w = super().norm_weight()
        if w is not None:
            w = w.float() + 1.0
        return w

    def qk_norm(self, i: int):
        result = super().qk_norm(i)
        return tuple(w.float() + 1.0 if w is not None else w for w in result)

    # ---- handle mixed QKV(fp16) + O(AWQ) attention layers -------

    def _attn(self, i: int, kind: str):
        """Override to handle mixed QKV(fp16) + O(AWQ) attention layers.

        Some AWQ-quantized Qwen3.5 models keep QKV in fp16 while quantizing only the O projection.  TurboMind requires
        uniform weight types per layer, so we dequantize O to fp16 at export time.
        """
        prefix = f'{self.attn_layer_prefix}.{i}.self_attn'
        q_is_fp16 = f'{prefix}.q_proj.weight' in self.params
        o_is_awq = f'{prefix}.o_proj.qweight' in self.params

        if not (q_is_fp16 and o_is_awq):
            # Not a mixed-format layer, use standard behaviour.
            return super()._attn(i, kind)

        # Mixed format detected: QKV are fp16 but O is AWQ.
        if kind == 'weight':
            # Get fp16 QKV the normal way, then dequantize O.
            q, k, v, _ = super()._attn(i, kind)
            o = self._awq_dequant(f'{prefix}.o_proj')
            o = self.transform(o, kind)
            return (q, k, v, o)

        # For any quant kind (qweight/scales/qzeros), return all None
        # so that the AWQ handler skips this layer entirely — the O
        # weight is already handled via dequantization above.
        return (None, None, None, None)

    def _awq_dequant(self, prefix: str):
        """Dequantize an AWQ-quantized linear layer to fp16.

        AWQ stores weights in transposed form relative to PyTorch's
        convention ([in, out] vs [out, in]), so we transpose here to
        match the fp16 ``.weight`` layout that downstream export
        expects.
        """
        from lmdeploy.pytorch.backends.default.awq_modules import dequantize_gemm
        qweight = self.params[f'{prefix}.qweight']
        scales = self.params[f'{prefix}.scales']
        qzeros = self.params[f'{prefix}.qzeros']
        group_size = qweight.shape[0] // scales.shape[0]
        w = dequantize_gemm(qweight, qzeros, scales, 4, group_size)
        return w.t()  # [in, out] → [out, in] (PyTorch convention)

    def linear_attn(self, i: int, kind: str):
        if not kind:
            return self.filter(r'linear_attn\.', i)
        # Always return a fixed-length tuple with None placeholders to
        # preserve positional alignment with the name list in module.py.
        result = []
        for key in self._LINEAR_ATTN_KEYS:
            prefix = f'{self.attn_layer_prefix}.{i}.linear_attn.{key}'
            tensor = self.params.get(f'{prefix}.{kind}')
            # A_log and dt_bias are bare nn.Parameter (no .weight suffix)
            if tensor is None:
                tensor = self.params.get(prefix)
            # If requesting weight but only AWQ qweight exists,
            # dequantize on the fly so LinearAttn gets fp16 tensors.
            if tensor is None and kind == 'weight':
                if f'{prefix}.qweight' in self.params:
                    tensor = self._awq_dequant(prefix)
            if tensor is not None:
                tensor = self.transform(tensor, kind)
            result.append(tensor)  # keep None to preserve alignment
        if all(t is None for t in result):
            return tuple()
        return tuple(result)

    def linear_norm(self, i: int, kind: str = 'weight'):
        tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.linear_attn.norm.{kind}')
        if tensor is not None:
            return self.transform(tensor, kind)
        return None


class Qwen3_5Reader(Qwen3_5ReaderMixin, Qwen3Reader):
    pass


@INPUT_MODELS.register_module(name='qwen3_5')
class Qwen3_5Model(Qwen3Model):
    Reader = Qwen3_5Reader

    def model_info(self):
        if 'text_config' in self.model_config:
            self.model_config = self.model_config['text_config']
        cfg = self.model_config
        info = super().model_info()
        # MoE parameters (same as Qwen2MoeModel / Qwen3MoeModel)
        info['expert_num'] = cfg.get('num_experts', 0)
        info['expert_inter_size'] = cfg.get('moe_intermediate_size', 0)
        info['experts_per_token'] = cfg.get('num_experts_per_tok', 0)
        # For MoE models, inter_size is the shared expert intermediate size;
        # for dense models, keep the value from super() (intermediate_size).
        shared_expert_size = cfg.get('shared_expert_intermediate_size')
        if shared_expert_size is not None:
            info['inter_size'] = shared_expert_size
        info['moe_shared_gate'] = True
        # Qwen3.5 uses sigmoid MoE routing (not softmax)
        info['scoring_func'] = 'softmax'
        info['norm_topk_prob'] = True
        # Fix RoPE dim for partial_rotary_factor
        rope_params = cfg.get('rope_parameters', {})
        partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0))
        if partial_rotary_factor < 1.0:
            info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor)
        # Linear attention parameters
        info['layer_types'] = cfg.get('layer_types', [])
        info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0)
        info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0)
        info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0)
        info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0)
        info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0)
        # attn_output_gate doubles Q projection for full-attention layers
        info['attn_output_gate'] = cfg.get('attn_output_gate', False)
        return info


class Qwen3_5MoeReader(Qwen3_5ReaderMixin, Qwen3MoeReader):

    def _unpacked_moe_expert(self, e: int, i: int, kind: str):
        prefix = f'{self.attn_layer_prefix}.{i}.mlp.experts'
        gate_up = self.params.get(f'{prefix}.gate_up_proj.{kind}')
        down = self.params.get(f'{prefix}.down_proj.{kind}')
        if gate_up is None or down is None:
            return None

        # Packed Qwen3.5 MoE checkpoints store all experts in the first
        # dimension. Slice one expert before transform so quantized policies
        # still see a 2D tensor.
        gate_up = self.transform(gate_up[e], kind)
        down = self.transform(down[e], kind)
        gate, up = gate_up.chunk(2, dim=0)
        return (gate, down, up)

    def moe_ffn_expert(self, e=None, i=None, kind=None):
        if not kind:
            return self.filter(r'experts', i)
        unpacked = self._unpacked_moe_expert(e, i, kind)
        if unpacked is not None:
            return unpacked

        return super().moe_ffn_expert(e, i, kind)


@INPUT_MODELS.register_module(name='qwen3_5-moe')
class Qwen3_5MoeModel(Qwen3MoeModel):
    Reader = Qwen3_5MoeReader

    @staticmethod
    def map_packed_qwen35_experts(name: str):
        """Map packed expert names to weight names, i.e.,
        "mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj.weight" so that
        class Weight in parameter.py can classify them."""
        s = re.sub(r'(mlp\.experts\.(?:gate_up|down)_proj)$', r'\1.weight', name)
        return s

    def readers(self):
        pattern = getattr(self.Reader, 'attn_layer_pattern', self.Reader.attn_layer_patten)
        loader = create_loader(self.model_path, pattern, [])

        has_packed_gate_up = any('mlp.experts.gate_up_proj' in k for k in loader.index.keys())
        has_packed_down = any('mlp.experts.down_proj' in k for k in loader.index.keys())
        if has_packed_gate_up and has_packed_down:
            loader.mappings = [self.map_packed_qwen35_experts]

        for i, param in loader.items():
            reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant)
            yield i, reader
        torch.cuda.empty_cache()

    def model_info(self):
        if 'text_config' in self.model_config:
            self.model_config = self.model_config['text_config']
        cfg = self.model_config
        info = super().model_info()
        # Shared expert params (missing from Qwen3MoeModel base)
        info['inter_size'] = cfg.get('shared_expert_intermediate_size', 0)
        info['moe_shared_gate'] = True
        # Qwen3.5 uses sigmoid MoE routing (not softmax)
        info['scoring_func'] = 'softmax'
        info['norm_topk_prob'] = True
        # Fix RoPE dim for partial_rotary_factor
        rope_params = cfg.get('rope_parameters', {})
        partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0))
        if partial_rotary_factor < 1.0:
            info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor)
        # Linear attention parameters
        info['layer_types'] = cfg.get('layer_types', [])
        info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0)
        info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0)
        info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0)
        info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0)
        info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0)
        # attn_output_gate doubles Q projection for full-attention layers
        info['attn_output_gate'] = cfg.get('attn_output_gate', False)
        return info


================================================
FILE: lmdeploy/turbomind/deploy/source_model/xcomposer2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .base import INPUT_MODELS
from .internlm2 import InternLM2Model, InternLM2Reader


class Xcomposer2Reader(InternLM2Reader):
    """Xcomposer2 model reader."""

    # include only Plora and ignore other lora weights
    attn_pattern = r'attention.\w+(.Plora_[AB])?.\w+$'
    ffn_pattern = r'feed_forward.\w+(.Plora_[AB])?.\w+$'

    def _attn(self, i, kind):
        if 'Plora_A' in kind:
            qkv = self.params[f'model.layers.{i}.attention.wqkv.Plora_A.weight']
            o = self.params[f'model.layers.{i}.attention.wo.Plora_A.weight']
            return qkv, o
        return super()._attn(i, kind)


@INPUT_MODELS.register_module(name='xcomposer2')
class Xcomposer2Model(InternLM2Model):
    """Xcomposer2 model in hf format."""

    Reader = Xcomposer2Reader

    def _lora_cfg_7b(self):
        """Lora config for internlm-xcomposer2-7b."""
        return dict(lora_r=256, lora_scale=1.0, lora_policy='plora', lora_max_wo_r=256)

    def _lora_cfg_4khd_7b(self, model_info: dict):
        """Lora config for internlm-xcomposer2-4khd-7b."""
        rank_pattern = ['attention.w_qkv:8', 'attention.wo:256']
        scale_pattern = ['attention.w_qkv:2.0', 'attention.wo:1.0']
        rank_pattern = ','.join(rank_pattern)
        scale_pattern = ','.join(scale_pattern)
        return dict(lora_r=256,
                    lora_scale=1.0,
                    lora_max_wo_r=256,
                    lora_policy='plora',
                    lora_rank_pattern=rank_pattern,
                    lora_scale_pattern=scale_pattern)

    def model_info(self):
        out = super().model_info()
        from lmdeploy.vl.model.xcomposer2 import ModelType, get_xcomposer_type
        model_type, _ = get_xcomposer_type(self.model_path)
        if model_type == ModelType.XCOMPOSER2_4KHD:
            out.update(self._lora_cfg_4khd_7b(out))
        else:
            out.update(self._lora_cfg_7b())
        return out


================================================
FILE: lmdeploy/turbomind/deploy/target_model/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .fp import TurbomindModel  # noqa: F401


================================================
FILE: lmdeploy/turbomind/deploy/target_model/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import os.path as osp
from abc import ABC
from collections.abc import Sequence

import torch
import tqdm
import yaml
from mmengine import Registry

from ..config import AttentionConfig, LoraConfig, ModelConfig, TurbomindModelConfig, config_from_dict, config_to_dict
from ..source_model.base import BaseInputModel

OUTPUT_MODELS = Registry('target model', locations=['lmdeploy.turbomind.deploy.target_model.base'])


def tprint(*args, **kwargs):
    to_file = kwargs.pop('to_file', False)
    if not to_file:
        return
    from io import StringIO
    s = StringIO()
    print(*args, **kwargs, file=s, end='')
    tqdm.tqdm.write(s.getvalue())


def _weight_dtype_map(weight_type: str, default=None):
    """Map literal data type to torch dtype."""

    _WEIGHT_DTYPE_MAP = dict(int4=torch.float16, float16=torch.float16, float32=torch.float16, bfloat16=torch.bfloat16)

    return _WEIGHT_DTYPE_MAP.get(weight_type, default)


def _pad_inter_size(inter_size: int, group_size: int, tp: int):
    group_size = max(1, group_size)
    group_num = (inter_size + group_size - 1) // group_size
    groups_per_rank = (group_num + tp - 1) // tp
    inter_size_padded = groups_per_rank * group_size * tp
    return inter_size_padded


class BaseOutputModel(ABC):
    """Base output model."""

    def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model_cls, out_dir: str = ''):
        super().__init__()
        self.input_model = input_model
        self.model_config = cfg.model_config
        self.attention_config = cfg.attention_config
        self.lora_config = cfg.lora_config
        self.attn_tp_size = self.model_config.attn_tp_size
        self.attn_cp_size = self.model_config.attn_cp_size
        self.mlp_tp_size = self.model_config.mlp_tp_size
        self.out_dir = out_dir
        self.to_file = True if out_dir else False
        self.tm_params = dict()

        # get `model_info` at first, which will be updated to `self.model_config` and `self.attention_config`
        self.input_model_info = self.input_model.model_info()
        self.input_model_info = self.single_to_list(self.input_model_info, keys=['inter_size', 'expert_num'])
        self.permute_qk = self.input_model_info.get('permute_qk', True)
        self.update_model_config()
        for i, v in enumerate(self.model_config.inter_size):
            self.model_config.inter_size[i] = _pad_inter_size(v, self.model_config.group_size, self.mlp_tp_size)
        if self.model_config.expert_num:
            self.model_config.expert_inter_size = _pad_inter_size(self.model_config.expert_inter_size,
                                                                  self.model_config.group_size, self.mlp_tp_size)

        # head_num is divisble by tp but kv_head_num is not
        # and tp is divisble by kv_head_num
        assert self.model_config.head_num % self.attn_tp_size == 0
        self.repeat_kv = 0
        if (self.attn_tp_size > self.model_config.kv_head_num
                and self.attn_tp_size % self.model_config.kv_head_num == 0):
            self.repeat_kv = (self.attn_tp_size // self.model_config.kv_head_num)
            self.model_config.kv_head_num = self.attn_tp_size

        self.model_config.verify()
        assert self.model_config.kv_head_num % self.attn_tp_size == 0

        # print(self.model_config)

        self.update_attention_config()
        self.update_lora_config()
        # ! Dependency on `self`
        self.model = model_cls(self)

    def single_to_list(self, config: dict, keys):
        num_layer = int(config['num_layer'])
        for k in keys:
            v = config.get(k, None)
            if v is not None and not isinstance(v, Sequence):
                config[k] = [v] * num_layer
        return config

    def update_model_config(self):
        """Update `self.model_config` according to the input_model's
        `model_info`"""
        final_cfg = config_to_dict(self.model_config)
        final_cfg.update(self.input_model_info)
        if 'embedding_size' not in self.input_model_info.keys():
            final_cfg.update(embedding_size=self.input_model_info['vocab_size'])

        self.model_config = config_from_dict(ModelConfig, final_cfg)

    def update_attention_config(self):
        """Update attention config according to input model's model info."""
        final_cfg = config_to_dict(self.attention_config)
        final_cfg.update(self.input_model_info)
        self.attention_config = config_from_dict(AttentionConfig, final_cfg)

    def update_lora_config(self):
        """Update lora config according to input model's model info."""
        final_cfg = config_to_dict(self.lora_config)
        final_cfg.update(self.input_model_info)
        self.lora_config = config_from_dict(LoraConfig, final_cfg)

    def export_config(self) -> None:
        """Export turbomind config."""
        if self.to_file:
            config_path = osp.join(self.out_dir, 'config.yaml')
            with open(config_path, 'w') as f:
                yaml.safe_dump(self.tm_config.to_dict(), f)

    def export_weight(self, param: torch.Tensor, name: str) -> None:
        """Export turbomind weight."""

        def _tofile(tensor, path):
            """To file."""
            if tensor.dtype == torch.bfloat16:
                tensor = tensor.view(torch.half)
            tensor.contiguous().cpu().numpy().tofile(path)

        if self.to_file:
            if torch.is_floating_point(param):
                torch_type = _weight_dtype_map(self.model_config.weight_type, torch.float16)
                param = param.to(torch_type)
            tprint(name, param.shape)
            _tofile(param, osp.join(self.out_dir, name))
        elif len(self.tm_params) > 0:
            tm_params = self.tm_params
            weight_type = self.model_config.weight_type
            data_type = self.model_config.data_type
            assert weight_type in ['float16', 'bfloat16', 'int4', 'fp8']

            # currently, the tensor type should in
            # [torch.float, torch.half, torch.bfloat16, torch.int32]
            torch_tensor = param if param.is_contiguous() else param.contiguous()
            torch_tensor = torch_tensor.cuda()
            assert torch_tensor.dtype in [torch.int32, torch.float, torch.half, torch.bfloat16, torch.uint8]
            FLOAT_TYPES = [torch.float, torch.half, torch.bfloat16]
            if weight_type == 'fp8':
                # avoid casting float scales to half
                if torch_tensor.dtype == torch.bfloat16 and data_type == 'float16':
                    torch_tensor = torch_tensor.half()
            elif torch_tensor.dtype in FLOAT_TYPES:
                if weight_type in ['float16', 'int4']:
                    torch_tensor = torch_tensor.half()
                elif weight_type == 'bfloat16':
                    torch_tensor = torch_tensor.bfloat16()
                else:
                    torch_tensor = torch_tensor.half()
            if name in tm_params:
                try:
                    import _turbomind as _tm
                except ImportError:
                    _tm = None
                for tm_tensor in tm_params[name]:
                    # Match TurboMind tensor dtype to avoid byte_size mismatch (e.g. f32 256b vs f16 128b)
                    if _tm is not None:
                        if tm_tensor.type == _tm.DataType.TYPE_FP32 and torch_tensor.dtype in [
                                torch.float16, torch.bfloat16
                        ]:
                            torch_tensor = torch_tensor.float()
                        elif tm_tensor.type == _tm.DataType.TYPE_FP16 and torch_tensor.dtype == torch.float32:
                            torch_tensor = torch_tensor.half()
                    tm_tensor.copy_from(torch_tensor)
                tm_params.pop(name)
        else:
            tprint('skip export', name, param.shape)

    def save_split(self, tensor: torch.Tensor, name: str, split_dim=None, split_num=1, copy=False) -> None:
        """Save split.

        - 2D input
            shape must be (input_dims, output_dims)
        - 1D input (bias)
            shape must be (output_dims)
            split is skipped when split_dim == 0
        """

        if copy or (tensor.dim() == 1 and split_dim == 0):
            split_dim = None
            copy = True

        if split_dim is not None:
            tprint(f'*** splitting {name}, shape={tensor.shape}, '
                   f'split_dim={split_dim}, split_num={split_num}',
                   to_file=self.to_file)
            if tensor.shape[split_dim] % split_num != 0:
                raise RuntimeError(f'{name}: shape={list(tensor.shape)}, split_num={split_num}')
            split_size = tensor.shape[split_dim] // split_num
            splits = torch.split(tensor, split_size, dim=split_dim)
            for i, split in enumerate(splits):
                prefix, ext = osp.splitext(name)
                self.export_weight(split, f'{prefix}.{i}{ext}')
        elif copy:
            tprint(f'### copying {name}, shape={tensor.shape}', to_file=self.to_file)
            copies = [tensor] * split_num
            for i, copy in enumerate(copies):
                prefix, ext = osp.splitext(name)
                self.export_weight(copy, f'{prefix}.{i}{ext}')
        else:
            self.export_weight(tensor, name)

    def export(self) -> None:
        """Export to turbomind model format."""
        num_layer = self.model_config.num_layer
        from tqdm import tqdm
        pbar = tqdm(total=num_layer, desc='Convert to turbomind format', leave=self.to_file)
        self.export_config()
        for i, reader in self.input_model.readers():
            if self.model(i, reader):
                pbar.update(1)
        pbar.close()

    def export_iter(self):
        self.export_config()
        for i, reader in self.input_model.readers():
            self.model(i, reader)
            yield i

    @property
    def tm_config(self):
        return TurbomindModelConfig(model_config=self.model_config,
                                    attention_config=self.attention_config,
                                    lora_config=self.lora_config)


================================================
FILE: lmdeploy/turbomind/deploy/target_model/fp.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from .base import OUTPUT_MODELS, BaseOutputModel


@OUTPUT_MODELS.register_module(name='tm')
class TurbomindModel(BaseOutputModel):
    """Export to turbomind fp16 format."""
    pass


================================================
FILE: lmdeploy/turbomind/supported_models.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.archs import get_model_arch, search_nested_config
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

SUPPORTED_ARCHS = dict(
    # baichuan-7b
    BaiChuanForCausalLM='baichuan',
    # baichuan2-7b, baichuan-13b, baichuan2-13b
    BaichuanForCausalLM='baichuan2',
    # gpt-oss
    GptOssForCausalLM='gpt-oss',
    # internlm
    InternLMForCausalLM='llama',
    # internlm2
    InternLM2ForCausalLM='internlm2',
    # internlm3
    InternLM3ForCausalLM='llama',
    # llama, llama2, alpaca, vicuna, codellama, ultracm, yi,
    # deepseek-coder, deepseek-llm
    LlamaForCausalLM='llama',
    # Qwen 7B-72B, Qwen-VL-7B
    QWenLMHeadModel='qwen',
    # Qwen2
    Qwen2ForCausalLM='qwen2',
    Qwen2MoeForCausalLM='qwen2-moe',
    # Qwen2-VL
    Qwen2VLForConditionalGeneration='qwen2',
    # Qwen2.5-VL
    Qwen2_5_VLForConditionalGeneration='qwen2',
    # Qwen3
    Qwen3ForCausalLM='qwen3',
    Qwen3MoeForCausalLM='qwen3-moe',
    # Qwen 3.5
    Qwen3_5ForConditionalGeneration='qwen3_5',
    Qwen3_5MoeForConditionalGeneration='qwen3_5-moe',
    # mistral
    MistralForCausalLM='llama',
    # llava
    LlavaLlamaForCausalLM='llama',
    LlavaMistralForCausalLM='llama',
    LlavaForConditionalGeneration='llava',
    # xcomposer2
    InternLMXComposer2ForCausalLM='xcomposer2',
    # internvl
    InternVLChatModel='internvl',
    # internvl3
    InternVLForConditionalGeneration='internvl',
    InternS1ForConditionalGeneration='internvl',
    # deepseek-vl
    MultiModalityCausalLM='deepseekvl',
    DeepseekV2ForCausalLM='deepseek2',
    # MiniCPMV
    MiniCPMV='minicpmv',
    # chatglm2/3, glm4
    ChatGLMModel='glm4',
    ChatGLMForConditionalGeneration='glm4',
    # glm4-moe-lite (e.g. GLM-4.7-Flash)
    Glm4MoeLiteForCausalLM='glm4-moe-lite',
    # mixtral
    MixtralForCausalLM='mixtral',
    MolmoForCausalLM='molmo',
)


def is_supported(model_path: str):
    """Check whether supported by turbomind engine.

    Args:
        model_path (str): the path of a model.
            It could be one of the following options:
                - i) A local directory path of a turbomind model which is
                    converted by `lmdeploy convert` command or download from
                    ii) and iii).
                - ii) The model_id of a lmdeploy-quantized model hosted
                    inside a model repo on huggingface.co, such as
                    "InternLM/internlm-chat-20b-4bit",
                    "lmdeploy/llama2-chat-70b-4bit", etc.
                - iii) The model_id of a model hosted inside a model repo
                    on huggingface.co, such as "internlm/internlm-chat-7b",
                    "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                    and so on.
    Returns:
        support_by_turbomind (bool): Whether input model is supported by turbomind engine
    """  # noqa: E501
    import os

    def _is_head_dim_supported(cfg):
        head_dim = cfg.head_dim if hasattr(cfg, 'head_dim') else cfg.hidden_size // cfg.num_attention_heads
        return head_dim in [128, 64]

    support_by_turbomind = False
    triton_model_path = os.path.join(model_path, 'triton_models')
    if os.path.exists(triton_model_path):
        support_by_turbomind = True
    else:

        arch, cfg = get_model_arch(model_path)
        quant_method = search_nested_config(cfg.to_dict(), 'quant_method')
        if quant_method and quant_method in ['smooth_quant']:
            # tm hasn't support quantized models by applying smoothquant
            return False

        if arch in SUPPORTED_ARCHS.keys():
            support_by_turbomind = True
            # special cases
            if arch == 'BaichuanForCausalLM':
                num_attn_head = cfg.num_attention_heads
                if num_attn_head == 40:
                    # baichuan-13B, baichuan2-13B not supported by turbomind
                    support_by_turbomind = False
            elif arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
                support_by_turbomind = _is_head_dim_supported(cfg)
            elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'):
                # chatglm1/2/3 is not working yet
                support_by_turbomind = cfg.num_layers == 40
                if getattr(cfg, 'vision_config', None) is not None:
                    # glm-4v-9b not supported
                    support_by_turbomind = False
            elif arch == 'InternVLChatModel':
                llm_arch = cfg.llm_config.architectures[0]
                support_by_turbomind = (llm_arch in SUPPORTED_ARCHS and _is_head_dim_supported(cfg.llm_config))
            elif arch in ['LlavaForConditionalGeneration', 'InternVLForConditionalGeneration']:
                llm_arch = cfg.text_config.architectures[0]
                if llm_arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
                    support_by_turbomind = _is_head_dim_supported(cfg.text_config)
            elif arch == 'MolmoForCausalLM':
                kv_heads = cfg.num_key_value_heads
                # TM hasn't supported allenai/Molmo-7B-O-0924 yet
                support_by_turbomind = kv_heads is not None
            elif arch == 'DeepseekV2ForCausalLM':
                if getattr(cfg, 'vision_config', None) is not None:
                    support_by_turbomind = False
            elif arch == 'Glm4MoeLiteForCausalLM':
                if getattr(cfg, 'vision_config', None) is not None:
                    support_by_turbomind = False

    return support_by_turbomind


================================================
FILE: lmdeploy/turbomind/tokenizer_info.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Borrowed from xgrammar's TokenizerInfo
"""This module provides the tokenizer info class to handle the tokenizer
information."""

import json
import logging
from enum import Enum
from typing import List, Optional, Union

import _xgrammar as _xgr  # noqa: E402

try:
    import sentencepiece
except ImportError:
    sentencepiece = None
try:
    import tiktoken
except ImportError:
    tiktoken = None

from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast

logger = logging.getLogger(__name__)


class VocabType(Enum):
    """The type of the vocabulary.

    Used in TokenizerInfo. XGrammar supports three types of
    vocabularies: RAW, BYTE_FALLBACK, BYTE_LEVEL.
    """

    RAW = 0
    """The vocabulary is in the raw format.

    The tokens in the vocabulary are kept in their original form without any processing. This kind of tokenizer includes
    the tiktoken tokenizer, e.g. microsoft/Phi-3-small-8k-instruct, Qwen/Qwen-7B-Chat, etc.
    """

    BYTE_FALLBACK = 1
    r"""The vocabulary used in the byte fallback BPE tokenizer.

    The tokens are encoded through the byte-fallback conversion. E.g. "\u001b" -> "<0x1B>", " apple" -> "▁apple". This
    kind of tokenizer includes meta-llama/Llama-2-7b-chat, microsoft/Phi-3.5-mini-instruct, etc.
    """

    BYTE_LEVEL = 2
    """The vocabulary used in the byte level BPE tokenizer.

    The tokens are encoded through the byte-to-unicode conversion, as in
    https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59

    This kind of tokenizer includes meta-llama/Meta-Llama-3-8B-Instruct,
    meta-llama/Meta-Llama-3.1-8B-Instruct, etc.
    """


class TokenizerInfo(_xgr.TokenizerInfo):
    """The tokenizer info contains the vocabulary, the type of the vocabulary,
    and necessary information for the grammar-guided generation.

    Note that although some tokenizers will encode the tokens in a special format, e.g. "<0x1B>" for "\u001b" in the
    ByteFallback tokenizer, and "Ġ" for " " in the Byte-Level BPE tokenizer, TokenizerInfo always decodes the vocabulary
    to the original format (e.g. "\u001b" and " ").

    Also note that some models (e.g. Phi-3 and Deepseek-V2) may pad the vocabulary to a multiple of 32. In this case,
    the model's vocab_size is larger than the tokenizer's vocabulary size. Please pass the model's vocab_size to the
    vocab_size parameter in the constructor, because this information is used to determine the size of the token mask.
    """

    def __init__(
        self,
        encoded_vocab: Union[List[bytes], List[str]],
        vocab_type: VocabType = VocabType.RAW,
        *,
        vocab_size: Optional[int] = None,
        stop_token_ids: Optional[Union[List[int], int]] = None,
        add_prefix_space: bool = False,
    ) -> None:
        """Construct the tokenizer info.

        Parameters
        ----------
        encoded_vocab : Union[List[bytes], List[str]]
            The encoded vocabulary of the tokenizer.

        vocab_type : VocabType, default: VocabType.RAW
            The type of the vocabulary. See also VocabType.

        vocab_size : Optional[int], default: None
            The size of the vocabulary. If not provided, the vocabulary size will be len(encoded_vocab).

        stop_token_ids : Optional[List[int]], default: None
            The stop token ids. If not provided, the stop token ids will be auto detected (but may not
            be correct).

        add_prefix_space : bool, default: False
            Whether the tokenizer will prepend a space before the text in the tokenization process.
        """
        if isinstance(stop_token_ids, int):
            stop_token_ids = [stop_token_ids]

        super().__init__(encoded_vocab, vocab_type.value, vocab_size, stop_token_ids, add_prefix_space)

    @staticmethod
    def _is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
        if tiktoken is None:
            return False

        # helper to check if tokenizer is a tiktoken tokenizer
        has_tiktoken_encoding = hasattr(tokenizer, 'tokenizer') and isinstance(tokenizer.tokenizer, tiktoken.Encoding)

        filename_pattern = (hasattr(tokenizer, 'vocab_files_names') and 'vocab_file' in tokenizer.vocab_files_names
                            and 'tiktoken' in tokenizer.vocab_files_names['vocab_file'])

        return has_tiktoken_encoding or filename_pattern

    @staticmethod
    def _is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
        if sentencepiece is None:
            return False

        # helper to check if tokenizer is a sentence piece tokenizer
        has_sp_model_attr = hasattr(tokenizer, 'sp_model') and isinstance(tokenizer.sp_model,
                                                                          sentencepiece.SentencePieceProcessor)

        has_nested_sp_model_attr = (hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model')
                                    and isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor))

        return has_sp_model_attr or has_nested_sp_model_attr

    @staticmethod
    def from_huggingface(
        tokenizer: PreTrainedTokenizerBase,
        *,
        vocab_size: Optional[int] = None,
        stop_token_ids: Optional[Union[List[int], int]] = None,
    ) -> 'TokenizerInfo':
        """Construct the tokenizer info from the huggingface tokenizer. This
        constructor supports various tokenizer backends, including the
        huggingface fast tokenizer and tiktoken tokenizer. Necessary
        information is automatically detected from the tokenizer.

        The vocab_size parameter is introduced to handle the misalignment between the model's
        vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size
        (could be defined in the model config) here. See docs of vocab_size for more details.

        The stop token ids is by default the eos_token_id of the tokenizer. If there are other
        stop tokens, you can specify them manually.

        Parameters
        ----------
        tokenizer : PreTrainedTokenizerBase
            The huggingface tokenizer.

        vocab_size : Optional[int], default: None
            The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the
            vocab dimension of the model's lm_head. This is the size of the token mask.

            It can be:

            1. the same as the tokenizer's vocabulary size. This is the most common case.
            2. larger than the tokenizer's vocabulary size. This happens when the model has padding
               to lm_head, possibly due to aligning lm_head to the power of 2.
               E.g. Phi-3 and Deepseek-V2.
            3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has
               some added tokens that will not supported by the model. E.g.
               Llama-3.2 Vision and Molmo-72B-0924 has padded `<|image|>` tokens, but they will not
               be considered in lm_head or generated by the model.

            model_vocab_size need to be provided for case 2 and 3. If not provided, it will be
            set to the tokenizer's vocabulary size.

        stop_token_ids : Optional[List[int]], default: None
            The stop token ids. If not provided, the eos_token_id of the tokenizer will be used.

        Returns
        -------
        tokenizer_info : TokenizerInfo
            The tokenizer info.
        """
        if isinstance(stop_token_ids, int):
            stop_token_ids = [stop_token_ids]
        if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0:
            raise ValueError('stop_token_ids cannot be empty')

        try:
            vocab_dict = tokenizer.get_vocab()
        except AttributeError as e:
            msg = (f'Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer '
                   'should have a get_vocab method.')
            raise ValueError(msg) from e

        # Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the
        # number of tokens.
        max_id = max(vocab_dict.values())
        tokenizer_vocab_size = max(len(vocab_dict), max_id + 1)

        vocab_size = vocab_size or tokenizer_vocab_size

        # maintain tokenizer's indexing
        encoded_vocab = [''] * vocab_size
        for token, idx in vocab_dict.items():
            if idx < vocab_size:
                encoded_vocab[idx] = token

        if isinstance(tokenizer, PreTrainedTokenizerFast):
            # huggingface fast tokenizer
            # - the vocabulary is directly obtained from tokenizer.get_vocab()
            #   (tokenizer.backend_tokenizer.to_str() may not contain the full vocab, special
            #   tokens may be omitted)
            # - the vocab size is obtained from len(tokenizer.get_vocab()) or provided by user
            # - the vocab type and add_prefix_space are obtained from
            #   tokenizer.backend_tokenizer.to_str()
            # - stop token id is provided by user, or auto detected.
            backend_str = tokenizer.backend_tokenizer.to_str()
            if stop_token_ids is None:
                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
                    stop_token_ids = [tokenizer.eos_token_id]
                else:
                    logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
                                   'stop_token_ids is neither provided by user nor found from the tokenizer. '
                                   'It will be automatically detected.')
            metadata = json.loads(TokenizerInfo._detect_metadata_from_hf(backend_str))
            return TokenizerInfo(
                encoded_vocab,
                vocab_type=VocabType(metadata['vocab_type']),
                vocab_size=vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=metadata['add_prefix_space'],
            )

        elif TokenizerInfo._is_tiktoken_tokenizer(tokenizer):
            # tiktoken tokenizer
            # e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
            if stop_token_ids is None:
                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
                    stop_token_ids = [tokenizer.eos_token_id]
                else:
                    logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
                                   'stop_token_ids is neither provided by user nor found from the tokenizer. '
                                   'It will be automatically detected.')
            return TokenizerInfo(
                encoded_vocab,
                VocabType.RAW,
                vocab_size=vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=False,
            )

        elif TokenizerInfo._is_sentencepiece_tokenizer(tokenizer):
            # sentencepiece tokenizer
            # e.g. Chatglm3-6b
            if hasattr(tokenizer, 'sp_model'):
                sp_model = tokenizer.sp_model
            elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'):
                sp_model = tokenizer.tokenizer.sp_model

            if stop_token_ids is None:
                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
                    stop_token_ids = [tokenizer.eos_token_id]
                else:
                    eos_id = sp_model.eos_id()
                    if eos_id != -1:
                        stop_token_ids = [eos_id]
                    else:
                        logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '
                                       'stop_token_ids is neither provided by user nor found from the tokenizer. '
                                       'It will be automatically detected.')
            # detect vocab_type of tokenizer
            if '<0x0A>' in vocab_dict:
                vocab_type = VocabType.BYTE_FALLBACK
            else:
                vocab_type = VocabType.RAW

            return TokenizerInfo(
                encoded_vocab,
                vocab_type=vocab_type,
                vocab_size=vocab_size,
                stop_token_ids=stop_token_ids,
                add_prefix_space=True,
            )

        else:
            # TODO(yixin): unsupported tokenizer
            raise ValueError(f'Unsupported tokenizer type: {type(tokenizer)}')


================================================
FILE: lmdeploy/turbomind/turbomind.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import asyncio
import copy
import json
import math
import os
import os.path as osp
import sys
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from functools import partial
from multiprocessing.reduction import ForkingPickler
from queue import Queue
from typing import Any, Dict, List, Optional

import pybase64
import torch
import yaml

import lmdeploy
from lmdeploy.messages import EngineOutput, GenerationConfig, ResponseType, ScheduleMetrics, TurbomindEngineConfig
from lmdeploy.serve.openai.protocol import UpdateParamsRequest
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger, get_max_batch_size, get_model

from .deploy.config import TurbomindModelConfig
from .supported_models import is_supported

# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm  # noqa: E402
import _xgrammar as _xgr  # noqa: E402

from .tokenizer_info import TokenizerInfo  # noqa: E402

logger = get_logger('lmdeploy')

MAX_LOGPROBS = 1024


def _construct_stop_or_bad_words(words: List[int] = None):
    if words is None or len(words) == 0:
        return None
    offsets = list(range(1, len(words) + 1))
    combined = [words, offsets]
    return combined


def _np_dict_to_tm_dict(np_dict: dict):
    """Map numpy.ndarray to turbomind's tensor."""
    ret = _tm.TensorMap()
    for k, v in np_dict.items():
        ret[k] = _tm.from_dlpack(v)

    return ret


def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
    """Map turbomind's tensor to torch's tensor."""
    ret = dict()
    for k, v in tm_dict.items():
        if v.type == _tm.DataType.TYPE_UINT32:
            v = v.view(_tm.DataType.TYPE_INT32)
        ret[k] = torch.from_dlpack(v)

    return ret


def complete_parallel_config(cfg: TurbomindEngineConfig):
    if any((cfg.attn_dp_size, cfg.attn_tp_size, cfg.mlp_dp_size, cfg.mlp_tp_size, cfg.outer_dp_size)):
        cfg.attn_dp_size = cfg.attn_dp_size or 1
        cfg.attn_tp_size = cfg.attn_tp_size or 1
        cfg.mlp_dp_size = cfg.mlp_dp_size or 1
        cfg.mlp_tp_size = cfg.mlp_tp_size or 1
        cfg.outer_dp_size = cfg.outer_dp_size or 1
        gcd = math.gcd(cfg.mlp_dp_size, cfg.attn_dp_size)
        cfg.outer_dp_size *= gcd
        cfg.mlp_dp_size //= gcd
        cfg.attn_dp_size //= gcd
        return True
    return False


def update_parallel_config(cfg: TurbomindEngineConfig):
    cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num
    if not complete_parallel_config(cfg):
        total = cfg.dp * cfg.tp
        if not cfg.device_num:
            count = torch.cuda.device_count() * cfg.nnodes
            if total < count:
                count = total
            cfg.device_num = count
        assert total % cfg.device_num == 0
        overlap = total // cfg.device_num
        attn_dp_size = overlap
        mlp_tp_size = overlap
        inner_tp_size = cfg.tp // mlp_tp_size
        cfg.outer_dp_size = cfg.dp // attn_dp_size
        cfg.attn_dp_size = attn_dp_size
        cfg.attn_tp_size = inner_tp_size // cfg.cp
        cfg.attn_cp_size = cfg.cp
        cfg.mlp_dp_size = 1
        cfg.mlp_tp_size = mlp_tp_size * inner_tp_size
    assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
    assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num
    # update devices
    cfg.devices = cfg.devices or list(range(cfg.device_num // cfg.nnodes))
    cfg.devices = cfg.devices[:cfg.device_num // cfg.nnodes]
    assert len(cfg.devices) == cfg.device_num // cfg.nnodes


class TurboMind:
    """LMDeploy's inference engine.

    Args:
        model_path (str): the path of turbomind's model
        mode_name (str): the name of the served model
        chat_template_name (str): the name of the chat template, which is
            supposed to be a builtin chat template defined in
            `lmdeploy/model.py`
        engine_config (TurbomindEngineConfig): the config of the inference
            engine
        model_source (int): the source of the model, which is either
            turbomind model, or a transformers model
    """

    def __init__(self,
                 model_path: str,
                 model_name: str = None,
                 chat_template_name: str = None,
                 engine_config: TurbomindEngineConfig = None,
                 **kwargs):
        self.model_name = model_name
        self.chat_template_name = chat_template_name

        _engine_config = copy.deepcopy(engine_config)
        if _engine_config is None:
            _engine_config = TurbomindEngineConfig()
        if _engine_config.max_batch_size is None:
            _engine_config.max_batch_size = get_max_batch_size('cuda')
        assert _engine_config.max_batch_size > 0, 'max_batch_size should be' \
            f' greater than 0, but got {_engine_config.max_batch_size}'

        update_parallel_config(_engine_config)
        if _engine_config.nnodes > 1:
            logger.info(f'dist_init_addr={_engine_config.dist_init_addr}')
            assert _engine_config.dist_init_addr is not None
            hostname, port = _engine_config.dist_init_addr.split(':')
            os.environ['LMDEPLOY_DIST_INIT_ADDR'] = hostname
            os.environ['LMDEPLOY_DIST_INIT_PORT'] = port
            # this will block the process and ignore signals until all ranks done
            from torch.distributed import TCPStore
            self.store = TCPStore(host_name=hostname,
                                  port=int(port),
                                  world_size=_engine_config.nnodes,
                                  is_master=_engine_config.node_rank == 0)

        self.gpu_count = len(_engine_config.devices)
        self.devices = _engine_config.devices
        self._engine_created = False

        if not osp.exists(model_path):
            model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision)
        self.model_comm = self._from_hf(model_path=model_path, engine_config=_engine_config)
        self.is_dummy = self.model_comm.is_dummy_node()
        self.tokenizer = Tokenizer(model_path)
        if not _engine_config.empty_init:
            self._load_weights()
            self._process_weights()
            self._create_engine()

        self.session_len = self.config.session_len

    def _check_unloaded_tm_params(self):
        tm_params = self._tm_model.tm_params
        if len(tm_params) > 0:
            uninitialized = list(tm_params.keys())
            logger.warning('the model may not be loaded successfully '
                           f'with {len(tm_params)} uninitialized params:\n{uninitialized}')

    def _load_weights(self):
        """Load weights."""
        self._get_model_params()

        with torch.cuda.device(self.devices[0]):
            self._tm_model.export()

        self._check_unloaded_tm_params()

    def _process_weights(self):
        """Process weight."""
        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
            for _ in e.map(self.model_comm.process_weight, range(self.gpu_count)):
                pass

    def _create_engine(self):
        """Create engine."""
        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
            for _ in e.map(self.model_comm.create_engine, range(self.gpu_count)):
                pass
        self._engine_created = True

    def _create_weight(self, model_comm):
        """Allocate weight buffer, load params if from_workspace."""

        # create weight
        def _create_weight_func(device_id):
            model_comm.create_weights(device_id)

        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
            futures = []
            for device_id in range(self.gpu_count):
                futures.append(executor.submit(_create_weight_func, device_id))
            for future in futures:
                future.result()

    def _get_model_params(self):
        """Get turbomind model params when loading from hf."""

        model_comm = self.model_comm
        tm_params = self._tm_model.tm_params
        tm_params.clear()

        def _get_params(device_id, que):
            out = model_comm.get_weights(device_id)
            que.put(out)

        que = Queue()
        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
            futures = []
            for device_id in range(self.gpu_count):
                futures.append(executor.submit(_get_params, device_id, que))
            for future in futures:
                future.result()

        for _ in range(self.gpu_count):
            tensor_map = que.get()
            for k, v in tensor_map.items():
                if k not in tm_params:
                    tm_params[k] = [v]
                else:
                    tm_params[k].append(v)
        logger.warning(f'get {len(tm_params)} model params')

    def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig):
        """Postprocess turbomind config by."""
        import copy
        self.config = copy.deepcopy(tm_config)
        # Update the attribute values in `self.config` with the valid values
        # from the corresponding attributes in `engine_config`, such as
        # `session_len`, `quant_policy`, `rope_scaling_factor`, etc.
        self.config.update_from_engine_config(engine_config)

        # update some attributes of `engine_config` which depends on
        # `session_len`
        self.engine_config = engine_config

        # pack `self.config` and `self.engine_config` into a dict
        self.config_dict = self.config.to_dict()
        self.config_dict.update(dict(engine_config=asdict(self.engine_config)))
        logger.info(f'turbomind model config:\n\n'
                    f'{json.dumps(self.config_dict, indent=2)}')

    def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
        """Load model which is in hf format."""
        assert is_supported(model_path), (f'turbomind does not support {model_path}. '
                                          'Plz try pytorch engine instead.')

        # convert transformers model into turbomind model
        from .deploy.converter import get_tm_model
        tm_model = get_tm_model(model_path, self.model_name, self.chat_template_name, engine_config)

        self._postprocess_config(tm_model.tm_config, engine_config)

        model_comm = _tm.TurboMind.create(model_dir='',
                                          config=yaml.safe_dump(self.config_dict),
                                          weight_type=self.config.model_config.weight_type)

        # create empty weight
        self._create_weight(model_comm)
        # output model
        self._tm_model = tm_model
        return model_comm

    def sleep(self, level: int = 1):
        """Sleep the model."""
        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
            for _ in e.map(self.model_comm.sleep, range(self.gpu_count), [level] * self.gpu_count):
                pass

    def wakeup(self, tags: Optional[list[str]] = None):
        """Wakeup the model."""
        if tags is None:
            tags = ['weights', 'kv_cache']
        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
            for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count):
                pass

    def update_params(self, request: UpdateParamsRequest):
        """Update params.

        When using the this function, you need to set empty_init=True when creating the engine.

        For each request, the serialized_named_tensors should be the full weights of a decoder layer or the misc weights
        (embedding, norm, lm_haed). You should set finished=True when you call this function for the last time.
        """

        def _construct(item):
            """ Deserialize torch.Tensor
            Args:
                item (Tuple[Callable, Tuple]): the return of reduce_tensor
            """
            func, args = item
            args = list(args)
            args[6] = torch.cuda.current_device()  # device id.
            return func(*args).clone()

        if not hasattr(self, '_export_iter'):
            self._get_model_params()
            que = Queue()
            tm_model = self._tm_model
            tm_model.input_model.model_path = que
            self._update_params_que = que
            self._export_iter = tm_model.export_iter()

        with torch.cuda.device(self.devices[0]):
            if isinstance(request.serialized_named_tensors, str):
                weights = ForkingPickler.loads(pybase64.b64decode(request.serialized_named_tensors))
                weights = {k: _construct(v) for k, v in weights}
            else:
                weights = request.serialized_named_tensors
            self._update_params_que.put(weights)
            next(self._export_iter)

        if request.finished:
            self._check_unloaded_tm_params()
            self._process_weights()
            if self._engine_created is False:
                self._create_engine()

    @classmethod
    def from_pretrained(cls,
                        pretrained_model_name_or_path: str,
                        model_name: str = None,
                        chat_template_name: str = None,
                        engine_config: TurbomindEngineConfig = None,
                        **kwargs):
        """LMDeploy's turbomind inference engine.

        Args:
            pretrained_model_name_or_path (str):
                It could be one of the following options:
                    - i) A local directory path of a turbomind model which is
                      converted by `lmdeploy convert` command or download from
                      ii) and iii)
                    - ii) The model_id of a lmdeploy-quantized model hosted
                      inside a model repo on huggingface.co, such as
                      "InternLM/internlm-chat-20b-4bit",
                      "lmdeploy/llama2-chat-70b-4bit", etc.
                    - iii) The model_id of a model hosted inside a model repo
                      on huggingface.co, such as "internlm/internlm-chat-7b",
                      "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                      and so on.
            kwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update configuration when initialize the engine.
        """
        return cls(model_path=pretrained_model_name_or_path,
                   model_name=model_name,
                   chat_template_name=chat_template_name,
                   engine_config=engine_config,
                   **kwargs)

    def close(self):
        if hasattr(self, '_tm_model'):
            # close immediately after init engine with empty_init=True
            self._tm_model.tm_params.clear()
        if hasattr(self, '_export_iter'):
            del self._export_iter
        if self.model_comm is not None:
            self.model_comm = None
        self._engine_created = False
        if hasattr(self, 'store'):
            del self.store

    def create_instance(self, cuda_stream_id=0):
        """Create a turbomind instance.

        Args:
            cuda_stream_id(int): identity of a cuda stream
        Returns:
            TurboMindInstance: an instance of turbomind
        """
        return TurboMindInstance(self, self.config, cuda_stream_id)

    def get_schedule_metrics(self):
        # TODO: support dp
        tm_metrics = self.model_comm.get_schedule_metrics(0)
        return ScheduleMetrics(active_seqs=tm_metrics.active_seqs,
                               waiting_seqs=tm_metrics.waiting_seqs,
                               total_blocks=tm_metrics.total_blocks,
                               active_blocks=tm_metrics.active_blocks,
                               free_blocks=tm_metrics.free_blocks)


def _get_logits(outputs, offset: int):
    logits = outputs['logits']

    def _func(out: EngineOutput, step: int, **kwargs):
        out.logits = logits[:step - offset - 1, :]

    return _func


def _get_last_hidden_state(outputs, offset: int):
    last_hidden_state = outputs['last_hidden_state']

    def _func(out: EngineOutput, step: int, **kwargs):
        out.last_hidden_state = last_hidden_state[:step - offset - 1, :]

    return _func


def _get_logprobs_impl(logprob_vals: torch.Tensor, logprob_idxs: torch.Tensor, logprob_nums: torch.Tensor,
                       output_ids: List[int], logprobs: int, offset: int):
    """Get logprob of each generated token.

    Args:
        logprob_vals (torch.Tensor): shape (max_new_tokens, 1024),
            1024 is the max_logprobs that turbomind engine can output
        logprob_idxs (torch.Tensor): shape (max_new_tokens, 1024)
        logprob_nums (torch.Tensor): shape (max_new_tokens,)
        output_ids (List[int]): new generated token ids
        logprobs (int): top n logprobs to return
        offset (int): offset to index logprob_vals, logprob_idxs and logprob_nums.
            It indicates where to start getting logprobs for the current generated tokens `output_ids`
    """
    out_logprobs = []
    # the total generated token number until now
    length = len(output_ids) + offset
    for (pos, idx, val, n) in zip(range(len(output_ids)), logprob_idxs[offset:length], logprob_vals[offset:length],
                                  logprob_nums[offset:length]):
        topn = min(n.item(), logprobs)
        tok_res = {idx[i].item(): val[i].item() for i in range(topn)}
        token_id = output_ids[pos]
        if token_id not in tok_res:
            valid_n = n.item()
            tok_res[token_id] = \
                val[:valid_n][idx[:valid_n] == token_id].item()
        ids = list(tok_res.keys())
        for k in ids:
            if tok_res[k] == float('-inf'):
                tok_res.pop(k)
        out_logprobs.append(tok_res)
    return out_logprobs


def _get_logprobs(outputs, output_logprobs: int):
    logprob_vals = outputs['logprob_vals']  # shape {max_new_tokens, 1024}
    logprob_idxs = outputs['logprob_indexes']  # shape {max_new_tokens, 1024}
    logprob_nums = outputs['logprob_nums']  # shape {max_new_tokens,}
    offset = 0  # offset to index logprob_vals, logprob_idxs and logprob_nums

    def _func(out: EngineOutput, step: int, **kwargs):
        nonlocal offset
        out.logprobs = _get_logprobs_impl(logprob_vals, logprob_idxs, logprob_nums, out.token_ids, output_logprobs,
                                          offset)
        offset += len(out.token_ids)

    return _func


def _get_metrics(metrics):
    import time

    from lmdeploy.messages import EngineEvent, EventType, RequestMetrics

    is_first = True

    def _func(out: EngineOutput, step: int, **kwargs):
        nonlocal is_first
        if not is_first:
            out.req_metrics = RequestMetrics(token_timestamp=time.time())
        else:
            events = [
                EngineEvent(EventType.QUEUED, metrics.enqueue_time / 1000000),
                EngineEvent(EventType.SCHEDULED, metrics.scheduled_time / 1000000),
            ]
            out.req_metrics = RequestMetrics(token_timestamp=time.time(), engine_events=events)
            is_first = False

    return _func


class StreamingSemaphore:

    def __init__(self):
        self.loop = asyncio.get_running_loop()
        self.fut = None
        self.val = 0

    async def acquire(self):
        if self.val:
            self.val = 0
            return
        self.fut = self.loop.create_future()
        await self.fut
        self.fut = None
        self.val = 0

    def release(self):
        if not self.val:
            self.val = 1
            if self.fut and not self.fut.done():
                self.fut.set_result(None)


class TurboMindInstance:
    """Instance of TurboMind.

    Args:
        tm_model (str): turbomind's model path
        cuda_stream_id(int): identity of a cuda stream
    """

    def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_stream_id: int = 0):
        self.tm_model = tm_model
        self.cuda_stream_id = cuda_stream_id

        # create model instances
        lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)
        self._model_inst = None if lazy_init else self._create_model_instance()

        self.config = config
        self.lock = None
        # error code map from csrc (refer to `struct Request` in src/turbomind/engine/request.h)
        # to lmdeploy.messages.ResponseType
        self.errcode_map = {
            0: ResponseType.SUCCESS,
            1: ResponseType.SESSION_NOT_EXIST,
            2: ResponseType.SESSION_REPEAT,
            3: ResponseType.SESSION_REPEAT,
            4: ResponseType.INTERNAL_ENGINE_ERROR,
            5: ResponseType.INTERNAL_ENGINE_ERROR,
            6: ResponseType.INPUT_LENGTH_ERROR,
            7: ResponseType.FINISH,
            8: ResponseType.CANCEL,
            9: ResponseType.PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE,
            10: ResponseType.NO_QUEUE,
            -1: ResponseType.INTERNAL_ENGINE_ERROR,
        }

    @property
    def model_inst(self):
        if self._model_inst is None:
            self._model_inst = self._create_model_instance()
        return self._model_inst

    def _create_model_instance(self):
        model_inst = self.tm_model.model_comm.create_request()
        return model_inst

    def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,
                                     input_len: int, metrics: '_tm.RequestMetrics'):

        def _get_offset(type):
            return input_len - 1 if type == 'generation' else 0

        fs = []
        if gen_config.output_logits:
            offset = _get_offset(gen_config.output_logits)
            fs.append(_get_logits(outputs, offset))
        if gen_config.output_last_hidden_state:
            offset = _get_offset(gen_config.output_last_hidden_state)
            fs.append(_get_last_hidden_state(outputs, offset))
        if gen_config.logprobs:
            fs.append(_get_logprobs(outputs, gen_config.logprobs))
        if self.tm_model.engine_config.enable_metrics:
            fs.append(_get_metrics(metrics))
        return fs

    def prepare_embeddings(self, input_embeddings=None, input_embedding_ranges=None):
        """Convert embeddings."""
        if not input_embeddings:
            return None, None

        assert isinstance(input_embeddings, List)
        assert isinstance(input_embedding_ranges, List)
        assert len(input_embeddings) == len(input_embedding_ranges)

        length = sum([x.shape[0] for x in input_embeddings])

        _MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16)
        dtype = _MAP[self.tm_model.config.model_config.data_type]

        values = torch.empty((length, input_embeddings[0].shape[-1]), dtype=dtype, device='cpu')
        ranges = torch.tensor(input_embedding_ranges, dtype=torch.int32, device='cpu')

        offset = 0
        for embeds in input_embeddings:
            values[offset:offset + embeds.shape[0]].copy_(embeds)
            offset += embeds.shape[0]

        return values, ranges

    def prepare_mrope(self, input_meta: Dict[str, Any], input_len: int):
        mrope_position_ids = input_meta['mrope_position_ids']
        mrope_position_delta = input_meta['mrope_position_delta']
        assert mrope_position_ids.size(-1) == input_len
        mrope_position_ids = mrope_position_ids.t().contiguous()
        return mrope_position_ids, mrope_position_delta

    def prepare_inputs(self,
                       input_ids,
                       gen_config: GenerationConfig,
                       input_embeddings=None,
                       input_embedding_ranges=None,
                       input_meta: Dict[str, Any] = None):
        """Convert inputs format."""
        assert isinstance(input_ids, Sequence)

        input_ids = torch.IntTensor(input_ids)
        input_len = len(input_ids)

        inputs = dict(input_ids=input_ids, )

        input_embeddings, input_embedding_ranges = self.prepare_embeddings(input_embeddings, input_embedding_ranges)
        if input_embeddings is not None:
            inputs['input_embeddings'] = input_embeddings.cpu()
            inputs['input_embedding_ranges'] = input_embedding_ranges

        if input_meta and 'mrope_position_ids' in input_meta:
            mrope_position_ids, mrope_position_delta = self.prepare_mrope(input_meta, input_len)
            inputs['mrope_position_ids'] = mrope_position_ids.type(torch.int32)
            inputs['mrope_position_delta'] = mrope_position_delta.type(torch.int32)
            inputs['mrope_length'] = torch.IntTensor([mrope_position_ids.shape[0]])

        return inputs, input_len

    async def async_cancel(self, session_id: int = None):
        self.model_inst.cancel()

    def async_end_cb(self, fut: asyncio.Future, status: int):
        """Executing on engine's signaling thread."""
        logger.info(f'[async_end_cb] session ended, status = {status}')
        fut.get_loop().call_soon_threadsafe(fut.set_result, status)

    async def async_end(self, session_id):
        fut = asyncio.get_running_loop().create_future()
        self.model_inst.end(partial(self.async_end_cb, fut), session_id)
        await fut

    def async_signal_cb(self, s: StreamingSemaphore):
        """Executing on engine's signaling thread."""
        s.loop.call_soon_threadsafe(s.release)

    async def async_stream_infer(self,
                                 session_id,
                                 input_ids,
                                 input_embeddings=None,
                                 input_embedding_ranges=None,
                                 input_meta: Dict[str, Any] = None,
                                 sequence_start: bool = True,
                                 sequence_end: bool = False,
                                 step=0,
                                 gen_config: GenerationConfig = None,
                                 stream_output=False,
                                 **kwargs):
        """Perform model inference.

        Args:
            session_id (int): the id of a session
            input_ids (numpy.ndarray): the token ids of a prompt
            input_embeddings (List[numpy.ndarray]): embeddings features
            input_embedding_ranges (List[Tuple[int,int]]): the begin/end
              offsets of input_embeddings to input_ids
            sequence_start (bool): indicator for starting a sequence
            sequence_end (bool): indicator for ending a sequence
            step (int): the offset of the k/v cache
            stop (bool): indicator for cancelling the session
            gen_config (GenerationConfig): generation config
            stream_output (bool): indicator for stream output
            kwargs (dict): kwargs for backward compatibility
        """
        logger.info(f'[async_stream_infer] session {session_id} start')
        gen_cfg = self._get_generation_config(gen_config)

        inputs, input_len = self.prepare_inputs(input_ids=input_ids,
                                                input_embeddings=input_embeddings,
                                                input_embedding_ranges=input_embedding_ranges,
                                                input_meta=input_meta,
                                                gen_config=gen_config)

        if gen_config.response_format is not None:
            tokenizer = self.tm_model.tokenizer
            vocab_size = self.tm_model.config.model_config.vocab_size

            try:
                tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)
                decode_grammar_type = gen_config.response_format['type']
                if decode_grammar_type == 'json_schema':
                    decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
                elif decode_grammar_type == 'regex_schema':
                    decode_grammar = gen_config.response_format[decode_grammar_type]
                elif decode_grammar_type == 'json_object':
                    decode_grammar = '{"type" : "object", "additionalProperties": true}'

                compiler = _xgr.GrammarCompiler(tokenizer_info)

                if decode_grammar_type == 'json_schema':
                    decode_grammar = json.dumps(decode_grammar)
                    grammar = compiler.compile_json_schema(decode_grammar)
                elif decode_grammar_type == 'regex_schema':
                    decode_grammar = str(decode_grammar)
                    grammar = compiler.compile_regex(decode_grammar)
                elif decode_grammar_type == 'json_object':
                    decode_grammar = str(decode_grammar)
                    grammar = compiler.compile_json_schema(decode_grammar)
                else:
                    assert False, f'Decode grammar type {decode_grammar_type} should be in ' \
                                   '["json_schema", "regex_schema", "json_object"]'

                self.model_inst.set_grammar(grammar)
            except ValueError as e:
                logger.warning(f'Failed to initialize guided decoding for tokenizer {tokenizer}, '
                               f'disable guided decoding: {e}')
                gen_config.response_format = None

        session = _tm.SessionParam(id=session_id, step=step, start=sequence_start, end=sequence_end)

        inputs = _np_dict_to_tm_dict(inputs)

        sem = StreamingSemaphore()
        signal_cb = partial(self.async_signal_cb, sem)

        outputs, shared_state, metrics = self.model_inst.forward(inputs, session, gen_cfg, stream_output,
                                                                 self.tm_model.engine_config.enable_metrics, signal_cb)

        outputs = _tm_dict_to_torch_dict(outputs)

        extra_fs = self._get_extra_output_processors(outputs, gen_config, input_len, metrics)

        output_ids_buf = outputs['output_ids']

        finish = False
        state = None

        output_ids = []
        prev_len = step + input_len
        try:
            while True:
                await sem.acquire()
                state = shared_state.consume()

                status, seq_len = state.status, state.seq_len
                ret_status = ResponseType.SUCCESS

                if status in [7, 8]:  # finish / canceled
                    finish = True
                    ret_status = ResponseType.FINISH if status == 7 else ResponseType.CANCEL
                elif status:
                    logger.error(f'internal error. status_code {status}')
                    yield self._get_error_output(status)
                    break

                if seq_len == prev_len and not finish:
                    continue

                output_ids = output_ids_buf[prev_len:seq_len].tolist()
                output = EngineOutput(ret_status, output_ids)

                for f in extra_fs:
                    f(output, seq_len)

                prev_len = seq_len

                yield output

                if finish:
                    break

        except (GeneratorExit, asyncio.CancelledError) as e:
            logger.info(f'[async_stream_infer] {type(e).__name__}')
            self.model_inst.cancel()
        except Exception as e:
            logger.error(f'[async_stream_infer] {type(e).__name__} {e}')
            self.model_inst.cancel()
            yield self._get_error_output(-1)
        finally:
            # Contract: `cb` won't be called again if status is non-zero
            # wait for status to be set as `finish` or `error`
            while not state or state.status == 0:
                await sem.acquire()
                state = shared_state.consume()
            logger.info(f'[async_stream_infer] session {session_id} done')

    def _get_error_output(self, status):
        return EngineOutput(status=self.errcode_map[status], token_ids=[])

    def _get_generation_config(self, cfg: GenerationConfig):
        c = _tm.GenerationConfig()
        c.max_new_tokens = cfg.max_new_tokens
        c.top_k = cfg.top_k
        c.top_p = cfg.top_p
        c.min_p = cfg.min_p
        c.temperature = cfg.temperature
        if cfg.stop_token_ids:
            c.eos_ids = cfg.stop_token_ids
        if cfg.bad_token_ids:
            c.bad_ids = _construct_stop_or_bad_words(cfg.bad_token_ids)
        if not cfg.ignore_eos and cfg.stop_token_ids:
            c.stop_ids = _construct_stop_or_bad_words(cfg.stop_token_ids)
        c.repetition_penalty = cfg.repetition_penalty
        if cfg.min_new_tokens:
            c.min_new_tokens = cfg.min_new_tokens
        output_type = dict(all=1, generation=2)
        if cfg.output_last_hidden_state:
            c.output_last_hidden_state = output_type[cfg.output_last_hidden_state]
        if cfg.output_logits:
            c.output_logits = output_type[cfg.output_logits]
        if cfg.logprobs:
            if cfg.logprobs > MAX_LOGPROBS:
                cfg.logprobs = MAX_LOGPROBS
                logger.warning(f'logprobs shoudd be in range [1, {MAX_LOGPROBS}]'
                               f'update logprobs={cfg.logprobs}')
            c.output_logprobs = cfg.logprobs
        if cfg.random_seed is not None:
            c.random_seed = cfg.random_seed
        # print (c)
        return c


================================================
FILE: lmdeploy/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import functools
import logging
import os
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger, LogRecord

import torch
from transformers import PretrainedConfig

logger_initialized = {}


class _ASNI_COLOR:
    BRIGHT_RED = '\033[91m'
    RED = '\033[31m'
    YELLOW = '\033[33m'
    WHITE = '\033[37m'
    GREEN = '\033[32m'


# copy from: https://github.com/termcolor/termcolor
@functools.cache
def can_colorize(*, no_color: bool | None = None, force_color: bool | None = None) -> bool:
    """Check env vars and for tty/dumb terminal."""
    import io
    if no_color is not None and no_color:
        return False
    if force_color is not None and force_color:
        return True

    # Then check env vars:
    if os.environ.get('ANSI_COLORS_DISABLED'):
        return False
    if os.environ.get('NO_COLOR'):
        return False
    if os.environ.get('FORCE_COLOR'):
        return True

    # Then check system:
    if os.environ.get('TERM') == 'dumb':
        return False
    if not hasattr(sys.stdout, 'fileno'):
        return False

    try:
        return os.isatty(sys.stdout.fileno())
    except io.UnsupportedOperation:
        return sys.stdout.isatty()


class ColorFormatter(logging.Formatter):

    _LEVELNAME_COLOR_MAP = dict(CRITICAL=_ASNI_COLOR.BRIGHT_RED,
                                ERROR=_ASNI_COLOR.RED,
                                WARN=_ASNI_COLOR.YELLOW,
                                WARNING=_ASNI_COLOR.YELLOW,
                                INFO=_ASNI_COLOR.WHITE,
                                DEBUG=_ASNI_COLOR.GREEN)

    _RESET_COLOR = '\033[0m'

    def format(self, record: LogRecord):
        """format."""
        if not can_colorize():
            # windows does not support ASNI color
            return super().format(record)
        levelname = record.levelname
        level_color = self._LEVELNAME_COLOR_MAP.get(levelname, self._RESET_COLOR)
        levelname = f'{level_color}{levelname}{self._RESET_COLOR}'
        record.levelname = levelname
        return super().format(record)


class FilterDuplicateWarning(logging.Filter):
    """Filter the repeated warning message.

    Args:
        name (str): name of the filter.
    """

    def __init__(self, name: str = 'lmdeploy'):
        super().__init__(name)
        self.seen: set = set()

    def filter(self, record: LogRecord) -> bool:
        """Filter the repeated warning message.

        Args:
            record (LogRecord): The log record.

        Returns:
            bool: Whether to output the log record.
        """
        if record.levelno != logging.WARNING:
            return True

        if record.msg not in self.seen:
            self.seen.add(record.msg)
            return True
        return False


_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d' \
          ' - %(message)s'


def get_logger(name: str | None = None,
               log_file: str | None = None,
               log_level: int = logging.INFO,
               file_mode: str = 'a',
               log_formatter: str = _FORMAT) -> Logger:
    """Initialize and get a logger by name.

    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified, a FileHandler will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level.
        file_mode (str): The file mode used in opening log file.
            Defaults to 'a'.
        log_formatter (str): The logger output format.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    # handle hierarchical names
    # e.g., logger "a" is initialized, then logger "a.b" will skip the
    # initialization since it is a child of "a".
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    # handle duplicate logs to the console
    for handler in logger.root.handlers:
        if type(handler) is logging.StreamHandler:
            handler.setLevel(logging.ERROR)

    stream_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [stream_handler]

    # set log_file from env
    log_file = log_file or os.getenv('LMDEPLOY_LOG_FILE')

    if log_file is not None:
        log_file = os.path.expanduser(log_file)
        log_dir = os.path.dirname(log_file)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
        # Here, the default behaviour of the official logger is 'a'. Thus, we
        # provide an interface to change the file mode to the default
        # behaviour.
        file_handler = logging.FileHandler(log_file, file_mode)
        handlers.append(file_handler)

    formatter = ColorFormatter(log_formatter)
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(logging.DEBUG)
        handler.addFilter(FilterDuplicateWarning(name))
        logger.addHandler(handler)

    logger.setLevel(log_level)
    logger.propagate = False
    logger_initialized[name] = True

    return logger


def filter_suffix(response: str, suffixes: list[str] | None = None) -> str:
    """Filter response with suffixes.

    Args:
        response (str): generated response by LLMs.
        suffixes (str): a list of suffixes to be deleted.

    Return:
        str: a clean response.
    """
    if suffixes is None:
        return response
    for item in suffixes:
        if response.endswith(item):
            response = response[:len(response) - len(item)]
    return response


# TODO remove stop_word_offsets stuff and make it clean
def _stop_words(stop_words: list[int | str], tokenizer: object):
    """Return list of stop-words to numpy.ndarray."""
    import numpy as np
    if stop_words is None:
        return None
    assert isinstance(stop_words, list) and \
        all(isinstance(elem, (str, int)) for elem in stop_words), \
        f'stop_words must be a list but got {type(stop_words)}'
    stop_indexes = []
    for stop_word in stop_words:
        if isinstance(stop_word, str):
            stop_indexes += tokenizer.indexes_containing_token(stop_word)
        elif isinstance(stop_word, int):
            stop_indexes.append(stop_word)
    assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words'
    # each id in stop_indexes represents a stop word
    # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
    # detailed explanation about fastertransformer's stop_indexes
    stop_word_offsets = range(1, len(stop_indexes) + 1)
    stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(np.int32)
    return stop_words


def get_hf_gen_cfg(path: str):
    from transformers import GenerationConfig
    try:
        cfg = GenerationConfig.from_pretrained(path, trust_remote_code=True)
        return cfg.to_dict()
    except OSError:
        return {}


def get_model(pretrained_model_name_or_path: str, download_dir: str = None, revision: str = None, token: str = None):
    """Get model from huggingface, modelscope or openmind_hub."""
    import os
    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':
        from openmind_hub import snapshot_download
    else:
        from huggingface_hub import snapshot_download

    download_kwargs = {}
    if download_dir is not None:
        download_kwargs['cache_dir'] = download_dir
    if revision is not None:
        download_kwargs['revision'] = revision
    if token is not None:
        download_kwargs['token'] = token

    model_path = snapshot_download(pretrained_model_name_or_path, ignore_patterns=['*.pth'], **download_kwargs)
    return model_path


def logging_timer(op_name: str, logger: Logger, level: int = logging.DEBUG):
    """Logging timer."""

    @contextmanager
    def __timer():
        """timer."""
        start = time.perf_counter()
        yield
        end = time.perf_counter()
        duration = (end - start) * 1000
        logger.log(level, f'<{op_name}> take time: {duration:.2f} ms')

    def __inner(func):
        """inner."""

        @functools.wraps(func)
        def __func_warpper(*args, **kwargs):
            """Func warpper."""
            if logger.level > level:
                return func(*args, **kwargs)
            with __timer():
                return func(*args, **kwargs)

        @functools.wraps(func)
        def __async_warpper(*args, **kwargs):
            """Async warpper."""

            async def __tmp():
                if logger.level > level:
                    return (await func(*args, **kwargs))
                with __timer():
                    return (await func(*args, **kwargs))

            return __tmp()

        if asyncio.iscoroutinefunction(func):
            return __async_warpper
        else:
            return __func_warpper

    return __inner


# modified from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150  # noqa
def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: int | None,
) -> int:
    """Get and verify the model's maximum length."""

    # vl configs hide session-len inside llm configs
    llm_keys = ['language_config', 'llm_config', 'text_config']
    for key in llm_keys:
        hf_config = getattr(hf_config, key, hf_config)

    logger = get_logger('lmdeploy')
    derived_max_model_len = float('inf')
    possible_keys = [
        # OPT
        'max_position_embeddings',
        # GPT-2
        'n_positions',
        # MPT
        'max_seq_len',
        # ChatGLM2
        'seq_length',
        # Command-R
        'model_max_length',
        # Others
        'max_sequence_length',
        'max_seq_length',
        'seq_len',
    ]
    max_len_key = None
    for key in possible_keys:
        max_len = None
        if hasattr(hf_config, key):
            max_len = getattr(hf_config, key)
        elif key in hf_config:
            max_len = hf_config[key]
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
    if derived_max_model_len == float('inf'):
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning("The model's config.json does not contain any of the following "
                       'keys to determine the original maximum length of the model: '
                       f"{possible_keys}. Assuming the model's maximum length is "
                       f'{default_max_len}.')
        derived_max_model_len = default_max_len

    if max_model_len is None:
        max_model_len = int(derived_max_model_len)
    elif max_model_len > derived_max_model_len:
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, 'model_max_length', None)
        if model_max_length is not None and max_model_len <= model_max_length:
            pass
        else:
            logger.warning(f'User-specified max_model_len ({max_model_len}) is greater '
                           'than the derived max_model_len '
                           f'({max_len_key}={derived_max_model_len} or model_max_length='
                           f"{model_max_length} in model's config.json).")
    return int(max_model_len)


def get_max_batch_size(device_type: str):
    """Get the max inference batch size for LLM models according to the device
    type.

    Args:
        device_type (str): the type of device
    """
    assert device_type in ['cuda', 'ascend', 'maca', 'camb']
    if device_type == 'cuda':
        max_batch_size_map = {'a100': 384, 'a800': 384, 'h100': 1024, 'h800': 1024, 'l20y': 1024, 'h200': 1024}
        import torch
        device_name = torch.cuda.get_device_name(0).lower()
        for name, size in max_batch_size_map.items():
            if name in device_name:
                return size
        # for devices that are not in `max_batch_size_map`, set
        # the max_batch_size 128
        return 128
    elif device_type == 'ascend':
        return 256
    elif device_type == 'maca':
        return 256
    elif device_type == 'camb':
        return 256


def is_bf16_supported(device_type: str = 'cuda'):
    """Check if device support bfloat16.

    Args:
        device_type (str): the type of device
    """

    if device_type == 'cuda':
        import torch
        device = torch.cuda.current_device()

        # Check for CUDA version and device compute capability.
        # This is a fast way to check for it.
        cuda_version = torch.version.cuda
        if (cuda_version is not None and int(cuda_version.split('.')[0]) >= 11
                and torch.cuda.get_device_properties(device).major >= 8):
            return True
        else:
            return False
    elif device_type == 'ascend':
        # The following API doesn't work somehow in multi-npu devices. Due to
        # the `ascend910` device's capability to support bfloat16, we are
        # returning true as a workaround
        return True
        # import torch_npu
        # device_name = torch_npu.npu.get_device_name(0)[:10]
        # device_name = device_name.lower()
        # if device_name.startwith('ascend910'):
        #     return True
        # else:
        #     return False
    elif device_type == 'maca':
        return True
    elif device_type == 'camb':
        return True
    elif device_type == 'rocm':
        return True
    else:
        return False


def try_import_deeplink(device_type: str):
    deeplink_device_type_list = [
        'ascend',
        'npu',
        'maca',
        'camb',
    ]
    if device_type in deeplink_device_type_list:
        try:
            import dlinfer.framework.lmdeploy_ext  # noqa: F401
        except Exception as e:
            logger = get_logger('lmdeploy')
            logger.error(f'{type(e).__name__}: {e}')
            exit(1)


def serialize_state_dict(state_dict: dict) -> str:
    """Serialize state dict to str.

    The consumer should use it on same node. As the producer and consumer may
    have different GPU visibility, we use reduce_tensor instead of ForkingPickler.dumps
    to fix the device_id when loading the serialized tensor.

    Args:
        state_dict (dict[str, torch.Tensor]): state dict to serialize.
    Returns:
        str: serialized state dict.
    """
    from io import BytesIO
    from multiprocessing.reduction import ForkingPickler

    import pybase64
    from torch.multiprocessing.reductions import reduce_tensor

    # flattened_tensor
    if 'metadata' in state_dict and 'flattened_tensor' in state_dict:
        data = state_dict
        if isinstance(data['flattened_tensor'], torch.Tensor):
            data['flattened_tensor'] = reduce_tensor(state_dict['flattened_tensor'])
    else:
        data = [(k, reduce_tensor(v)) for k, v in state_dict.items()]

    buf = BytesIO()
    ForkingPickler(buf).dump(data)
    buf.seek(0)
    return pybase64.b64encode(buf.read()).decode('utf-8')


def is_dlblas_installed():
    is_dlblas_installed = True
    try:
        import dlblas  # noqa: F401
    except Exception:
        is_dlblas_installed = False
    return is_dlblas_installed


# from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/weight_sync/tensor_bucket.py


@dataclass
class FlattenedTensorMetadata:
    """Metadata for flatten bucket tensor."""
    name: str
    shape: torch.Size
    dtype: torch.dtype
    start_idx: int
    end_idx: int
    numel: int


class FlattenedTensorBucket:
    """Pack multiple flattened tensor into one to transfer efficiently."""

    def __init__(
        self,
        named_tensors: list[tuple[str, torch.Tensor]] | None = None,
        flattened_tensor: torch.Tensor = None,
        metadata: list[FlattenedTensorMetadata] | None = None,
    ):
        """Initialize a tensor bucket from a list of named tensors or from pre-
        flattened data.

        Args:
            named_tensors: List of (name, tensor) tuples (for creating new bucket)
            flattened_tensor: Pre-flattened tensor (for reconstruction)
            metadata: Pre-computed metadata (for reconstruction)
        """
        if named_tensors is not None:
            num_tensors = len(named_tensors)
            self.metadata = [None] * num_tensors
            self.flattened_tensor = [None] * num_tensors
            if num_tensors > 0:
                if num_tensors > 1:
                    dtypes = [t.dtype for _, t in named_tensors]
                    if not all([d == dtypes[0] for d in dtypes[1:]]):
                        raise ValueError(f'All tensors should have same dtype, but given {dtypes}')

                current_idx = 0
                for idx, (name, tensor) in enumerate(named_tensors):
                    self.flattened_tensor[idx] = tensor.flatten()
                    numel = tensor.numel()
                    self.metadata[idx] = FlattenedTensorMetadata(name=name,
                                                                 shape=tensor.shape,
                                                                 dtype=tensor.dtype,
                                                                 start_idx=current_idx,
                                                                 end_idx=current_idx + numel,
                                                                 numel=numel)
                    current_idx += numel

                self.flattened_tensor = torch.cat(self.flattened_tensor, dim=0)
        else:
            if flattened_tensor is None or metadata is None:
                raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata')
            self.metadata = metadata
            self.flattened_tensor = flattened_tensor

    def get_flattened_tensor(self) -> torch.Tensor:
        """Get the flattened tensor containing multiple tensors."""
        return self.flattened_tensor

    def get_metadata(self) -> list[FlattenedTensorMetadata]:
        """Get all metadatas for all tensors in the bucket."""
        return self.metadata

    def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]:
        """Reconstruct original tensors."""
        # preallocate the result list
        reconstructed = [None] * len(self.metadata)

        for i, meta in enumerate(self.metadata):
            tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape)

            # batch dtype conversion (if needed)
            if tensor.dtype != meta.dtype:
                tensor = tensor.to(meta.dtype)

            reconstructed[i] = (meta.name, tensor)

        return reconstructed


================================================
FILE: lmdeploy/version.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

__version__ = '0.12.2'
short_version = __version__


def parse_version_info(version_str: str) -> Tuple:
    """Parse version from a string.

    Args:
        version_str (str): A string represents a version info.

    Returns:
        tuple: A sequence of integer and string represents version.
    """
    _version_info = []
    for x in version_str.split('.'):
        if x.isdigit():
            _version_info.append(int(x))
        elif x.find('rc') != -1:
            patch_version = x.split('rc')
            _version_info.append(int(patch_version[0]))
            _version_info.append(f'rc{patch_version[1]}')
    return tuple(_version_info)


version_info = parse_version_info(__version__)


================================================
FILE: lmdeploy/vl/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (encode_image_base64, encode_time_series_base64, encode_video_base64, load_image, load_time_series,
                    load_video)

__all__ = [
    'load_image',
    'load_video',
    'load_time_series',
    'encode_image_base64',
    'encode_video_base64',
    'encode_time_series_base64',
]


================================================
FILE: lmdeploy/vl/constants.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from enum import Enum

IMAGE_TOKEN = ''


class Modality(str, Enum):
    IMAGE = 'image'
    VIDEO = 'video'
    AUDIO = 'audio'
    TIME_SERIES = 'time_series'


================================================
FILE: lmdeploy/vl/engine.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import asyncio
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Union

import torch

from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.builder import load_vl_model

logger = get_logger('lmdeploy')


def _raise_exception_on_finish(task: asyncio.Task) -> None:
    """Raise exception on finish."""
    try:
        task.result()
    except asyncio.CancelledError:
        return
    except Exception as e:
        raise e


def _accepts_arg(func, arg_name: str) -> bool:
    """Check if a function accepts a specific keyword argument."""
    return arg_name in inspect.signature(func).parameters


class ImageEncoder:
    """Image encoder."""

    def __init__(
        self,
        model_path: str,
        backend: str,
        vision_config: VisionConfig = None,
        backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None,
    ):
        self.model = load_vl_model(model_path, backend, backend_config=backend_config)
        if vision_config is None:
            vision_config = VisionConfig()
        self.vision_config = vision_config
        self.max_batch_size = vision_config.max_batch_size
        self.executor = ThreadPoolExecutor(max_workers=1)
        torch.cuda.empty_cache()

    async def preprocess(self,
                         messages: List[Dict],
                         mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> List[Dict]:
        """Preprocess multimodal data in the messages."""
        if _accepts_arg(self.model.preprocess, 'mm_processor_kwargs'):
            future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages,
                                                              mm_processor_kwargs)
        else:
            future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)
        future.add_done_callback(_raise_exception_on_finish)
        outputs = await future
        return outputs

    async def async_infer(self, messages: List[Dict]) -> List[Dict]:
        """Get multimodal embedding.

        Args:
            messages (List[Dict]): a list of message, which is the output
            of `preprocess()`
        """
        future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,
                                                          self.max_batch_size)
        future.add_done_callback(_raise_exception_on_finish)
        outputs = await future
        return outputs

    async def wrap_for_pytorch(
        self,
        messages: List[Dict],
        chat_template,
        tokenizer,
        sequence_start,
        tools: Optional[List[object]] = None,
        chat_template_kwargs: Optional[Dict] = None,
    ) -> List[Dict]:
        """
        Args:
            messages (List[Dict]): a list of message, which is supposed to be
                the output of `preprocess`
        Returns:
            a dict which will be passed to pytorch engine_instance's forward.
            The dict is like the following:
            Dict(
                'prompt': 'the prompt after applying chat template'
                'input_ids': [],
                'multimodal': {
                    'pixel_values': torch.Tensor,
                    ...
                ]
            )
        """
        has_input_ids = self.model.has_input_ids(messages)
        if not has_input_ids:
            result = self.model.to_pytorch(messages,
                                           chat_template,
                                           tokenizer,
                                           sequence_start,
                                           tools=tools,
                                           chat_template_kwargs=chat_template_kwargs)
        else:
            result = self.model.to_pytorch_with_input_ids(messages)
        # clear data
        for i, message in enumerate(messages):
            if isinstance(message['content'], List):
                messages[i]['preprocess'] = None
        return result

    async def wrap_for_turbomind(
        self,
        messages: List[Dict],
        chat_template,
        tokenizer,
        sequence_start,
        tools: Optional[List[object]] = None,
        chat_template_kwargs: Optional[Dict] = None,
    ) -> Dict:
        """
        Args:
            messages (List[Dict]): a list of message, which is supposed to be
                the output of `async_infer`
        Returns:
            a dict which will be passed to pytorch engine_instance's forward.
            The dict is like the following:
            Dict(
                'prompt': 'the prompt after applying chat template'
                'input_ids': [],
                'input_embeddings': list[torch.Tensor],
                'input_embedding_ranges': list[torch.Tensor],
                ...
        """
        result = self.model.to_turbomind(messages,
                                         chat_template,
                                         tokenizer,
                                         sequence_start,
                                         tools=tools,
                                         chat_template_kwargs=chat_template_kwargs)
        # clear data
        for i, message in enumerate(messages):
            if isinstance(message['content'], List):
                messages[i]['preprocess'] = None
                messages[i]['forward'] = None
        return result


================================================
FILE: lmdeploy/vl/media/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/vl/media/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/base.py

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Generic, TypeVar

_T = TypeVar('_T')


class MediaIO(ABC, Generic[_T]):

    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError


================================================
FILE: lmdeploy/vl/media/connection.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from pathlib import Path
from typing import TypeVar
from urllib.parse import ParseResult, urlparse
from urllib.request import url2pathname

import requests

from .base import MediaIO
from .image import ImageMediaIO
from .video import VideoMediaIO

_M = TypeVar('_M')

headers = {
    'User-Agent':
    'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
    '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}


def _load_http_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:
    if url_spec.scheme not in ('http', 'https'):
        raise ValueError(f'Unsupported URL scheme: {url_spec.scheme}')

    fetch_timeout = 10
    if isinstance(media_io, ImageMediaIO):
        fetch_timeout = int(os.environ.get('LMDEPLOY_IMAGE_FETCH_TIMEOUT', 10))
    elif isinstance(media_io, VideoMediaIO):
        fetch_timeout = int(os.environ.get('LMDEPLOY_VIDEO_FETCH_TIMEOUT', 30))

    client = requests.Session()
    response = client.get(url_spec.geturl(), headers=headers, timeout=fetch_timeout)
    response.raise_for_status()

    return media_io.load_bytes(response.content)


def _load_data_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:
    url_spec_path = url_spec.path or ''
    data_spec, data = url_spec_path.split(',', 1)
    media_type, data_type = data_spec.split(';', 1)
    # media_type starts with a leading "/" (e.g., "/video/jpeg")
    media_type = media_type.lstrip('/')

    if data_type != 'base64':
        msg = 'Only base64 data URLs are supported for now.'
        raise NotImplementedError(msg)

    return media_io.load_base64(media_type, data)


def _load_file_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:
    url_spec_path = url_spec.path or ''
    url_spec_netloc = url_spec.netloc or ''
    filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
    return media_io.load_file(filepath)


def load_from_url(url: str, media_io: MediaIO[_M]) -> _M:
    """Load media from a HTTP, data or file url."""
    url_spec = urlparse(url)

    if url_spec.scheme and url_spec.scheme.startswith('http'):
        return _load_http_url(url_spec, media_io)

    if url_spec.scheme == 'data':
        return _load_data_url(url_spec, media_io)

    # file url or raw file path (absolute or relative)
    if url_spec.scheme == 'file' or os.path.exists(url) or os.path.exists(url_spec.path):
        return _load_file_url(url_spec, media_io)

    msg = 'The URL must be either a HTTP, data or file URL.'
    raise ValueError(msg)


================================================
FILE: lmdeploy/vl/media/image.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/image.py

from io import BytesIO
from pathlib import Path

import pybase64
from PIL import Image, ImageFile

from .base import MediaIO

ImageFile.LOAD_TRUNCATED_IMAGES = True


class ImageMediaIO(MediaIO[Image.Image]):

    def __init__(self, image_mode: str = 'RGB', **kwargs) -> None:
        super().__init__()
        self.image_mode = image_mode

        # for potential custom arguments from --media-io-kwargs
        self.kwargs = kwargs

    def load_bytes(self, data: bytes) -> Image.Image:
        image = Image.open(BytesIO(data))
        return image.convert(self.image_mode)

    def load_base64(self, media_type: str, data: str) -> Image.Image:
        return self.load_bytes(pybase64.b64decode(data))

    def load_file(self, file_path: Path) -> Image.Image:
        with open(file_path, 'rb') as f:
            data = f.read()
        image = Image.open(BytesIO(data))
        return image.convert(self.image_mode)

    def encode_base64(self, image: Image.Image, image_format: str = 'PNG') -> str:
        with BytesIO() as buffer:
            image = image.convert(self.image_mode)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return pybase64.b64encode(data).decode('utf-8')


================================================
FILE: lmdeploy/vl/media/time_series.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO
from pathlib import Path

import numpy as np
import numpy.typing as npt
import pybase64

from lmdeploy.utils import get_logger

from .base import MediaIO

logger = get_logger('lmdeploy')


class TimeSeriesMediaIO(MediaIO[npt.NDArray]):

    def __init__(self, **kwargs):
        super().__init__()

        # for potential custom arguments from --media-io-kwargs
        self.kwargs = kwargs

    def load_bytes(self, data: bytes) -> npt.NDArray:
        ts_array = np.load(BytesIO(data), allow_pickle=False)
        return ts_array

    def load_base64(self, media_type: str, data: str) -> npt.NDArray:
        return self.load_bytes(pybase64.b64decode(data))

    def load_file(self, filepath: Path) -> npt.NDArray:
        suffix = filepath.suffix.lower()

        if suffix == '.npy':
            return np.load(filepath, allow_pickle=False)
        elif suffix == '.csv':
            try:
                ts_array = np.genfromtxt(filepath, delimiter=',', dtype=np.float32)
                if ts_array.size == 0:
                    raise ValueError(f'CSV file {filepath} yielded no data.')
                return ts_array
            except Exception as e:
                logger.error(f'Failed to load CSV {filepath}: {e}')
                raise
        elif suffix in ['.wav', '.mp3', '.flac']:
            try:
                import soundfile as sf
            except ImportError:
                raise ImportError('Please install soundfile via `pip install soundfile`.')

            ts_array, _ = sf.read(filepath)
            return ts_array

        raise ValueError(f'Unsupported file format: {suffix}')

    def encode_base64(self, data: npt.NDArray) -> str:
        """Encode numpy array to base64 string using NPY format."""
        buffer = BytesIO()
        np.save(buffer, data, allow_pickle=False)
        return pybase64.b64encode(buffer.getvalue()).decode('utf-8')


================================================
FILE: lmdeploy/vl/media/video.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/video.py

import base64
from functools import partial
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt
from PIL import Image

from lmdeploy.utils import get_logger

from .base import MediaIO
from .image import ImageMediaIO
from .video_loader import (DecordVideoLoader, OpenCVVideoLoader, TorchCodecVideoLoader, TorchVisionVideoLoader,
                           VideoLoader)

logger = get_logger('lmdeploy')


class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):

    def __init__(
        self,
        image_io: ImageMediaIO,
        num_frames: int = 32,
        **kwargs,
    ) -> None:
        super().__init__()

        self.image_io = image_io
        self.num_frames = num_frames

        # for potential custom arguments from --media-io-kwargs
        self.kwargs = kwargs
        self.video_loader = self._get_video_loader_backend()

    def _get_video_loader_backend(self) -> VideoLoader:
        """Determines the best available video loader backend."""
        # vLLM:          OpenCV
        # SGLang:        Decord
        # qwen-vl-utils: TorchCodec -> Decord -> TorchVision (deprecated soon)
        backends = [
            ('cv2', OpenCVVideoLoader),
            ('decord', DecordVideoLoader),
            ('torchcodec', TorchCodecVideoLoader),
            ('torchvision', TorchVisionVideoLoader),
        ]

        for module_name, loader_cls in backends:
            try:
                __import__(module_name)
                return loader_cls()
            except (ImportError, RuntimeError):
                logger.warning(f"Video backend '{module_name}' not found. Trying next backend...")
                continue

        raise ImportError(
            'No video backend found. Install either opencv-python-headless, decord, torchcodec, or torchvision.')

    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
        return self.video_loader.load_bytes(data, num_frames=self.num_frames, **self.kwargs)

    def load_base64(self, media_type: str, data: str) -> tuple[npt.NDArray, dict[str, Any]]:
        if media_type.lower() == 'video/jpeg':
            load_frame = partial(
                self.image_io.load_base64,
                'image/jpeg',
            )

            # NOTE: known issue in https://github.com/QwenLM/Qwen3-VL/issues/1643
            # when passing a video as a sequence of JPEG frames, we cannot obtain the video metadata
            # therefore we construct a default metadata dictionary with common values.
            frames = np.stack([np.asarray(load_frame(frame_data)) for frame_data in data.split(',')])

            total_frames_num = int(frames.shape[0])
            fps = float(self.kwargs.get('fps', 2))  # default to 2 fps if not specified
            duration = (total_frames_num / fps) if fps > 0 else 0
            frame_idx = list(range(total_frames_num))

            metadata = {
                'total_num_frames': total_frames_num,
                'fps': fps,
                'duration': duration,
                'video_backend': 'jpeg_sequence',
                'frames_indices': frame_idx,
            }

            logger.info('Loading video from base64-encoded JPEG frames misses video metadata.'
                        f'Fall back to default metadata values:\n{metadata}')
            return frames, metadata

        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
        return self.video_loader.load_file(filepath, num_frames=self.num_frames, **self.kwargs)

    def encode_base64(
        self,
        media: npt.NDArray,
        *,
        video_format: str = 'JPEG',
    ) -> str:
        video = media

        if video_format == 'JPEG':
            encode_frame = partial(
                self.image_io.encode_base64,
                image_format=video_format,
            )

            return ','.join(encode_frame(Image.fromarray(frame)) for frame in video)

        msg = 'Only JPEG format is supported for now.'
        raise NotImplementedError(msg)


================================================
FILE: lmdeploy/vl/media/video_loader.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/video.py
# adapted from https://github.com/QwenLM/Qwen3-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py

import math
import os
import tempfile
from abc import abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


class VideoLoader:

    @classmethod
    @abstractmethod
    def load_bytes(self, data: bytes, num_frames: int = -1, **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        raise NotImplementedError

    @classmethod
    def smart_nframes(self, total_frames_num: int, num_frames: int, fps: int, duration: int) -> tuple[int, list[int]]:
        # resample video to target num_frames and fps
        # - the minimum of the two will be used
        num_frames_to_sample = total_frames_num
        if num_frames > 0:
            num_frames_to_sample = min(num_frames, total_frames_num)
        if fps > 0:
            num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
        num_frames_to_sample = max(1, num_frames_to_sample)  # at least one sample

        if num_frames_to_sample == total_frames_num:
            frame_idx = list(range(0, num_frames_to_sample))
        else:
            uniform_sampled_frames = np.linspace(0, total_frames_num - 1, num_frames_to_sample, dtype=int)
            frame_idx = uniform_sampled_frames.tolist()
        return num_frames_to_sample, frame_idx


class OpenCVVideoLoader(VideoLoader):

    def get_cv2_video_api(self):
        import cv2.videoio_registry as vr

        api_pref = None
        for backend in vr.getStreamBufferedBackends():
            if not vr.hasBackend(backend):
                continue
            if not vr.isBackendBuiltIn(backend):
                _, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
                if abi < 1 or (abi == 1 and api < 2):
                    continue
            api_pref = backend
            break
        return api_pref

    @staticmethod
    def _read_frames(
        cap,
        frame_indices: set[int],
        num_expected_frames: int,
        max_frame_idx: int,
    ) -> tuple[npt.NDArray, int, list[int]]:
        import cv2

        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)  # THWC

        i = 0
        valid_frame_indices = []
        for idx in range(max_frame_idx + 1):
            ok = cap.grab()
            if not ok:
                # Frame is broken/unreadable, log warning
                if idx in frame_indices:
                    logger.warning(
                        'Failed to grab frame %d during video loading. '
                        'This frame will be skipped.',
                        idx,
                    )
                continue
            if idx in frame_indices:
                ret, frame = cap.retrieve()
                if ret:
                    frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    valid_frame_indices.append(idx)
                    i += 1
                else:
                    # retrieve() failed even though grab() succeeded
                    logger.warning(
                        'Failed to retrieve frame %d during video loading. '
                        'This frame will be skipped.',
                        idx,
                    )

        valid_num_frames = len(valid_frame_indices)
        if valid_num_frames < num_expected_frames:
            logger.warning(
                'Video loading completed with %d broken/unreadable frames. '
                'Expected %d frames but only loaded %d frames.',
                num_expected_frames - valid_num_frames,
                num_expected_frames,
                valid_num_frames,
            )

        return frames[:valid_num_frames], valid_num_frames, valid_frame_indices

    @classmethod
    def load_file(
        self,
        filepath: Path,
        num_frames: int = -1,
        fps: int = -1,
        max_duration: int = 300,
        **kwargs,
    ) -> tuple[npt.NDArray, dict[str, Any]]:
        with open(filepath, 'rb') as f:
            data = f.read()
        return self.load_bytes(data, num_frames=num_frames, fps=fps, max_duration=max_duration, **kwargs)

    @classmethod
    def load_bytes(
        cls,
        data: bytes,
        num_frames: int = -1,
        fps: int = -1,
        max_duration: int = 300,
        **kwargs,
    ) -> tuple[npt.NDArray, dict[str, Any]]:
        """Load video frames from bytes.

        Args:
            data: Raw video bytes
            num_frames: Target number of frames to sample (-1 for all)
            fps: Target FPS for sampling (-1 for original)
            max_duration: Maximum duration (unused in base backend)

        Returns:
            Tuple of (frames_array, metadata_dict)
        """
        import cv2

        backend = cls().get_cv2_video_api()
        cap = cv2.VideoCapture(BytesIO(data), backend, [])
        if not cap.isOpened():
            raise ValueError('Could not open video stream')

        total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        original_fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames_num / original_fps if original_fps > 0 else 0

        num_frames_to_sample, frame_idx = cls.smart_nframes(total_frames_num, num_frames, fps, duration)

        frame_idx_set = set(frame_idx)
        frames, valid_num_frames, valid_frame_indices = cls._read_frames(cap, frame_idx_set, num_frames_to_sample,
                                                                         max(frame_idx))

        # Use transformers transformers.video_utils.VideoMetadata format
        # For models like Qwen3-VL/GLM4.5V, this metadata
        # can cause incorrect timestamp calculation without num_frames=-1.
        # TODO: zhouxinyu, support per-request do_sample_frames
        metadata = {
            'total_num_frames': total_frames_num,
            'fps': original_fps,
            'duration': duration,
            'video_backend': 'opencv',
            'frames_indices': valid_frame_indices,
            # extra field used to control hf processor's video
            # sampling behavior
            # "do_sample_frames": valid_num_frames == total_frames_num,
        }
        return frames, metadata


class DecordVideoLoader(VideoLoader):

    @classmethod
    def load_file(self,
                  filepath: Path,
                  num_frames: int = -1,
                  fps: int = -1,
                  max_duration: int = 300,
                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        import decord
        vr = decord.VideoReader(str(filepath))
        total_frames_num = len(vr)
        original_fps = vr.get_avg_fps()
        duration = total_frames_num / original_fps if original_fps > 0 else 0

        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)

        video = vr.get_batch(frame_idx).asnumpy()  # THWC
        metadata = {
            'total_num_frames': total_frames_num,
            'fps': original_fps,
            'duration': duration,
            'video_backend': 'decord',
            'frames_indices': frame_idx,
        }
        return video, metadata

    @classmethod
    def load_bytes(self,
                   data: bytes,
                   num_frames: int = -1,
                   fps: int = -1,
                   max_duration: int = 300,
                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
        try:
            tmp_file.write(data)
            tmp_file.close()
            return self.load_file(Path(tmp_file.name),
                                  num_frames=num_frames,
                                  fps=fps,
                                  max_duration=max_duration,
                                  **kwargs)
        finally:
            # always cleanup, even if load_file crashes
            try:
                os.unlink(tmp_file.name)
            except OSError:
                pass  # file might not exist if write failed


class TorchCodecVideoLoader(VideoLoader):

    @classmethod
    def load_file(self,
                  filepath: Path,
                  num_frames: int = -1,
                  fps: int = -1,
                  max_duration: int = 300,
                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        # torchcodec requires matched ffmpeg, torchcodec, and torch versions
        # ffmpeg 5.1.2, torch 2.8.0, torchcodec 0.7.0 are verified to work together
        from torchcodec.decoders import VideoDecoder

        torch_codec_num_threads = 8
        decoder = VideoDecoder(str(filepath), num_ffmpeg_threads=torch_codec_num_threads)
        total_frames_num = decoder.metadata.num_frames
        original_fps = decoder.metadata.average_fps
        duration = total_frames_num / original_fps if original_fps > 0 else 0

        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)

        video = decoder.get_frames_at(frame_idx).data
        metadata = {
            'total_num_frames': total_frames_num,
            'fps': original_fps,
            'duration': duration,
            'video_backend': 'torchcodec',
            'frames_indices': frame_idx,
        }
        return video, metadata

    @classmethod
    def load_bytes(self,
                   data: bytes,
                   num_frames: int = -1,
                   fps: int = -1,
                   max_duration: int = 300,
                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
        try:
            tmp_file.write(data)
            tmp_file.close()
            return self.load_file(Path(tmp_file.name),
                                  num_frames=num_frames,
                                  fps=fps,
                                  max_duration=max_duration,
                                  **kwargs)
        finally:
            # always cleanup, even if load_file crashes
            try:
                os.unlink(tmp_file.name)
            except OSError:
                pass  # file might not exist if write failed


class TorchVisionVideoLoader(VideoLoader):

    @classmethod
    def load_file(self,
                  filepath: Path,
                  num_frames: int = -1,
                  fps: int = -1,
                  max_duration: int = 300,
                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        import torchvision

        video, audio, info = torchvision.io.read_video(
            filepath,
            pts_unit='sec',
            output_format='THWC',
        )
        total_frames_num = video.size(0)
        original_fps = info['video_fps']
        duration = total_frames_num / original_fps if original_fps > 0 else 0

        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)

        video = video[frame_idx]
        metadata = {
            'total_num_frames': total_frames_num,
            'fps': original_fps,
            'duration': duration,
            'video_backend': 'torchvision',
            'frames_indices': frame_idx,
        }
        return video, metadata

    @classmethod
    def load_bytes(self,
                   data: bytes,
                   num_frames: int = -1,
                   fps: int = -1,
                   max_duration: int = 300,
                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
        try:
            tmp_file.write(data)
            tmp_file.close()
            return self.load_file(Path(tmp_file.name),
                                  num_frames=num_frames,
                                  fps=fps,
                                  max_duration=max_duration,
                                  **kwargs)
        finally:
            # always cleanup, even if load_file crashes
            try:
                os.unlink(tmp_file.name)
            except OSError:
                pass  # file might not exist if write failed


================================================
FILE: lmdeploy/vl/model/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/vl/model/base.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from itertools import groupby
from typing import Dict, List, Union

import numpy as np
from mmengine import Registry
from transformers import AutoConfig, AutoTokenizer

from lmdeploy.archs import get_model_arch

VISION_MODELS = Registry('vision_model')


class VisionModel(ABC):
    """Visual model which extract image feature."""
    _arch: Union[str, List[str]] = None

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        """init."""
        self.model_path = model_path
        self.with_llm = with_llm
        self.max_memory = max_memory
        self.backend = backend
        if hf_config is None:
            _, hf_config = get_model_arch(model_path)
        self.hf_config = hf_config
        self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0

    def get_pad_token_id(self, model_path, hf_config):
        """Get pad_token_id from hf_config or tokenizer."""
        pad_token_id = getattr(hf_config, 'pad_token_id', None)
        if pad_token_id is None:
            try:
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
                pad_token_id = getattr(tokenizer, 'pad_token_id', None)
            except Exception as e:
                print(e)
                pass
        return pad_token_id

    @abstractmethod
    def build_preprocessor(self, ):
        """Build the preprocessor.

        NOTE: When the derived class implements this method, try not to
        introduce the upper stream model repo as a thirdparty package
        """
        raise NotImplementedError()

    def build_model(self, ):
        """Build the vision part of a VLM model when backend is turbomind.

        But when `with_llm=True`, load the whole VLM model
        """
        if self.backend == 'turbomind' or self.with_llm:
            raise NotImplementedError()

    @abstractmethod
    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Preprocess multimodal data in the messages.

        The derived class,
        i.e., a specific vision model, takes the charge of image preprocessing
        and the result management.
        It can integrate the result into the messages list, or insert it to
        the individual image item.
        Args:
            message(Dict): multimodal data in a dict, which is as follows:
            [
                {'role': 'user', 'content': 'user prompt'},
                {'role': 'assisant', 'content': 'AI reponse'},
                {
                    'role': 'user',
                    'content': [
                        {
                            'type': 'text',
                            'text': 'string',
                        },
                        {
                            'type': 'image',
                            'image': pillow.Image,
                            'key1': value1,
                            ...
                        },
                        {
                            'type': 'image',
                            'image': pillow.Image,
                            'key1': value1,
                            ...
                        },
                        ...
                    ]
                }
                {....}
            ]
        Returns:
            the message list with preprocessing results included, which is
            determined by the derived classes
        """  # noqa
        raise NotImplementedError()

    def has_input_ids(self, messages: List[Dict]) -> bool:
        """Check whether the messages contain input_ids directly.

        Args:
            messages (List[Dict]): a list of message, which is supposed to be
                the output of `preprocess`
        Returns:
            bool: whether the messages contain input_ids directly
        """
        users = [x['content'] for x in messages if x['role'] == 'user']
        return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List)

    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included, which is
            determined by the derived classes
        """
        if self.backend == 'turbomind':
            raise NotImplementedError()

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
        """Pack the preprocessing results in a format compatible with what is
        required by pytorch engine. ONLY implement it when the backend is
        pytorch engine.

        Args:
            messages(List[Dict]): the output of `preprocess`
            chat_template: the chat template defined in `lmdeploy/model.py`
            tokenzer: the tokenizer model
            sequence_start: starting flag of a sequence
            chat_template_kwargs: additional arguments for chat template
                processing, such as `add_vision_id` and `enable_thinking`
        """
        if self.backend == 'pytorch':
            raise NotImplementedError()

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
        """Pack the forwarding results in a format compatible with what is
        required by turbomind engine. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the output of `preprocess`
            chat_template: the chat template defined in `lmdeploy/model.py`
            tokenzer: the tokenizer model
            sequence_start: starting flag of a sequence
            chat_template_kwargs: additional arguments for chat template
                processing, such as `add_vision_id` and `enable_thinking`
        """
        if self.backend == 'turbomind':
            raise NotImplementedError()

    @staticmethod
    def collect_multimodal_items(messages):
        """Gather all multimodal items along with their respective parameters
        from the messages and compile them into a single list.

        Args:
            messages (List[Dict]): a list of message
        Returns:
            List[Tuple[Modality, Any, Dict]]: a list of (modality, data, params) for each multimodal item
        """
        multimodal_items = []
        for message in messages:
            content = message['content']
            if not isinstance(content, list):
                continue

            for x in content:
                if not isinstance(x, dict):
                    continue

                modality = x.get('type')
                if modality is None or modality == 'text':
                    continue

                data = x.get('data')
                params = {k: v for k, v in x.items() if k not in ['type', 'data']}
                multimodal_items.append((modality, data, params))

        return multimodal_items

    @staticmethod
    def IMAGE_TOKEN_included(messages):
        """Check whether the IMAGE_TOKEN is included in the messages.

        Args:
            messages (List[Dict]): a list of message
        Returns:
            bool: whether the IMAGE_TOKEN is included in the messages
        """
        for message in messages:
            role, content = message['role'], message['content']
            if role != 'user':
                continue
            if isinstance(content, str) and '' in content:
                return True
            elif isinstance(content, List):
                content = [x['text'] for x in content if x['type'] == 'text']
                if any('' in x for x in content):
                    return True
        return False

    def to_pytorch_with_input_ids(self, messages):
        """Pack the preprocessing results in a format compatible with what is
        required by pytorch engine when input_ids are provided directly.

        Args:
            messages(List[Dict]): the output of `preprocess`
        """
        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        _input_ids = messages[0]['content'][0]['text']
        segs = []
        for k, g in groupby(_input_ids, lambda x: x == self.image_token_id):
            if not k:
                segs.append(list(g))
            else:
                segs.extend([[]] * (len(list(g)) - 1))
        if _input_ids[0] == self.image_token_id:
            segs = [[]] + segs
        if _input_ids[-1] == self.image_token_id:
            segs = segs + [[]]

        assert self.image_token_id == preps[0]['image_token_id']
        assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal '
                                             f'to input images, {len(segs) - 1} vs {len(preps)}')
        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                preps[i - 1].update(offset=len(input_ids))
                image_tokens = preps[i - 1]['image_tokens']
                input_ids.extend([self.image_token_id] * image_tokens)
            input_ids.extend(seg)

        return dict(prompt=None, input_ids=input_ids, multimodal=preps)

    def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
        """Auxiliary function to pack the preprocessing results in a format
        compatible with what is required by pytorch engine.

        Args:
            messages(List[Dict]): the output of `preprocess`
            prompt(str): the prompt after applying chat template
            IMAGE_TOKEN(str): a placeholder where image tokens will be
                inserted
            tokenzer: the tokenizer model
            sequence_start: starting flag of a sequence
        """
        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        # split prompt into segments and validate data
        segs = prompt.split(IMAGE_TOKEN)
        assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal '
                                             f'to input images, {len(segs) - 1} vs {len(preps)}')

        # calculate the image token offset for each image
        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                preps[i - 1].update(offset=len(input_ids))
                image_tokens = preps[i - 1]['image_tokens']
                assert self.image_token_id == preps[i - 1]['image_token_id']
                input_ids.extend([self.image_token_id] * image_tokens)
            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(token_ids)

        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)

    def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
        """Auxiliary function to pack the forwarding results in a format
        compatible with what is required by turbomind engine.

        Args:
            messages(List[Dict]): the output of `preprocess`
            prompt(str): the prompt after applying chat template
            IMAGE_TOKEN(str): a placeholder where image tokens will be
                inserted
            tokenzer: the tokenizer model
            sequence_start: starting flag of a sequence
        """
        # collect image features from messages
        features = [x['content'] for x in messages if x['role'] == 'forward']
        features = features[0]
        features = [x.cpu() for x in features]
        # split prompt into segments and validate data
        segs = prompt.split(IMAGE_TOKEN)
        assert len(segs) == len(features) + 1, (f'the number of {IMAGE_TOKEN} is not equal '
                                                f'to input images, {len(segs) - 1} vs {len(features)}')

        # tokenizer prompt, and get input_embeddings and input_embedding_ranges
        input_ids = []
        begins = []
        ends = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(features):
                image_dim = features[i - 1].shape[0]
                begins.append(len(input_ids))
                ends.append(begins[-1] + image_dim)
                input_ids.extend([self.image_token_id] * image_dim)
            seg_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(seg_ids)
        ranges = np.stack([begins, ends], axis=1).tolist()
        return dict(prompt=prompt, input_ids=input_ids, input_embeddings=features, input_embedding_ranges=ranges)

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch and (arch == cls._arch or arch in cls._arch):
            return True
        return False


================================================
FILE: lmdeploy/vl/model/builder.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional, Union

import torch

from lmdeploy.archs import get_model_arch
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.utils import get_logger, get_model
from lmdeploy.vl.model.base import VISION_MODELS

from .cogvlm import CogVLMVisionModel  # noqa F401
from .deepseek import DeepSeekVisionModel  # noqa F401
from .deepseek_vl2 import DeepSeek2VisionModel  # noqa F401
from .gemma3_vl import Gemma3VisionModel  # noqa F401
from .glm4_1v import GLM4_1_VisionModel  # noqa F401
from .glm4_v import GLM4VisionModel  # noqa F401
from .interns1_pro import InternS1ProVisionModel  # noqa F401
from .internvl import InternVLVisionModel  # noqa F401
from .internvl3_hf import InternVL3VisionModel  # noqa F401
from .internvl_llava import InternVLLlavaVisionModel  # noqa F401
from .llama4 import LLama4VisionModel  # noqa F401
from .llava import LlavaVisionModel  # noqa F401
from .llava_hf import LlavaHfVisionModel  # noqa F401
from .llava_next import LlavaNextVisionModel  # noqa F401
from .minicpmv import MiniCPMVModel  # noqa F401
from .mllama import MllamaVLModel  # noqa F401
from .molmo import MolmoVisionModel  # noqa F401
from .phi3_vision import Phi3VisionModel  # noqa F401
from .qwen import QwenVisionModel  # noqa F401
from .qwen2 import Qwen2VLModel  # noqa F401
from .qwen3 import Qwen3VLModel  # noqa F401
from .qwen3_5 import Qwen3_5Model  # noqa F401
from .xcomposer2 import Xcomposer2VisionModel  # noqa F401
from .yi import YiVisionModel  # noqa F401

logger = get_logger('lmdeploy')


def load_vl_model(model_path: str,
                  backend: str,
                  with_llm: bool = False,
                  backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None):
    """Load visual model.

    Args:
        model_path(str): the path or repo_id from model hub of the model
        backend(str): the name of inference backend
        with_llm(bool): load LLM model or not. Set it to False for VLM
            inference scenarios and True for VLM quantization
        backend_config: the config of the inference engine
    """
    if not os.path.exists(model_path):
        revision = getattr(backend_config, 'revision', None)
        download_dir = getattr(backend_config, 'download_dir', None)
        model_path = get_model(model_path, revision=revision, download_dir=download_dir)

    max_memory = None
    if not with_llm:
        tp = getattr(backend_config, 'tp', 1)
        max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} if backend == 'turbomind' else None

    _, hf_config = get_model_arch(model_path)
    kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, hf_config=hf_config, backend=backend)

    for name, module in VISION_MODELS.module_dict.items():
        try:
            if module.match(hf_config):
                logger.info(f'matching vision model: {name}')
                model = module(**kwargs)
                model.build_preprocessor()
                # build the vision part of a VLM model when backend is
                # turbomind, or load the whole VLM model when `with_llm==True`
                if backend == 'turbomind' or with_llm:
                    model.build_model()
                return model
        except Exception as e:
            logger.error(f'build vision model {name} failed, {e}')
            raise

    raise ValueError(f'unsupported vl model with config {hf_config}')


================================================
FILE: lmdeploy/vl/model/cogvlm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class CogVLMVisionModel(VisionModel):
    """CogVLM vision model."""

    _arch = 'CogVLMForCausalLM'

    def build_preprocessor(self):
        from torchvision import transforms
        self.image_transform = transforms.Compose([
            transforms.Resize((self.hf_config.vision_config['image_size'], ) * 2,
                              interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
        image_size = self.hf_config.vision_config['image_size']
        patch_size = self.hf_config.vision_config['patch_size']
        if self.hf_config.vision_config['num_positions'] == 1226:
            # cogvlm-chat-hf, https://huggingface.co/THUDM/cogvlm-chat-hf/blob/e29dc3ba206d524bf8efbfc60d80fc4556ab0e3c/modeling_cogvlm.py#L820 # noqa E501
            self.n_token_per_image = 2 + (image_size // patch_size)**2
        else:
            # cogvlm2, https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B/blob/2c2226281325649d49b8aa237a932367c7da4f26/modeling_cogvlm.py#L819 # noqa E501
            self.n_token_per_image = 2 + (image_size // patch_size // 2)**2

    def build_model(self):
        if self.with_llm:
            from transformers import AutoModelForCausalLM
            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,
                                                                 device_map='cpu',
                                                                 trust_remote_code=True)
        else:
            raise NotImplementedError('turbomind has not supported cogvlm yet')

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to the spec of `super().preprocess`"""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, _ in images:
            image = image.convert('RGB')
            pixel_values = self.image_transform(image)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=self.n_token_per_image,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])

            prompt_messages.append(dict(role='user', content=content[0], num_images=n_images))

        from lmdeploy.model import Vicuna
        llm_chat_template = Vicuna(eoa='', stop_words=chat_template.stop_words)
        prompt = ''
        IMAGE_TOKEN = ''
        for i, msg in enumerate(prompt_messages):
            num_images = msg.pop('num_images', 0)
            if num_images == 0:
                role = msg['role']
                msg = llm_chat_template.messages2prompt([msg], sequence_start and i == 0)
                msg = dict(role=role, content=msg)
            prompt_i = chat_template.messages2prompt([msg], sequence_start and i == 0)
            if num_images > 0:
                prompt_i = (IMAGE_TOKEN * num_images) + prompt_i
            prompt += prompt_i
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/deepseek.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


def check_deepseek_vl_install():
    """Check deepseek_vl install."""
    try:
        import deepseek_vl  # noqa: F401
    except ImportError:
        raise ImportError('To use DeepSeekVLModel, please install deepseek_vl by '
                          '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'
                          ' --no-deps`')


@VISION_MODELS.register_module()
class DeepSeekVisionModel(VisionModel):
    """Qwen vision model."""

    _arch = 'MultiModalityCausalLM'

    def build_preprocessor(self):
        check_deepseek_vl_install()
        from deepseek_vl.models import VLChatProcessor
        vl_chat_processor = VLChatProcessor.from_pretrained(self.model_path)
        tokenizer = vl_chat_processor.tokenizer
        self.image_token_id = tokenizer.vocab.get(vl_chat_processor.image_tag)
        self.image_processor = vl_chat_processor.image_processor

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights():
            warnings.simplefilter('ignore')
            model = AutoModelForCausalLM.from_pretrained(self.model_path)
            self.vl_model = model
            if not self.with_llm:
                del model.language_model

        from accelerate.utils import get_balanced_memory, infer_auto_device_map
        max_memory = get_balanced_memory(model,
                                         max_memory=self.max_memory,
                                         dtype=torch.half,
                                         no_split_module_classes=['Block'])
        device_map = infer_auto_device_map(model,
                                           no_split_module_classes=['Block'],
                                           max_memory=max_memory,
                                           dtype=torch.half)
        same_device_keys = [('vision_model.vision_tower_high.vision_tower.pos_embed',
                             'vision_model.vision_tower_high.vision_tower.patch_embed'),
                            ('vision_model.vision_tower_low.vision_tower.pos_embed',
                             'vision_model.vision_tower_low.vision_tower.patch_embed')]
        for (a, b) in same_device_keys:
            if a in device_map and b in device_map:
                device_map[b] = device_map[a]
        downsamples = []
        ka = 'vision_model.vision_tower_high.vision_tower.downsamples'
        kb = 'vision_model.vision_tower_high.vision_tower.hd_alpha_downsamples'  # noqa: E501
        for k in device_map:
            if k.startswith(ka):
                downsamples.append(k)
        if len(downsamples) == 1:
            device_map[ka] = device_map[kb]
        elif len(downsamples) > 1:
            numbers = [int(x[len(ka) + 1:]) for x in downsamples]
            device_map[f'{ka}.{numbers[-1]}'] = device_map[kb]

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map=device_map if not self.with_llm else {'': 'cpu'},
                                         dtype=torch.half)

        self.model = model.eval()
        self.vision_model = model.vision_model.eval()
        self.aligner = model.aligner.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to the spec of `super.preprocess()"""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, _ in images:
            image = image.convert('RGB')
            pixel_values = self.image_processor([image], return_tensors='pt').pixel_values
            outputs.append(
                dict(
                    pixel_values=pixel_values,
                    image_size=image.size,
                    # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py  # noqa
                    # which is hardcoded 576
                    image_tokens=576,
                    image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(device=next(self.vision_model.parameters()).device, dtype=torch.float16)
            # [b x n_images, T2, D]
            logger.info(f'vision forward shape: {pixel_values.shape}')
            feats = self.aligner(self.vision_model(pixel_values))
            feats = torch.split(feats, 1, dim=0)
            outputs.extend([x.squeeze() for x in feats])
        messages.append(dict(role='forward', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        # apply chat template to get the prompt
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            content = content[0]
            n_image = sum([1 for x in message['content'] if x['type'] == 'image'])
            n_placeholder = content.count(IMAGE_TOKEN)
            if n_placeholder == 0:
                logger.warning(f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN}
                    to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html
                    for more details.""")  # noqa
            if n_placeholder != 0 and n_placeholder != n_image:
                logger.error(f'unmatched placeholder and image: {n_placeholder} vs '
                             f'{n_image}. Ignore the placeholder')
                content = content.replace(IMAGE_TOKEN, '')
                n_placeholder = 0
            if n_placeholder == 0:
                if n_image == 1:
                    content = f'{IMAGE_TOKEN}{content}'
                else:
                    content = ''.join([f'{IMAGE_TOKEN} is Figure {str(i)}.\n' for i in range(n_image)]) + content
            prompt_messages.append(dict(role='user', content=content))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/deepseek_vl2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
from contextlib import redirect_stdout
from typing import Dict, List

import torch
from transformers import AutoConfig

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


def check_deepseek_vl2_install():
    """Check deepseek_vl2 install."""
    try:
        import deepseek_vl2  # noqa: F401
    except ImportError:
        raise ImportError('To use DeepSeek-VL2, please install deepseek_vl2 by '
                          '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git'
                          ' --no-deps`')


def check_trans_version():
    """Check if the installed version of the 'transformers' library is smaller
    than the specified version."""
    import transformers
    from packaging import version

    max_version = '4.48.0'
    installed_version = transformers.__version__
    assert version.parse(installed_version) < version.parse(
        max_version
    ), f'deepseek_vl2 requires transformers version < 4.48.0, but found version: {installed_version}. Please downgrade.'


@VISION_MODELS.register_module()
class DeepSeek2VisionModel(VisionModel):
    """DeepSeek2 vision model."""

    _arch = 'DeepseekV2ForCausalLM'

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        if hasattr(config, 'language_config') and hasattr(config, 'vision_config'):
            arch = config.language_config.get('architectures', [None])[0]
            return arch == cls._arch
        return False

    def build_preprocessor(self):
        check_trans_version()
        check_deepseek_vl2_install()
        from deepseek_vl2.models.processing_deepseek_vl_v2 import DeepseekVLV2Processor

        # suppress deepseek-vl2 processor initialization print logs
        with open(os.devnull, 'w') as devnull:
            with redirect_stdout(devnull):
                self.image_processor = DeepseekVLV2Processor.from_pretrained(self.model_path,
                                                                             image_token='')
                self.image_token_id = self.image_processor.image_token_id

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        # TODO, implement for tubomind engine
        raise NotImplementedError()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to the spec of `super.preprocess()"""
        images = self.collect_multimodal_items(messages)

        # convert to upstream api formats
        images = [item[1] for item in images]
        formatted_messages = []
        for message in messages:
            text_content = DeepSeek2VisionModel.proc_single_message(message)
            image_content = [x['image'] for x in message['content'] if x['type'] == 'image']
            formatted_messages.append(dict(role=message['role'], content=text_content, images=image_content))

        # NOTE: DeepseekVLV2Processor inputs
        # conversations (List[Dict]): conversations with a list of messages;
        # images (List[ImageType]): the list of images;
        # force_batchify (bool): force batchify the inputs;
        # inference_mode (bool): if True, then remove the last eos token;
        prepare = self.image_processor(conversations=formatted_messages,
                                       images=images,
                                       force_batchify=False,
                                       inference_mode=False)

        messages.append(
            dict(role='preprocess',
                 content=[
                     dict(
                         pixel_values=prepare.images,
                         image_tokens=prepare.num_image_tokens[0],
                         image_token_id=self.image_processor.image_token_id,
                         image_size=self.image_processor.image_size,
                         images_spatial_crop=prepare.images_spatial_crop,
                     )
                 ]))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        # TODO, implement for turbomind engine
        raise NotImplementedError()

    @staticmethod
    def proc_single_message(message):
        IMAGE_TOKEN = ''

        if isinstance(message['content'], str):
            return message
        elif message['role'] in ['images', 'preprocess', 'forward']:
            return None

        content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
        content = content[0]
        n_image = sum([1 for x in message['content'] if x['type'] == 'image'])
        n_placeholder = content.count(IMAGE_TOKEN)
        if n_placeholder == 0:
            logger.warning(f"""for deepseek-vl2 model, the user should insert the {IMAGE_TOKEN}
                to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html
                for more details.""")  # noqa
        if n_placeholder != 0 and n_placeholder != n_image:
            logger.error(f'unmatched placeholder and image: {n_placeholder} vs '
                         f'{n_image}. Ignore the placeholder')
            content = content.replace(IMAGE_TOKEN, '')
            n_placeholder = 0
        if n_placeholder == 0:
            if n_image == 1:
                content = f'{IMAGE_TOKEN}{content}'
            else:
                content = ''.join([f'{IMAGE_TOKEN} is Figure {str(i)}.\n' for i in range(n_image)]) + content
        return content

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            content = DeepSeek2VisionModel.proc_single_message(message)
            if content is None:
                continue
            prompt_messages.append(dict(role='user', content=content))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/gemma3_vl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from transformers import AutoConfig, AutoProcessor
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


class Gemma3ImagesKwargs(ImagesKwargs):
    do_pan_and_scan: Optional[bool]
    pan_and_scan_min_crop_size: Optional[int]
    pan_and_scan_max_num_crops: Optional[int]
    pan_and_scan_min_ratio_to_activate: Optional[float]
    do_convert_rgb: Optional[bool]


class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
    images_kwargs: Gemma3ImagesKwargs
    _defaults = {
        'text_kwargs': {
            'padding': False,
        },
        'images_kwargs': {
            'do_pan_and_scan': False,
            'pan_and_scan_min_crop_size': 256,
            'pan_and_scan_max_num_crops': 4,
            'pan_and_scan_min_ratio_to_activate': 1.2,
        },
    }


@VISION_MODELS.register_module()
class Gemma3VisionModel(VisionModel):
    """Gemma3 vision model."""

    _arch = 'Gemma3ForConditionalGeneration'

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        super().__init__(model_path, with_llm, max_memory, hf_config, backend)

    def build_preprocessor(self):
        self.processor = AutoProcessor.from_pretrained(self.model_path)
        tokenizer = self.processor.tokenizer
        self.image_token_id = tokenizer.encode(tokenizer.image_token)[-1]
        self.image_tokens = self.processor.image_seq_length
        self.tokenizer_init_kwargs = tokenizer.init_kwargs

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        # TODO, implement for tubomind engine
        raise NotImplementedError()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        from transformers.image_utils import make_nested_list_of_images
        output_kwargs = self.processor._merge_kwargs(
            Gemma3ProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer_init_kwargs,
            **{
                'return_tensors': 'pt',
                'add_special_tokens': False
            },
        )
        images = self.collect_multimodal_items(messages)
        images = [image.convert('RGB') for modality, image, _ in images]
        num_image = len(images)
        images = make_nested_list_of_images(images)
        image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs'])
        outputs = []
        for idx in range(num_image):
            pixel_values = image_inputs['pixel_values'][idx:idx + 1, ...]
            num_crops = image_inputs['num_crops'][:idx:idx + 1]
            data = dict(pixel_values=pixel_values,
                        num_crops=num_crops,
                        image_tokens=self.image_tokens,
                        image_token_id=self.image_token_id)
            outputs.append(data)

        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        # TODO, implement for turbomind engine
        raise NotImplementedError()

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [item['text'] for item in message['content'] if item['type'] == 'text']
            prompt = ('\n\n' + IMAGE_TOKEN + '\n\n') * n_images + content[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/glm4_1v.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

from transformers import AutoConfig

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class GLM4_1_VisionModel(VisionModel):
    """GLM-4.1V-9B-Thinking model."""

    _arch = ['Glm4vForConditionalGeneration']

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch in cls._arch and hasattr(config, 'vision_config'):
            return True
        return False

    def build_preprocessor(self):
        from transformers import AutoProcessor
        self.processor = AutoProcessor.from_pretrained(self.model_path)
        tokenizer = self.processor.tokenizer
        image_token = self.processor.image_token
        self.image_token_id = tokenizer.encode(image_token)[-1]

    def build_model(self):
        raise NotImplementedError('turbomind has not supported glm4v yet')

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess()` for spec."""
        images = self.collect_multimodal_items(messages)
        optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')

            item = dict(type='image', image=image)
            item.update({key: params[key] for key in params.keys() if key in optional_keys})
            result = self.processor.image_processor(images=image, videos=None, return_tensors='pt')
            merge_length = self.processor.image_processor.merge_size**2
            image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
            outputs.append(result)
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [item['text'] for item in message['content'] if item['type'] == 'text']
            prompt = content[0]
            if IMAGE_TOKEN in prompt and '<|begin_of_image|>' not in prompt:
                prompt = prompt.replace(IMAGE_TOKEN, f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>')
            else:
                prompt = f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>' * \
                    n_images + prompt
            prompt_messages.append(dict(role=message['role'], content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/glm4_v.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

from transformers import AutoConfig

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class GLM4VisionModel(VisionModel):
    """Glm-4v-9b vision model."""

    _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration']

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch in cls._arch and hasattr(config, 'vision_config'):
            return True
        return False

    def build_preprocessor(self):
        from torchvision import transforms
        self.image_transform = transforms.Compose([
            transforms.Resize((self.hf_config.vision_config['image_size'], ) * 2,
                              interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
        image_size = self.hf_config.vision_config['image_size']
        patch_size = self.hf_config.vision_config['patch_size']
        self.n_token_per_image = 2 + (image_size // patch_size // 2)**2

    def build_model(self):
        if self.with_llm:
            from transformers import AutoModelForCausalLM
            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,
                                                                 device_map='cpu',
                                                                 trust_remote_code=True)
        else:
            raise NotImplementedError('turbomind has not supported glm4v yet')

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to the spec of `super.preprocess()"""
        outputs = []
        for message in messages:
            if not isinstance(message['content'], List):
                continue
            images = [x['image'] for x in message['content'] if x['type'] == 'image']
            if len(images) > 1:
                logger.warning(f'glm4v does not support the input of multiple images'
                               f' in a single chat round, but got {len(images)} images.')
            # we still pass all the images to the model and let the
            # model decide what to do
            images = [x.convert('RGB') for x in images]
            pixel_values = [self.image_transform(x) for x in images]
            outputs.extend([
                dict(pixel_values=_2,
                     image_size=_1.size,
                     image_tokens=self.n_token_per_image,
                     image_token_id=self.image_token_id) for _1, _2 in zip(images, pixel_values)
            ])
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            content = message['content']
            if isinstance(content, str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['preprocess', 'forward']:
                continue
            prompt = [x['text'] for x in content if x['type'] == 'text']
            n_images = len([1 for x in content if x['type'] == 'image'])
            prompt = ''.join([f'{IMAGE_TOKEN}\n'] * n_images) + prompt[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/interns1_pro.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from transformers import AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import Modality
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


def check_transformers():
    try:
        from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration  # noqa: F401
    except ImportError:
        raise ImportError('please install latest transformers by '
                          'pip install git+https://github.com/huggingface/transformers.git')


@VISION_MODELS.register_module()
class InternS1ProVisionModel(VisionModel):
    """InternS1Pro model.

    Basically the same preprocessing as Qwen3VL, but with Time Series support.
    """

    _arch = ['InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration']

    def build_preprocessor(self):
        check_transformers()
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)

        # image tokens
        self.image_token = self.processor.image_token
        self.image_token_id = self.processor.image_token_id

        # video tokens
        self.video_token = self.processor.video_token
        self.video_token_id = self.processor.video_token_id

        # time series tokens
        self.ts_token = getattr(self.processor, 'ts_token', None)
        self.ts_token_id = getattr(self.processor, 'ts_token_id', None)

        # vision start and end tokens
        self.vision_start_token = self.processor.vision_start_token
        self.vision_end_token = self.processor.vision_end_token

    def get_processor_args(self, mm_processor_kwargs: Optional[Dict[str, Any]] = None):
        min_pixels = self.processor.image_processor.size['shortest_edge']
        max_pixels = self.processor.image_processor.size['longest_edge']

        if mm_processor_kwargs is None:
            return min_pixels, max_pixels

        input_min_pixels = mm_processor_kwargs.get('min_pixels', None)
        input_max_pixels = mm_processor_kwargs.get('max_pixels', None)

        # boundary check for min_pixels and max_pixels
        if input_min_pixels is None:
            if input_max_pixels is not None:
                # only max_pixels is given in the input
                if input_max_pixels < min_pixels:
                    logger.warning(
                        f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.')
                    return min_pixels, max_pixels
                max_pixels = input_max_pixels
        else:
            if input_max_pixels is None:
                # only min_pixels is given in the input
                if input_min_pixels > max_pixels:
                    logger.warning(
                        f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.')
                    return min_pixels, max_pixels
            else:
                if input_min_pixels > input_max_pixels:
                    logger.warning(
                        f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.')
                    return min_pixels, max_pixels
                max_pixels = input_max_pixels
            min_pixels = input_min_pixels

        return min_pixels, max_pixels

    def check_time_series_input(self, messages):
        has_time_series_input = any(
            isinstance(message['content'], list) and any(item['type'] == 'time_series' for item in message['content'])
            for message in messages)
        self.has_time_series_input = has_time_series_input

    def _preprocess_image(self,
                          data: List[Any],
                          params: Dict[str, Any],
                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:

        image = data.convert('RGB')
        min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs)

        result = self.processor.image_processor(images=image,
                                                size={
                                                    'shortest_edge': min_pixels,
                                                    'longest_edge': max_pixels
                                                },
                                                return_tensors='pt')
        merge_length = self.processor.image_processor.merge_size**2
        image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
        result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
        return result

    def _preprocess_video(self,
                          data: List[Any],
                          params: Dict[str, Any],
                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:

        # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs
        metadata = params['video_metadata']
        video_kwargs = dict(return_metadata=True,
                            do_resize=True,
                            do_sample_frames=False,
                            video_metadata=metadata,
                            return_tensors='pt')
        result = self.processor.video_processor(videos=data, **video_kwargs)
        video_grid_thw = result['video_grid_thw']

        merge_length = self.processor.video_processor.merge_size**2
        if metadata.get('fps') is None:
            logger.warning_once('Qwen3VL: fps not found, defaulting to 24.')
            metadata['fps'] = metadata['fps'] or 24

        # if timestamps are not provided, calculate them
        curr_timestamp = self.processor._calculate_timestamps(
            metadata['frames_indices'],
            metadata['fps'],
            self.processor.video_processor.merge_size,
        )

        frame_seqlen = video_grid_thw[0][1:].prod() // merge_length
        result.update(curr_timestamp=curr_timestamp, frame_seqlen=frame_seqlen, video_token_id=self.video_token_id)
        return result

    def _preprocess_time_series(self,
                                data: List[Any],
                                params: Dict[str, Any],
                                mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:

        ts_input = data
        sr = params.get('sampling_rate') if params is not None else None

        if not isinstance(ts_input, np.ndarray):
            ts_input = np.array(ts_input, dtype=np.float32)

        mean = ts_input.mean(axis=0, keepdims=True)
        std = ts_input.std(axis=0, keepdims=True)
        ts_input = (ts_input - mean) / (std + 1e-8)

        # truncate to 240k to avoid OOM
        max_ts_len = 240000
        if len(ts_input) > max_ts_len:
            ts_input = ts_input[:max_ts_len]

        if ts_input.ndim == 1:
            ts_input = ts_input[:, None]  # [T,C]

        ts_len = ts_input.shape[0]

        # set the default value to ts_len / 4 if sr is not provided or invalid
        if sr is None or sr <= 0:
            sr = max(ts_len / 4, 1.0)

        # compute num ts tokens
        stride = np.floor(160 / ((1 + np.exp(-sr / 100))**6))
        patch_size = stride * 2
        embed_length = (np.ceil((ts_len - patch_size) / stride) + 1)
        ts_tokens = int((embed_length // 2 + 1) // 2)

        return dict(ts_values=[ts_input],
                    ts_sr=[sr],
                    ts_lens=[ts_len],
                    ts_tokens=[ts_tokens],
                    ts_token_id=self.ts_token_id)

    def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:
        """Refer to `super().preprocess()` for spec."""
        outputs = []
        self.contains_video_input = False
        self.contains_ts_input = False

        mm_items = self.collect_multimodal_items(messages)
        for modality, data, params in mm_items:
            result = {}
            if modality == Modality.IMAGE:
                result = self._preprocess_image(data, params, mm_processor_kwargs)
            elif modality == Modality.VIDEO:
                self.contains_video_input = True
                result = self._preprocess_video(data, params, mm_processor_kwargs)
            elif modality == Modality.TIME_SERIES:
                self.contains_ts_input = True
                result = self._preprocess_time_series(data, params, mm_processor_kwargs)

            result.update(modality=modality)
            outputs.append(result)

        messages.append(dict(role='preprocess', content=outputs))
        return messages

    def proc_messages(self,
                      messages,
                      chat_template,
                      sequence_start,
                      tools: Optional[List[object]] = None,
                      chat_template_kwargs=None):
        """Apply chat template to get the prompt."""
        chat_template_kwargs = chat_template_kwargs or {}
        prompt_messages = []
        IMAGE_TOKEN = ''
        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]

        if VisionModel.IMAGE_TOKEN_included(messages):
            # backward compatibility
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                content = [x['text'] for x in content if x['type'] == 'text']
                prompt = ''.join(content)
                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')
                prompt_messages.append(dict(role='user', content=prompt))
        else:
            prompt_messages = messages

        # time series input requires enabling_thinking = False
        if self.contains_ts_input:
            chat_template_kwargs['enable_thinking'] = False

        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, tools=tools, **chat_template_kwargs)
        return prompt, None

    def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start):
        """Pack the video input to the compatible format with pytorch
        engine."""

        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        # split prompt into segments and validate data
        segs = prompt.split(self.vision_start_token + self.video_token + self.vision_end_token)
        assert len(segs) == len(preps) + 1, (f'the number of {self.video_token} is not equal '
                                             f'to input videos, {len(segs) - 1} vs {len(preps)}')

        # calculate the video token offset for each video
        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                preps[i - 1].update(offset=len(input_ids))
                frame_seqlen = preps[i - 1]['frame_seqlen']
                assert self.video_token_id == preps[i - 1]['video_token_id']

                video_grid_thw = preps[i - 1]['video_grid_thw']
                curr_timestamp = preps[i - 1]['curr_timestamp']

                # update prompt with timestamp index tokens and video pad tokens
                video_placeholder = ''
                for frame_idx in range(video_grid_thw[0][0]):
                    curr_time = curr_timestamp[frame_idx]
                    video_placeholder += f'<{curr_time:.1f} seconds>'
                    video_placeholder += (self.vision_start_token + '<|placeholder|>' * frame_seqlen +
                                          self.vision_end_token)

                video_placeholder = video_placeholder.replace('<|placeholder|>', self.video_token)
                video_token_ids = tokenizer.encode(video_placeholder)
                input_ids.extend(video_token_ids)

                preps[i - 1].update(video_tokens=len(video_token_ids))

            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(token_ids)

        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)

    def to_pytorch_aux_ts(self, messages, prompt, TS_TOKEN, tokenizer, sequence_start):
        """Pack the time series input to the compatible format with pytorch
        engine."""
        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        # split prompt into segments and validate data
        segs = prompt.split(TS_TOKEN)
        assert len(segs) == len(preps) + 1, (f'the number of {TS_TOKEN} is not equal '
                                             f'to input time series data, {len(segs) - 1} vs {len(preps)}')

        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                preps[i - 1].update(offset=len(input_ids))
                ts_tokens = preps[i - 1]['ts_tokens']

                ts_tokens = ts_tokens[0]
                ts_array = np.array(preps[i - 1]['ts_values'])

                preps[i - 1].update(ts_tokens=ts_tokens)
                preps[i - 1].update(ts_values=torch.from_numpy(ts_array).to(dtype=torch.bfloat16))
                preps[i - 1].update(ts_lens=torch.tensor(preps[i - 1]['ts_lens']))
                preps[i - 1].update(ts_sr=torch.tensor(preps[i - 1]['ts_sr']))

                assert self.ts_token_id == preps[i - 1]['ts_token_id']
                input_ids.extend([self.ts_token_id] * ts_tokens)
            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(token_ids)

        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)

    def to_pytorch(self,
                   messages,
                   chat_template,
                   tokenizer,
                   sequence_start,
                   tools: Optional[List[object]] = None,
                   chat_template_kwargs: Optional[Dict] = None,
                   **kwargs):
        """Return to the information needed by pytorch engine."""
        prompt, _ = self.proc_messages(messages,
                                       chat_template,
                                       sequence_start,
                                       tools=tools,
                                       chat_template_kwargs=chat_template_kwargs)

        if self.contains_video_input:
            return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start)
        elif self.contains_ts_input:
            return self.to_pytorch_aux_ts(messages, prompt, self.ts_token, tokenizer, sequence_start)
        else:
            return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start)

    def build_model(self):
        # TODO: implement for turbomind
        pass

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        # TODO: implement for turbomind
        pass

    def to_turbomind(self,
                     messages,
                     chat_template,
                     tokenizer,
                     sequence_start,
                     chat_template_kwargs: Optional[Dict] = None,
                     **kwargs):
        # TODO: implement for turbomind
        pass


================================================
FILE: lmdeploy/vl/model/internvl.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """copy from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5."""
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    """copy from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5."""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


@VISION_MODELS.register_module()
class InternVLVisionModel(VisionModel):
    """InternVL vision model."""

    _arch = 'InternVLChatModel'

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        super().__init__(model_path, with_llm, max_memory, hf_config, backend)
        self.image_token = ''
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
        self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)

    def build_preprocessor(self):
        self.config = self.hf_config
        dynamic_image_size = getattr(self.config, 'dynamic_image_size', False)
        image_processor = None
        try:
            image_processor = CLIPImageProcessor.from_pretrained(self.model_path)
        except OSError:
            pass

        if dynamic_image_size or image_processor is None:
            logger.info('using InternVL-Chat-V1-5 vision preprocess')
            MEAN = (0.485, 0.456, 0.406)
            STD = (0.229, 0.224, 0.225)
            import torchvision.transforms as T
            from torchvision.transforms.functional import InterpolationMode
            input_size = self.config.vision_config.image_size
            self.transform = T.Compose([
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
                T.ToTensor(),
                T.Normalize(mean=MEAN, std=STD)
            ])
            self.processor = self._preprocess_v1_5
            self._forward_func = self._forward_v1_5
        else:
            self.processor = self._preprocess
            self.image_processor = image_processor
            self._forward_func = self._forward

        force_image_size = self.hf_config.force_image_size
        patch_size = self.hf_config.vision_config.patch_size
        downsample_ratio = self.hf_config.downsample_ratio
        self.image_tokens_per_patch = int((force_image_size // patch_size)**2 * (downsample_ratio**2))

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights():
            # transformers below 4.37.0 may raise error about flash_attn
            self.config.llm_config.attn_implementation = 'eager'
            model = AutoModel.from_config(self.config, trust_remote_code=True)
            self.vl_model = model
            if not self.with_llm:
                del model.language_model

        model.half()
        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         max_memory=self.max_memory,
                                         no_split_module_classes=['InternVisionEncoderLayer'],
                                         dtype=torch.half)

        # We need eval mode to freeze the weights in model, thus,
        # avoid randomness in inference.
        self.model = model.eval()

    def _preprocess_v1_5(self, image, params=None):
        image_res = {'low': 6, 'medium': 12, 'high': 24}
        max_num = params.get('max_dynamic_patch')
        if max_num is None or not isinstance(max_num, int):
            res_key = params.get('detail', 'default')
            max_num = image_res.get(res_key, self.config.max_dynamic_patch)
        out = dynamic_preprocess(image,
                                 min_num=self.config.min_dynamic_patch,
                                 max_num=max_num,
                                 image_size=self.config.vision_config.image_size,
                                 use_thumbnail=self.config.use_thumbnail)
        pixel_values = [self.transform(x) for x in out]
        # (patch) x c x h x w
        pixel_values = torch.stack(pixel_values)
        return pixel_values

    def _forward_v1_5(self, inputs, max_batch_size):
        """Forward for internvl-chat-v1-5."""
        assert all(x.get('pixel_values') is not None for x in inputs)
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            split = [x.shape[0] for x in pixel_values]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            feats = self.model.extract_feature(pixel_values)
            feats = torch.split(feats, split, dim=0)
            outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
        return outputs

    def _preprocess(self, image, params=None):
        """Forward for internvl-chat-v1-1, internvl-chat-v1-2."""
        pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values
        return pixel_values

    def _forward(self, inputs, max_batch_size):
        """Forward for internvl-chat-v1-1, internvl-chat-v1-2."""
        assert all(x.get('pixel_values') is not None for x in inputs)
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            feats = self.model.extract_feature(pixel_values)
            feats = torch.split(feats, 1, dim=0)
            outputs.extend([x.squeeze() for x in feats])
        return outputs

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values = self.processor(image, params)
            image_tokens = (pixel_values.shape[0] * self.image_tokens_per_patch)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_tokens=image_tokens,
                     image_token_id=self.image_token_id,
                     image_size=image.size))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = self._forward_func(inputs, max_batch_size)
        messages.append(dict(role='forward', content=outputs))
        return messages

    def proc_messages(
        self,
        messages,
        chat_template,
        sequence_start,
        tools: Optional[List[object]] = None,
        chat_template_kwargs: Optional[Dict] = None,
    ):
        chat_template_kwargs = chat_template_kwargs or {}
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]
        if VisionModel.IMAGE_TOKEN_included(messages):
            # backward compatibility
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                content = [x['text'] for x in content if x['type'] == 'text']
                prompt = ''.join(content)
                prompt = prompt.replace(f'{IMAGE_TOKEN}', f'{self.image_token}')
                prompt_messages.append(dict(role='user', content=prompt))
        else:
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                _content = []
                for item in content:
                    item_type = item['type']
                    if item_type == 'text':
                        _content.append(item['text'])
                    elif item_type in ['image', 'image_url']:
                        _content.append(f'{self.image_token}\n')
                    else:
                        raise ValueError(f'Unsupported message type: {item["type"]}')
                prompt_messages.append(dict(role='user', content=''.join(_content)))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, tools=tools, **chat_template_kwargs)
        return prompt, self.image_token

    def to_pytorch(self,
                   messages,
                   chat_template,
                   tokenizer,
                   sequence_start,
                   tools: Optional[List[object]] = None,
                   chat_template_kwargs: Optional[Dict] = None,
                   **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages,
                                                 chat_template,
                                                 sequence_start,
                                                 tools=tools,
                                                 chat_template_kwargs=chat_template_kwargs)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self,
                     messages,
                     chat_template,
                     tokenizer,
                     sequence_start,
                     tools: Optional[List[object]] = None,
                     chat_template_kwargs: Optional[Dict] = None,
                     **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages,
                                                 chat_template,
                                                 sequence_start,
                                                 tools=tools,
                                                 chat_template_kwargs=chat_template_kwargs)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/internvl3_hf.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.internvl import VISION_MODELS, InternVLVisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


class InternVLImagesKwargs(ImagesKwargs, total=False):
    crop_to_patches: Optional[bool]
    min_patches: Optional[int]
    max_patches: Optional[int]


class InternVLProcessorKwargs(ProcessingKwargs, total=False):
    images_kwargs: InternVLImagesKwargs
    _defaults = {
        'text_kwargs': {
            'padding': False,
        },
        'images_kwargs': {
            'crop_to_patches': True,
        },
        'videos_kwargs': {},
    }


@VISION_MODELS.register_module()
class InternVL3VisionModel(InternVLVisionModel):
    """Internvl3 vision model."""

    _arch = ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration']

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        super().__init__(model_path, with_llm, max_memory, hf_config, backend)
        self.arch = self.hf_config.architectures[0]

    def build_preprocessor(self):
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        tokenizer = self.processor.tokenizer
        self.image_token = self.processor.image_token
        self.image_token_id = tokenizer.context_image_token_id
        self.image_tokens_per_patch = self.processor.image_seq_length
        self.tokenizer_init_kwargs = tokenizer.init_kwargs

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights():
            if self.arch == 'InternVLForConditionalGeneration':
                model = AutoModel.from_config(self.hf_config, trust_remote_code=True)
                if not self.with_llm:
                    del model.language_model
            elif self.arch == 'InternS1ForConditionalGeneration':
                model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True)
                if not self.with_llm:
                    del model.model.language_model
            else:
                raise ValueError(f'unsupported model arch {self.arch}')

        model.half()
        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         max_memory=self.max_memory,
                                         no_split_module_classes=['InternVLVisionLayer', 'InternS1VisionLayer'],
                                         dtype=torch.half)
        # We need eval mode to freeze the weights in model, thus,
        # avoid randomness in inference.
        self.model = model.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        from transformers.image_utils import make_flat_list_of_images
        output_kwargs = self.processor._merge_kwargs(
            InternVLProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer_init_kwargs,
            **{
                'return_tensors': 'pt',
                'add_special_tokens': False
            },
        )
        images = self.collect_multimodal_items(messages)
        images = [image.convert('RGB') for modality, image, _ in images]
        num_image = len(images)
        images = make_flat_list_of_images(images)
        image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs'])
        image_num_patches = image_inputs.pop('num_patches').cpu().numpy().tolist()
        image_pixel_values = image_inputs.pop('pixel_values')
        outputs = []
        cum_num_patches = 0
        for idx in range(num_image):
            cur_num_patches = image_num_patches[idx]
            pixel_values = image_pixel_values[cum_num_patches:cum_num_patches + cur_num_patches, ...]
            cum_num_patches += cur_num_patches
            data = dict(pixel_values=pixel_values,
                        image_tokens=self.image_tokens_per_patch * cur_num_patches,
                        image_token_id=self.image_token_id)
            outputs.append(data)

        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        assert all(x.get('pixel_values') is not None for x in inputs)
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            split = [x.shape[0] for x in pixel_values]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            feats = self.model.get_image_features(
                pixel_values,
                vision_feature_layer=self.hf_config.vision_feature_layer,
                vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy,
            )
            feats = torch.split(feats, split, dim=0)
            outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
        messages.append(dict(role='forward', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/model/internvl_llava.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import warnings
from contextlib import contextmanager
from typing import Dict, List

import torch
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel
from lmdeploy.vl.model.utils import rewrite_ctx

from .utils import disable_logging, disable_transformers_logging

logger = get_logger('lmdeploy')


def check_llava_install():
    try:
        from llava.model.multimodal_encoder.clip_encoder import InternVisionModel  # noqa: F401
    except ImportError:
        raise ImportError(
            'To use LlavaVLModel, please install llava by '
            '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`')


def _intern_vision_model__from_pretrained(vision_tower_name: str):
    logger.info(f'init empty InternVisionModel: {vision_tower_name}')
    from llava.model.multimodal_encoder.intern_vit_6b.modeling_intern_vit import InternVisionConfig, InternVisionModel
    config = InternVisionConfig.from_pretrained(vision_tower_name)
    model = InternVisionModel._from_config(config)
    model.requires_grad_(False)
    return model


def _intern_vl_model__from_pretrained(vision_tower_name: str):
    logger.info(f'init empty InternVLModel: {vision_tower_name}')

    from llava.model.multimodal_encoder.internvl_14b.modeling_internvl import InternVLConfig, InternVLModel

    config = InternVLConfig.from_pretrained(vision_tower_name)
    model = InternVLModel._from_config(config)
    model.requires_grad_(False)
    return model


@contextmanager
def init_empty_vit():
    """Skip download vision model if possible."""
    origin_func_path = [
        'llava.model.multimodal_encoder.intern_vit_6b.modeling_intern_vit.InternVisionModel.from_pretrained',  # noqa: E501
        'llava.model.multimodal_encoder.internvl_14b.modeling_internvl.InternVLModel.from_pretrained',  # noqa: E501
    ]
    rewrite_func = [_intern_vision_model__from_pretrained, _intern_vl_model__from_pretrained]
    with rewrite_ctx(origin_func_path, rewrite_func):
        yield


@VISION_MODELS.register_module()
class InternVLLlavaVisionModel(LlavaVisionModel):
    """Llava visual model."""

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch == 'LlavaLlamaForCausalLM':
            mm_vision_tower = getattr(config, 'mm_vision_tower', '')
            if 'OpenGVLab' in mm_vision_tower:
                return True
        return False

    def build_preprocessor(self):
        return super().build_preprocessor()

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        check_llava_install()
        # currently, only support llava llama
        from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM  # noqa
        self.config = LlavaConfig.from_pretrained(self.model_path)
        assert self.config.model_type in ['llava', 'llava_llama'], \
            'currently, only support llava llama'

        # init empty model, skip layer initialization
        from accelerate import init_empty_weights
        with init_empty_weights(), warnings.catch_warnings(), \
                disable_transformers_logging():
            warnings.simplefilter('ignore')
            self.config.quantization_config = {}  # disable vision part quantization
            model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
            self.vl_model = model
            if not self.with_llm:
                del model.lm_head
                del model.model.embed_tokens
                del model.model.layers
                del model.model.norm

            with init_empty_vit():
                vision_tower = model.get_vision_tower()
                vision_tower.is_loaded = False
                vision_tower.load_model()
            crop_size = vision_tower.image_processor.crop_size['height']
            image_size = vision_tower.config.image_size
            patch_size = vision_tower.config.patch_size
            if crop_size != image_size:
                vision_tower.vision_tower.resize_pos_embeddings(image_size, crop_size, patch_size)
                vision_tower.vision_tower.embeddings.image_size = crop_size
                vision_tower.config.image_size = crop_size
                vision_tower.image_processor.crop_size = dict(height=crop_size, width=crop_size)
                vision_tower.image_processor.size = dict(shortest_edge=crop_size)

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         max_memory=self.max_memory,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=['InternVisionEncoderLayer'],
                                         dtype=torch.half)

        self.model = model.model.eval()
        self.vision_tower = model.model.vision_tower.eval()
        self.mm_projector = model.model.mm_projector.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess() for spec."""
        return super().preprocess(messages)

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            split_sizes = [x.shape[0] for x in pixel_values]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            if pixel_values.ndim == 5:
                feats = self.encode_images(pixel_values)
                feats = torch.split(feats, split_sizes, dim=0)
                feats = [x.flatten(0, 1) for x in feats]
            else:
                feats = self.encode_images(pixel_values)
                feats = [x for x in feats]
            outputs.extend(feats)
        messages.append(dict(role='forward', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/model/llama4.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

import torch
from transformers import AutoConfig

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


def check_trans_version():
    """Check if the installed version of the 'transformers' library is smaller
    than the specified version."""
    import transformers
    from packaging import version

    min_version = '4.51.0'
    installed_version = transformers.__version__
    assert version.parse(installed_version) >= version.parse(min_version), (
        f'llama4 requires transformers version >= {min_version}, '
        f'but found version: {installed_version}. Please upgrade.')


@VISION_MODELS.register_module()
class LLama4VisionModel(VisionModel):
    """Llama4 vision model."""

    _arch = 'Llama4ForConditionalGeneration'

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0]
        return arch == cls._arch

    def build_preprocessor(self):
        check_trans_version()
        from transformers.models.llama4 import Llama4Processor
        from transformers.models.llama4.processing_llama4 import Llama4ProcessorKwargs
        self.processor = Llama4Processor.from_pretrained(
            self.model_path,
            padding_side='left',
        )
        img_patch_token = self.processor.img_patch_token
        self.image_token_id = self.processor.tokenizer.encode(img_patch_token, add_special_tokens=False)[0]
        self.images_kwargs = self.processor._merge_kwargs(
            Llama4ProcessorKwargs,
            tokenizer_init_kwargs=self.processor.tokenizer.init_kwargs,
            return_tensors='pt',
            add_special_tokens=False,
        )['images_kwargs']

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        # TODO, implement for tubomind engine
        raise NotImplementedError()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        processor = self.processor
        patch_size = processor.patch_size
        downsample_ratio = processor.downsample_ratio
        images_kwargs = self.images_kwargs
        for modality, image, params in images:
            image_inputs = processor.image_processor(images=[image], **images_kwargs)
            pixel_values = image_inputs['pixel_values']
            image_height, image_width = image_inputs['pixel_values'][0].shape[-2:]
            num_patches_per_chunk = int((image_height // patch_size) * (image_width // patch_size) // downsample_ratio)
            aspect_ratios = image_inputs.pop('aspect_ratios')
            image_prompts = processor._prompt_split_image(aspect_ratios[0], num_patches_per_chunk)
            image_tokens = image_prompts.count('<|') - 2
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_tokens=image_tokens,
                     image_token_id=self.image_token_id,
                     image_size=image.size,
                     image_prompts=image_prompts))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        # TODO, implement for turbomind engine
        raise NotImplementedError()

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            prompt = content[0]
            if IMAGE_TOKEN not in prompt:
                prompt = f'{IMAGE_TOKEN * n_images}' + prompt
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
        """Auxiliary function to pack the preprocessing results in a format
        compatible with what is required by pytorch engine.

        Args:
            messages(List[Dict]): the output of `preprocess`
            prompt(str): the prompt after applying chat template
            IMAGE_TOKEN(str): a placeholder where image tokens will be
                inserted
            tokenzer: the tokenizer model
            sequence_start: starting flag of a sequence
        """
        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        # split prompt into segments and validate data
        segs = prompt.split(IMAGE_TOKEN)
        assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal '
                                             f'to input images, {len(segs) - 1} vs {len(preps)}')

        # calculate the image token offset for each image
        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                prep = preps[i - 1]
                image_prompts = prep.pop('image_prompts', '')
                prep.update(offset=len(input_ids) + 1)
                assert self.image_token_id == prep['image_token_id']
                seg = image_prompts + seg
            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(token_ids)
        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/llava.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/haotian-liu/LLaVA.git

import ast
import math
import warnings
from contextlib import contextmanager
from typing import Dict, List

import torch
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging, rewrite_ctx

logger = get_logger('lmdeploy')


def check_llava_install():
    """Check llava install."""
    try:
        import llava  # noqa: F401
    except ImportError:
        raise ImportError('To use LlavaVLModel, please install llava by '
                          '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`'  # noqa: E501
                          )


def _clip_vision_tower_load_model(self, **kwargs):
    logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')
    from transformers import CLIPVisionConfig, CLIPVisionModel

    config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
    self.vision_tower = CLIPVisionModel._from_config(config=config)
    self.vision_tower.requires_grad_(False)
    self.is_loaded = True


@contextmanager
def init_llava_vision_tower(config):
    """Skip download vision model if possible."""
    if getattr(config, 'unfreeze_mm_vision_tower', False):
        origin_func_path = [
            'llava.model.multimodal_encoder.clip_encoder.CLIPVisionTower.load_model'  # noqa: E501
        ]
        rewrite_func = [_clip_vision_tower_load_model]
        with rewrite_ctx(origin_func_path, rewrite_func):
            yield
    else:
        yield


def select_best_resolution(original_size, possible_resolutions):
    """Selects the best resolution from a list of possible resolutions based on
    the original size.

    Args:
        original_size (tuple): The original size of the image in the format (width, height).
        possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].

    Returns:
        tuple: The best fit resolution in the format (width, height).
    """  # noqa
    original_width, original_height = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float('inf')

    for width, height in possible_resolutions:
        scale = min(width / original_width, height / original_height)
        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution
                                                               and wasted_resolution < min_wasted_resolution):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (width, height)

    return best_fit


def resize_and_pad_image(image, target_resolution):
    """Resize and pad an image to a target resolution while maintaining aspect
    ratio.

    Args:
        image (PIL.Image.Image): The input image.
        target_resolution (tuple): The target resolution (width, height) of the image.

    Returns:
        PIL.Image.Image: The resized and padded image.
    """  # noqa
    original_width, original_height = image.size
    target_width, target_height = target_resolution

    scale_w = target_width / original_width
    scale_h = target_height / original_height

    if scale_w < scale_h:
        new_width = target_width
        new_height = min(math.ceil(original_height * scale_w), target_height)
    else:
        new_height = target_height
        new_width = min(math.ceil(original_width * scale_h), target_width)

    # Resize the image
    resized_image = image.resize((new_width, new_height))

    new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
    paste_x = (target_width - new_width) // 2
    paste_y = (target_height - new_height) // 2
    new_image.paste(resized_image, (paste_x, paste_y))

    return new_image


def divide_to_patches(image, patch_size):
    """Divides an image into patches of a specified size.

    Args:
        image (PIL.Image.Image): The input image.
        patch_size (int): The size of each patch.

    Returns:
        list: A list of PIL.Image.Image objects representing the patches.
    """
    patches = []
    width, height = image.size
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            box = (j, i, j + patch_size, i + patch_size)
            patch = image.crop(box)
            patches.append(patch)

    return patches


def process_anyres_image(image, processor, grid_pinpoints):
    """Process an image with variable resolutions.

    Args:
        image (PIL.Image.Image): The input image to be processed.
        processor: The image processor object.
        grid_pinpoints (str): A string representation of a list of possible resolutions.

    Returns:
        torch.Tensor: A tensor containing the processed image patches.
    """  # noqa
    if type(grid_pinpoints) is list:
        possible_resolutions = grid_pinpoints
    else:
        possible_resolutions = ast.literal_eval(grid_pinpoints)
    best_resolution = select_best_resolution(image.size, possible_resolutions)
    image_padded = resize_and_pad_image(image, best_resolution)

    patches = divide_to_patches(image_padded, processor.crop_size['height'])

    image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))

    image_patches = [image_original_resize] + patches
    image_patches = [
        processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] for image_patch in image_patches
    ]
    return torch.stack(image_patches, dim=0)


def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result


def process_images(images, image_processor, model_cfg):
    image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
    new_images = []
    if image_aspect_ratio == 'pad':
        for image in images:
            image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
            image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            new_images.append(image)
    elif image_aspect_ratio == 'anyres':
        for image in images:
            image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
            new_images.append(image)
    else:
        return image_processor(images, return_tensors='pt')['pixel_values']
    if all(x.shape == new_images[0].shape for x in new_images):
        new_images = torch.stack(new_images, dim=0)
    return new_images


@VISION_MODELS.register_module()
class LlavaVisionModel(LlavaHfVisionModel):
    """Llava visual model."""

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:
            # internvl-llava has vision_tower of OpenGVLab/xxx
            mm_vision_tower = getattr(config, 'mm_vision_tower', '')
            # yi-vl has projector type of xxx_Norm
            projector_type = getattr(config, 'mm_projector_type', 'linear')
            if '_Norm' in projector_type:
                return False
            if 'OpenGVLab' in mm_vision_tower:
                return False
            return True
        return False

    def build_preprocessor(self):
        from transformers import CLIPImageProcessor
        self.image_processor = CLIPImageProcessor.from_pretrained(self.hf_config.mm_vision_tower)
        config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower)
        image_size = config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.n_token_per_image = (image_size // patch_size)**2
        if self.hf_config.mm_vision_select_feature == 'cls_patch':
            self.n_token_per_image += 1

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        check_llava_install()

        self.arch = self.hf_config.architectures[0]
        model = None
        if self.arch == 'LlavaLlamaForCausalLM':
            from llava.model.language_model.llava_llama import LlavaConfig
            self.config = LlavaConfig.from_pretrained(self.model_path)
            assert self.config.model_type in ['llava', 'llava_llama'], \
                f'expect model_type llava and llava_llama '\
                f'but got {self.config.model_type}'
        elif self.arch == 'LlavaMistralForCausalLM':
            from llava.model.language_model.llava_mistral import LlavaMistralConfig
            self.config = LlavaMistralConfig.from_pretrained(self.model_path)
        else:
            assert 0, f'unsupported arch {self.arch}'

        from accelerate import init_empty_weights

        # init empty model, skip layer initialization
        with init_empty_weights(), warnings.catch_warnings(), \
                init_llava_vision_tower(self.config):
            warnings.simplefilter('ignore')
            self.config.quantization_config = {}  # disable vision part quantization
            model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)

        self.vl_model = model
        if not self.with_llm:
            # remove the LLM part from llava model.
            del model.lm_head
            del model.model.embed_tokens
            del model.model.layers
            del model.model.norm

        # init empty vision_tower, the embedding layer in CLIPVisionModel
        # can't init right under init_empty_weights
        with init_llava_vision_tower(self.config):
            vision_tower = model.get_vision_tower()
            vision_tower.is_loaded = False
            vision_tower.load_model()
            # for llava-v1.5, the vit is not in llm ckpt
            vision_tower.to(dtype=torch.half)

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         max_memory=self.max_memory,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=['CLIPEncoderLayer'],
                                         dtype=torch.half)

        self.model = model.model.eval()
        self.vision_tower = model.model.vision_tower.half().eval()
        self.mm_projector = model.model.mm_projector.half().eval()

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """Encode images."""
        image_features = self.vision_tower(images)
        image_features = self.mm_projector(image_features)
        return image_features

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values = process_images([image], self.image_processor, self.config)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=self.n_token_per_image,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """

        from llava.model.llava_arch import get_anyres_image_grid_shape, unpad_image

        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            image_sizes = [x['image_size'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            if pixel_values[0].ndim == 5:
                split_sizes = [x.shape[1] for x in pixel_values]
                pixel_values = torch.cat([x for x in pixel_values], dim=1)
                logger.info(f'vision forward shape: {pixel_values.shape}')
                pixel_values = pixel_values.squeeze(0)
                pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)
                feats = self.encode_images(pixel_values)
                feats = torch.split(feats, split_sizes, dim=0)
                mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
                image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
                if mm_patch_merge_type == 'flat':
                    outputs.expand([x.flatten(0, 1) for x in feats])
                elif mm_patch_merge_type.startswith('spatial'):
                    for img_idx, feat in enumerate(feats):
                        if feat.shape[0] > 1:
                            base_feat = feat[0]
                            feat = feat[1:]
                            height = self.vision_tower.num_patches_per_side
                            width = self.vision_tower.num_patches_per_side
                            assert height * width == base_feat.shape[0]
                            if image_aspect_ratio == 'anyres':
                                num_patch_width, num_patch_height = \
                                    get_anyres_image_grid_shape(
                                        image_sizes[img_idx],
                                        self.config.image_grid_pinpoints,
                                        self.vision_tower.config.image_size)
                                feat = feat.view(num_patch_height, num_patch_width, height, width, -1)
                            else:
                                raise NotImplementedError
                            if 'unpad' in mm_patch_merge_type:
                                feat = feat.permute(4, 0, 2, 1, 3).contiguous()
                                feat = feat.flatten(1, 2).flatten(2, 3)
                                feat = unpad_image(feat, image_sizes[img_idx])
                                feat = torch.cat((feat, self.model.image_newline[:, None, None].expand(
                                    *feat.shape[:-1], 1).to(feat.device)),
                                                 dim=-1)
                                feat = feat.flatten(1, 2).transpose(0, 1)
                            else:
                                feat = feat.permute(0, 2, 1, 3, 4).contiguous()
                                feat = feat.flatten(0, 3)
                            feat = torch.cat((base_feat, feat), dim=0)
                        else:
                            feat = feat[0]
                            if 'unpad' in mm_patch_merge_type:
                                feat = torch.cat((feat, self.model.image_newline[None].to(feat.device)), dim=0)
                        outputs.append(feat)
                else:
                    raise ValueError('Unexpected mm_patch_merge_type: '
                                     f'{self.config.mm_patch_merge_type}')
            else:
                pixel_values = torch.cat(pixel_values, dim=0)
                pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)
                logger.info(f'vision forward shape: {pixel_values.shape}')
                feats = self.encode_images(pixel_values)
                outputs.extend([x for x in feats])
        messages.append(dict(role='forward', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/model/llava_hf.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List

import torch
from transformers import AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class LlavaHfVisionModel(VisionModel):
    """Llava hf vision model."""

    _arch = 'LlavaForConditionalGeneration'

    def build_preprocessor(self):
        processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        if hasattr(processor, 'tokenizer'):
            del processor.tokenizer
            processor.prtokenizer = None
        self.processor = processor.image_processor
        image_size = self.hf_config.vision_config.image_size
        patch_size = self.hf_config.vision_config.patch_size
        self.n_token_per_image = (image_size // patch_size)**2
        if self.hf_config.vision_feature_select_strategy == 'full':
            self.n_token_per_image += 1

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights, load_checkpoint_and_dispatch

        with init_empty_weights(), warnings.catch_warnings():
            warnings.simplefilter('ignore')
            from transformers import LlavaForConditionalGeneration
            model = LlavaForConditionalGeneration._from_config(self.hf_config)
            self.vl_model = model
            if not self.with_llm:
                del model.language_model

        # fix for llava-hf/llava-interleave-qwen-7b-hf
        setattr(model.config, 'tie_word_embeddings', False)
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         max_memory=self.max_memory,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=['CLIPEncoderLayer', 'SiglipEncoderLayer'],
                                         dtype=torch.half)
        model.eval()
        self.model = model

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values = self.processor(image, return_tensors='pt', input_data_format='channels_last').pixel_values
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=self.n_token_per_image,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = torch.cat(pixel_values, dim=0)
            pixel_values = pixel_values.to(device=self.model.device, dtype=self.model.dtype)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            image_outputs = self.model.vision_tower.forward(pixel_values, output_hidden_states=True)
            image_features = image_outputs.hidden_states[self.hf_config.vision_feature_layer]
            if self.hf_config.vision_feature_select_strategy == 'default':
                image_features = image_features[:, 1:]
            elif self.hf_config.vision_feature_select_strategy == 'full':
                image_features = image_features
            else:
                raise ValueError('Unexpected select feature strategy: '
                                 f'{self.hf_config.vision_feature_select_strategy}')
            image_features = self.model.multi_modal_projector(image_features)
            image_features = torch.split(image_features, 1, dim=0)
            outputs.extend([x.squeeze() for x in image_features])
        messages.append(dict(role='forward', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [item['text'] for item in message['content'] if item['type'] == 'text']
            prompt = (IMAGE_TOKEN + '\n') * n_images + content[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/llava_next.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings
from typing import Dict, List

import torch

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class LlavaNextVisionModel(LlavaHfVisionModel):
    """Llava hf vision model."""

    _arch = 'LlavaNextForConditionalGeneration'

    def build_preprocessor(self):
        super().build_preprocessor()
        # build the model with empty weights. The model will be used in
        # `preprocess` to get the image token number
        from accelerate import init_empty_weights
        with init_empty_weights(), warnings.catch_warnings():
            warnings.simplefilter('ignore')
            from transformers import LlavaNextForConditionalGeneration
            self.model = LlavaNextForConditionalGeneration._from_config(self.hf_config)
            self.vl_model = self.model
            if not self.with_llm:
                del self.model.language_model

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import load_checkpoint_and_dispatch
        from accelerate.utils import get_balanced_memory, infer_auto_device_map

        no_split_module_classes = ['CLIPEncoderLayer']
        max_memory = get_balanced_memory(self.model,
                                         max_memory=self.max_memory,
                                         dtype=torch.half,
                                         no_split_module_classes=no_split_module_classes)
        device_map = infer_auto_device_map(self.model,
                                           no_split_module_classes=no_split_module_classes,
                                           max_memory=max_memory,
                                           dtype=torch.half)

        same_device_keys = [('multi_modal_projector', 'image_newline')]
        for keys in same_device_keys:
            keys = [k for k in keys if k in device_map]
            if len(keys) <= 1:
                continue
            for k in keys[1:]:
                device_map[k] = device_map[keys[0]]

        with disable_logging():
            load_checkpoint_and_dispatch(model=self.model,
                                         checkpoint=self.model_path,
                                         device_map=device_map if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=no_split_module_classes,
                                         dtype=torch.half)
        self.model.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to the spec of `super.preprocess()"""
        from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            result = self.processor(image, return_tensors='pt', input_data_format='channels_last')
            # ! infer image_num_patches from image_sizes
            image_num_patches = [
                image_size_to_num_patches(
                    image_size=imsize,
                    grid_pinpoints=self.hf_config.image_grid_pinpoints,
                    patch_size=self.hf_config.vision_config.image_size,
                ) for imsize in result['image_sizes']
            ]

            hidden_size = self.hf_config.text_config.hidden_size
            fake_image_features = torch.zeros([image_num_patches[0], self.n_token_per_image, hidden_size])
            image_sizes = result['image_sizes']
            image_newline = torch.randn(self.hf_config.text_config.hidden_size)
            strategy = self.hf_config.vision_feature_select_strategy
            _, image_tokens = self.model.pack_image_features([fake_image_features],
                                                             image_sizes,
                                                             vision_feature_select_strategy=strategy,
                                                             image_newline=image_newline)
            result.update(
                dict(image_size=image.size,
                     image_patches=image_num_patches,
                     image_tokens=image_tokens,
                     image_token_id=self.image_token_id))
            outputs.append(result)
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [
                x['pixel_values'].to(device=self.model.device, dtype=self.model.dtype)
                for x in inputs[idx:idx + max_batch_size]
            ]
            pixel_values = torch.cat(pixel_values, dim=0)
            image_sizes = [
                x['image_sizes'].to(device=self.model.device, dtype=self.model.dtype)
                for x in inputs[idx:idx + max_batch_size]
            ]
            image_sizes = torch.cat(image_sizes, dim=0)
            image_num_patches = [x['num_patch'] for x in inputs[idx:idx + max_batch_size]]
            image_num_patches = list(itertools.chain(*image_num_patches))
            # figure out if pixel_values is concatenated or stacked
            if pixel_values.dim() == 5:
                # stacking when input is
                # (batch_size, num_patches, num_channels, height, width)
                _pixel_values_list = [
                    pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
                ]
                pixel_values = torch.cat(_pixel_values_list, dim=0)
            elif pixel_values.dim() != 4:
                # otherwise has to be stacked from list of
                # (num_patches, num_channels, height, width)
                raise ValueError(f'pixel_values of shape {pixel_values.shape}, '
                                 'expect to be of 4 or 5 dimensions')
            logger.info(f'vision forward shape: {pixel_values.shape}')
            image_outputs = self.model.vision_tower.forward(pixel_values, output_hidden_states=True)
            image_features = image_outputs.hidden_states[self.hf_config.vision_feature_layer]
            strategy = self.hf_config.vision_feature_select_strategy
            if strategy == 'default':
                image_features = image_features[:, 1:]
            elif strategy == 'full':
                image_features = image_features
            else:
                raise ValueError('Unexpected select feature strategy: '
                                 f'{strategy}')
            image_features = self.model.multi_modal_projector(image_features)
            image_features = torch.split(image_features, image_num_patches, dim=0)
            image_features, feature_lens = self.model.pack_image_features(
                image_features,
                image_sizes,
                vision_feature_select_strategy=strategy,
                image_newline=self.model.image_newline,
            )
            image_features = torch.split(image_features, feature_lens.cpu().numpy().tolist(), dim=0)
            outputs.extend(image_features)
        messages.append(dict(role='forward', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/model/minicpmv.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class MiniCPMVModel(VisionModel):
    """MiniCPMV vision model."""

    _arch = 'MiniCPMV'

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        super().__init__(model_path, with_llm, max_memory, hf_config, backend)
        if not hasattr(self.hf_config, 'version'):
            raise ValueError('Can not find `version` in config.json. '
                             'Please checkout the latest model')
        version = str(self.hf_config.version)
        if version not in ['2.5', '2.6']:
            raise ValueError(f'Only support v2.5 and v2.6, but got version {version}')
        self.version = version

    def build_preprocessor(self):
        from transformers import AutoProcessor
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        self.image_processor = self.processor.image_processor
        self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5' else self._preprocess_v2_6)

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights(), warnings.catch_warnings():
            warnings.simplefilter('ignore')
            config = self.hf_config
            assert config.slice_mode is True, 'only support slice mode'
            config.quantization_config = {}  # disable vision part quantization
            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
        self.vl_model = model
        if not self.with_llm:
            del model.llm

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(
                model=model,
                max_memory=self.max_memory,
                checkpoint=self.model_path,
                device_map='auto' if not self.with_llm else {'': 'cpu'},
                no_split_module_classes=['Idefics2EncoderLayer', 'Resampler', 'SiglipEncoderLayer'],
                dtype=torch.half)

        model.resampler.pos_embed = model.resampler.pos_embed.to(device=model.resampler.proj.device)
        self.config = config
        self.model = model.eval()

    def _get_slice_image(self, image: Image):
        slice_images = []
        source_image, patches, best_grid = self.image_processor.slice_image(image)
        slice_images.append(source_image)
        if len(patches) > 0:
            for i in range(len(patches)):
                for j in range(len(patches[0])):
                    slice_images.append(patches[i][j])
        return slice_images, best_grid

    def _reshape_by_patch(self, slice_images):
        tgt_sizes = []
        patches = []
        for slice_image in slice_images:
            slice_image = self.model.transform(slice_image)
            H, W = slice_image.shape[1:]
            slice_image = slice_image.numpy()
            slice_image = self.image_processor.reshape_by_patch(slice_image)
            slice_image = torch.from_numpy(slice_image)
            patches.append(slice_image)
            H //= self.config.patch_size
            W //= self.config.patch_size
            tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32))
        return patches, tgt_sizes

    def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict:
        """Image preprocessing for MiniCPM-Llama3-V-2_5."""
        slice_images, best_grid = self._get_slice_image(image)
        # pixel_values, tgt_sizes are list of torch tensors
        pixel_values, tgt_sizes = self._reshape_by_patch(slice_images)
        num_patches = len(pixel_values)
        return dict(
            pixel_values=pixel_values,  # a list
            tgt_sizes=tgt_sizes,  # a list
            best_grid=best_grid,
            num_patches=num_patches,
            image_tokens=1,
            image_token_id=self.image_token_id)

    def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict:
        """Image preprocessing for MiniCPM-V-2_6."""
        max_slice_nums = self.image_processor.max_slice_nums
        use_image_id = self.image_processor.use_image_id
        max_slice_nums = params.get('max_slice_nums', max_slice_nums)
        use_image_id = params.get('use_image_id', use_image_id)
        outputs = self.image_processor(image, max_slice_nums=max_slice_nums)
        pixel_values = outputs['pixel_values'][0]
        num_patches = len(pixel_values)
        pixel_values = [torch.as_tensor(x) for x in pixel_values]
        tgt_sizes = outputs['tgt_sizes'][0]
        tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]
        grid = self.image_processor.get_sliced_grid(image_size=image.size, max_slice_nums=max_slice_nums)
        return dict(
            pixel_values=pixel_values,  # a list
            tgt_sizes=tgt_sizes,  # a list
            best_grid=grid,
            num_patches=num_patches,
            image_tokens=1,
            image_token_id=self.image_token_id,
            use_image_id=use_image_id)

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess() for spec."""
        outputs = []
        for i, message in enumerate(messages):
            if message['role'] != 'user' or not isinstance(message['content'], List):
                continue
            for item in message['content']:
                if item['type'] == 'image':
                    image = item['image'].convert('RGB')
                    params = {k: v for k, v in item.items() if k not in {'type', 'image'}}
                    result = self._preprocess_func(image, params)
                    outputs.append(result)
            messages[i].update(dict(preprocess=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        # collect preprocess results into a list
        inputs = []
        inputs = [x['preprocess'] for x in messages if 'preprocess' in x.keys()]
        # flatten the list
        inputs = list(itertools.chain(*inputs))
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            tgt_sizes = [x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            num_patches = [x['num_patches'] for x in inputs[idx:idx + max_batch_size]]
            # flatten the list
            tgt_sizes = list(itertools.chain(*tgt_sizes))
            pixel_values = list(itertools.chain(*pixel_values))
            pixel_values = [x.to(dtype=torch.half, device=self.model.device) for x in pixel_values]
            pixel_values = [x.flatten(end_dim=1).permute(1, 0) for x in pixel_values]
            pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0.0)
            B, L, _ = pixel_values.shape
            pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
            tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
            max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
            patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=self.model.device)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            if self.version == '2.5':
                for j in range(B):
                    patch_attn_mask[j, :tgt_sizes[j][0] * tgt_sizes[j][1]] = True
                embeddings = self.model.vpm(pixel_values.type(torch.half),
                                            patch_attention_mask=patch_attn_mask).last_hidden_state
            else:
                for j in range(B):
                    patch_attn_mask[j, 0, :tgt_sizes[j][0] * tgt_sizes[j][1]] = True
                embeddings = self.model.vpm(pixel_values.type(torch.half),
                                            patch_attention_mask=patch_attn_mask,
                                            tgt_sizes=tgt_sizes).last_hidden_state

            embeddings = self.model.resampler(embeddings, tgt_sizes)
            embeddings = torch.split(embeddings, num_patches, 0)
            for embedding in embeddings:
                embedding = embedding.split(1, dim=0)
                outputs.extend([x.squeeze() for x in embedding])
        messages.append(dict(role='forward', content=outputs))
        return messages

    def proc_messages(self, messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        idx = 0
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            if 'preprocess' not in message.keys():
                continue
            prompts = []
            for x in message['preprocess']:
                prompt = f'{IMAGE_TOKEN}'
                if x.get('use_image_id', False):
                    prompt = f'{idx}' + prompt
                    idx += 1
                grid = x['best_grid']
                if grid is not None:
                    if self.version == '2.5':
                        slice = '\n'.join([f'{IMAGE_TOKEN}' * grid[0]] * grid[1])
                        prompt = f'{prompt}{slice}\n'
                    elif self.version == '2.6':
                        slice = '\n'.join([f'{IMAGE_TOKEN}' * grid[0]] * grid[1])
                        prompt = prompt + slice
                        prompt += '\n'
                else:
                    prompt = (prompt + '\n' if self.version == '2.6' else prompt)
                prompts.append(prompt)
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            prompt = ''.join(prompts) + content[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/mllama.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List

from lmdeploy.vl.model.base import VISION_MODELS, VisionModel


def check_transformers():
    try:
        from transformers import MllamaForConditionalGeneration  # noqa: F401
    except ImportError:
        raise ImportError('please install latest transformers by '
                          'pip install git+https://github.com/huggingface/transformers.git')


@VISION_MODELS.register_module()
class MllamaVLModel(VisionModel):
    """llama3.2 model."""

    _arch = 'MllamaForConditionalGeneration'

    def build_preprocessor(self):
        from transformers import AutoProcessor
        self.processor = AutoProcessor.from_pretrained(self.model_path)
        self.image_token_id = 128256

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to the spec of `super().preprocess`"""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            results = self.processor.image_processor(images=image, return_tensors='pt')
            results.update(image_size=image.size, image_tokens=1, image_token_id=self.image_token_id)
            outputs.append(results)
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    def build_model(self):
        check_transformers()
        if self.with_llm:
            from transformers import MllamaForConditionalGeneration
            model = MllamaForConditionalGeneration.from_pretrained(self.model_path, device_map='cpu')
            self.vl_model = model
        else:
            raise NotImplementedError('turbomind has not supported mllama yet')

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = '<|image|>'
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [item['text'] for item in message['content'] if item['type'] == 'text']
            prompt = (IMAGE_TOKEN) * n_images + content[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/molmo.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class MolmoVisionModel(VisionModel):
    """Molmo's vision model."""

    _arch = 'MolmoForCausalLM'

    def build_preprocessor(self):
        self.processor = AutoProcessor.from_pretrained(self.model_path,
                                                       trust_remote_code=True,
                                                       torch_dtype=torch.half,
                                                       device_map='auto')

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights, load_checkpoint_and_dispatch
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True)

            self.vl_model = model
            if not self.with_llm:
                # Remove nn modules other than embedding from the LLM model
                for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
                    del model.model.transformer[key]
            self.token_embedding = model.model.transformer.wte

        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map='auto' if not self.with_llm else {'': 'cpu'},
                                         max_memory=self.max_memory,
                                         no_split_module_classes=['ResidualAttentionBlock', 'Embedding'],
                                         dtype=torch.half)

        # We need eval mode to freeze the weights in model, thus,
        # avoid randomness in inference.
        self.model = model.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to the `super.preprocess() for spec."""
        for i, message in enumerate(messages):
            if not isinstance(message['content'], List):
                continue
            images = [x['image'] for x in message['content'] if x['type'] == 'image']
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            prompt = f' User: {content[0]}'
            tokens = self.processor.tokenizer.encode(prompt, add_special_tokens=False)
            # preprocess images. The output is a dict, which is
            # {
            #     'input_ids': torch.Tensor,
            #     'images': torch.Tensor, # (n_patch, d_model)
            #     'image_input_idx': torch.Tensor, # (n_patch, d_model)
            #     'image_masks': torch.Tensor,  # (n_patch, d_model)
            # }
            result = self.processor.process(images=images, tokens=tokens)
            # remove the bos from input_ids which is prepended by molmo's
            # processor
            input_ids = result['input_ids'][1:]
            result.update(input_ids=input_ids)
            messages[i].update(preprocess=result)
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        for i, message in enumerate(messages):
            if 'preprocess' not in message.keys():
                continue
            inputs = message['preprocess']
            # get input_ids of embedding
            inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
            input_ids = inputs['input_ids']
            # (batch_size, num_image, num_patch, d_model)
            images = inputs['images']
            # (batch_size, num_image, num_patch)
            image_input_idx = inputs['image_input_idx']
            image_masks = inputs['image_masks']
            batch_size, seq_len = input_ids.size()
            assert batch_size == 1
            input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
            embeddings = self.model.model.transformer.wte(input_ids)
            images = images.to(self.model.dtype)
            image_masks = image_masks.to(self.model.dtype)
            logger.info(f'vision forward shape: {images.shape}')
            image_features, _ = self.model.model.vision_backbone(images, image_masks)
            num_image, num_patch = image_features.shape[1:3]
            assert image_input_idx.shape == (batch_size, num_image, num_patch)

            # insert the image feature into the embedding.
            image_features = image_features.view(batch_size, num_image * num_patch, -1)
            image_input_idx = image_input_idx.view(batch_size, num_image * num_patch)
            valid = image_input_idx >= 0
            batch_idx = torch.arange(batch_size, device=embeddings.device)
            batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
            image_features = image_features.to(embeddings.device)
            # Since we remove bos_id from input_ids during `preprocess`,
            # the index `image_input_idx[valid]` should be shift to left
            # by subtracting 1
            index = image_input_idx[valid] - 1
            embeddings[batch_idx[valid], index] += image_features[valid]
            assert embeddings.shape[:2] == (batch_size, seq_len)
            messages[i].update(dict(forward=dict(input_ids=input_ids.flatten(), embeddings=embeddings)))
        return messages

    @staticmethod
    def proc_messages(messages):
        prompt = []
        IMAGE_TOKEN = ''
        for message in messages:
            role, content = message['role'], message['content']
            if isinstance(content, List):
                n_images = len([1 for x in content if x['type'] == 'image'])
                content = [x['text'] for x in content if x['type'] == 'text']
                prompt.append(' User: ' + (IMAGE_TOKEN + '\n') * n_images + content[0])
            else:
                if role == 'user':
                    prompt.append(f' User: {content}')
                elif role == 'assistant':
                    prompt.append(f' Assistant:{content}')
                else:
                    assert 0, f'molmo does not support role {role}, message is {message}'  # noqa
        prompt.append(' Assistant:')
        return ''.join(prompt)

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        assert 0, 'molmo is not supported by pytorch engine'

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        # results is a list of tuple(input_ids, embeddings)
        results = []
        # Prepend BOS
        # qwen2 and olmo do not have a BOS, and instead use EOS as a generic
        # separator token.
        bos = (self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id)
        results.append(([bos], None))

        for i, message in enumerate(messages):
            prompt = ''
            role, content = message['role'], message['content']
            if isinstance(content, List):
                forward_result = message.pop('forward')
                input_ids = forward_result['input_ids']
                embeddings = forward_result['embeddings']
                results.append((input_ids.tolist(), embeddings))
            else:
                if role == 'user':
                    prompt = f' User: {content}'
                elif role == 'assistant':
                    prompt = f' Assistant:{content}'
                else:
                    assert 0, f'molmo does not support role {role}, message is {message}'  # noqa
            if i == len(messages) - 1:
                # the last message
                assert role == 'user', f'the role of last message is expected to be user, but got {role}'  # noqa
                prompt += ' Assistant:'
            if prompt:
                input_ids = self.processor.tokenizer.encode(prompt, add_special_tokens=False)
                results.append((input_ids, None))

        # concat input_ids from results, calculate the range in the input_ids
        # where embeddings will be copied to
        input_ids = []
        input_embeddings = []
        input_embedding_ranges = []
        start = 0
        for _input_ids, _embeddings in results:
            if _embeddings is not None:
                input_embeddings.append(_embeddings.cpu())
                end = start + len(_input_ids)
                input_embedding_ranges.append((start, end))
            input_ids += _input_ids
            start += len(_input_ids)

        prompt = self.proc_messages(messages)
        return dict(prompt=prompt,
                    input_ids=input_ids,
                    input_embeddings=input_embeddings,
                    input_embedding_ranges=input_embedding_ranges)


================================================
FILE: lmdeploy/vl/model/phi3_vision.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List

from transformers import AutoProcessor

from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel


@VISION_MODELS.register_module()
class Phi3VisionModel(LlavaHfVisionModel):
    """Phi3-vision model."""

    _arch = 'Phi3VForCausalLM'

    def build_preprocessor(self):
        processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        if hasattr(processor, 'tokenizer'):
            del processor.tokenizer
            processor.tokenizer = None
        self.processor = processor

    def build_model(self):
        if self.with_llm:
            from transformers import AutoModelForCausalLM
            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,
                                                                 device_map='cpu',
                                                                 trust_remote_code=True)
        else:
            raise NotImplementedError('turbomind has not supported phi3v yet')

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            result = self.processor.image_processor([image], return_tensors='pt')
            image_tokens = result['num_img_tokens']
            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
            outputs.append(result)
        messages.append(dict(role='preprocess', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/model/qwen.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


@VISION_MODELS.register_module()
class QwenVisionModel(VisionModel):
    """Qwen vision model."""

    _arch = 'QWenLMHeadModel'

    def build_preprocessor(self):
        from torchvision import transforms
        from torchvision.transforms import InterpolationMode
        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)
        image_size = self.hf_config.visual['image_size']
        self.image_transform = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights():
            config = self.hf_config
            config.quantization_config = {}  # disable vision part quantization
            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
            self.vl_model = model
            if not self.with_llm:
                del model.lm_head
                for key in ['wte', 'h', 'ln_f']:
                    setattr(model.transformer, key, None)

        from accelerate.utils import get_balanced_memory, infer_auto_device_map
        max_memory = get_balanced_memory(model,
                                         max_memory=self.max_memory,
                                         dtype=torch.half,
                                         no_split_module_classes=['VisualAttentionBlock', 'Resampler'])
        device_map = infer_auto_device_map(model,
                                           no_split_module_classes=['VisualAttentionBlock', 'Resampler'],
                                           max_memory=max_memory,
                                           dtype=torch.half)
        same_device_keys = [('transformer.visual.conv1', 'transformer.visual.positional_embedding'),
                            ('transformer.visual.ln_post', 'transformer.visual.proj')]
        for (a, b) in same_device_keys:
            if a in device_map and b in device_map:
                device_map[b] = device_map[a]

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map=device_map if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=['VisualAttentionBlock'],
                                         dtype=torch.half)

        self.model = model.transformer.visual.eval()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refers to `super.preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values = self.image_transform(image)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=256,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = torch.stack(pixel_values, dim=0)
            logger.info(f'vision forward shape: {pixel_values.shape}')
            feats = self.model(pixel_values)
            feats = torch.split(feats, 1, dim=0)
            outputs.extend([x.squeeze() for x in feats])
        messages.append(dict(role='forward', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
            prompt = content[0]
            if IMAGE_TOKEN in prompt:
                pass
            else:
                prompt = ''.join([f'Picture {str(i)}:{IMAGE_TOKEN}\n' for i in range(n_images)]) + prompt
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/qwen2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging


def check_qwen_vl_deps_install():
    """Check qwen_vl_utils."""
    try:
        import qwen_vl_utils  # noqa: F401
    except ImportError:
        raise ImportError('please install qwen_vl_utils by `pip install qwen_vl_utils`'  # noqa: E501
                          )
    try:
        from transformers import Qwen2VLForConditionalGeneration  # noqa: F401
    except ImportError:
        raise ImportError('please install latest transformers by '
                          'pip install git+https://github.com/huggingface/transformers.git')


@VISION_MODELS.register_module()
class Qwen2VLModel(VisionModel):
    """Qwen2VL model."""

    _arch = ['Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration']

    def build_preprocessor(self):
        check_qwen_vl_deps_install()
        from transformers import AutoProcessor
        self.processor = AutoProcessor.from_pretrained(self.model_path)
        tokenizer = self.processor.tokenizer
        self.image_token = self.processor.image_token
        self.image_token_id = tokenizer.encode(self.image_token)[-1]

    def preprocess(self, messages: list[dict]) -> list[dict]:
        """Refer to `super().preprocess()` for spec."""
        from qwen_vl_utils import process_vision_info

        images = self.collect_multimodal_items(messages)
        optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')

            item = dict(type='image', image=image)
            item.update({key: params[key] for key in params.keys() if key in optional_keys})
            image_inputs, _ = process_vision_info([dict(content=[item])])
            result = self.processor.image_processor(images=image_inputs, return_tensors='pt')
            merge_length = self.processor.image_processor.merge_size**2
            image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
            outputs.append(result)
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    def build_model(self):
        check_qwen_vl_deps_install()
        arch = self.hf_config.architectures[0]
        if arch == 'Qwen2VLForConditionalGeneration':
            from transformers import Qwen2VLForConditionalGeneration as AutoModelCls
        elif arch == 'Qwen2_5_VLForConditionalGeneration':
            from transformers import Qwen2_5_VLForConditionalGeneration as AutoModelCls
        else:
            raise ValueError(f'Unsupported arch={arch}')

        if self.with_llm:
            self.vl_model = AutoModelCls.from_pretrained(self.model_path, device_map='cpu')
        else:
            from accelerate import init_empty_weights
            with init_empty_weights():
                config = self.hf_config
                # disable accelerate check_tied_parameters_in_config for Qwen2-VL-2B-Instruct
                config.tie_word_embeddings = False
                if hasattr(config, 'text_config'):
                    config.text_config.tie_word_embeddings = False
                model = AutoModelCls._from_config(config)
                model.visual = model.model.visual
                del model.model
                del model.lm_head
                model.half()

            from accelerate import load_checkpoint_and_dispatch
            with disable_logging():
                load_checkpoint_and_dispatch(model=model,
                                             checkpoint=self.model_path,
                                             device_map='auto' if not self.with_llm else {'': 'cpu'},
                                             max_memory=self.max_memory,
                                             no_split_module_classes=['Qwen2VLVisionBlock', 'Qwen2_5_VLVisionBlock'],
                                             dtype=torch.half)
            self.model = model.eval()

    @torch.no_grad()
    def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(list[dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]
        dtype = torch.half
        device = next(self.model.visual.parameters()).device
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            pixel_values = [x['pixel_values'].type(dtype) for x in inputs[idx:idx + max_batch_size]]
            image_grid_thw = [x['image_grid_thw'] for x in inputs[idx:idx + max_batch_size]]
            pixel_values = torch.cat(pixel_values, dim=0).to(device)
            image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device)
            image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
            if hasattr(image_embeds, 'pooler_output'):
                # transformers >= 5.0.0, the type if image_embeds is `BaseModelOutputWithPooling`
                # rather than torch.Tensor
                image_embeds = image_embeds.pooler_output
            merge_length = self.processor.image_processor.merge_size**2
            split_size = image_grid_thw.prod(dim=1) // merge_length
            image_embeds = image_embeds.split(split_size.tolist())
            outputs.extend(image_embeds)
        messages.append(dict(role='forward', content=outputs))
        return messages

    def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None):
        """Apply chat template to get the prompt."""
        chat_template_kwargs = chat_template_kwargs or {}
        prompt_messages = []
        IMAGE_TOKEN = ''
        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]
        if VisionModel.IMAGE_TOKEN_included(messages):
            # backward compatibility
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                content = [x['text'] for x in content if x['type'] == 'text']
                prompt = ''.join(content)
                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')
                prompt_messages.append(dict(role='user', content=prompt))
        else:
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                _content = []
                for item in content:
                    if item['type'] == 'text':
                        _content.append(item['text'])
                    elif item['type'] in ['image', 'image_url']:
                        _content.append(f'<|vision_start|>{self.image_token}<|vision_end|>')
                    else:
                        raise ValueError(f'Unsupported message type: {item["type"]}')
                message = dict(role=role, content=''.join(_content))
                prompt_messages.append(message)
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, self.image_token

    @staticmethod
    def get_mrope_info(seq_len: int,
                       grid_thws: list[tuple[int, int, int]] = None,
                       ranges: list[tuple[int, int]] = None):
        mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)]
        st_idx = ranges[0][0]
        for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)):
            llm_grid_t, llm_grid_h, llm_grid_w = grid_thw
            llm_grid_h //= 2
            llm_grid_w //= 2
            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
            mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx)
            st_idx += max(llm_grid_h, llm_grid_w)
            if i < len(ranges) - 1:
                text_len = ranges[i + 1][0] - ranges[i][1]
            else:
                text_len = seq_len - embedding_range[1]
            mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx)
            st_idx += text_len
        mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
        mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long)
        return mrope_position_ids, mrope_position_delta

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
        """Return to the information needed by pytorch engine."""
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)
        info = super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]
        grid_thws = [x['image_grid_thw'].tolist()[0] for x in inputs]
        seq_len = len(info['input_ids'])
        ranges = info['input_embedding_ranges']
        mrope_position_ids, mrope_position_delta = self.get_mrope_info(seq_len, grid_thws, ranges)
        meta = dict(mrope_position_ids=mrope_position_ids, mrope_position_delta=mrope_position_delta)
        info.update(dict(input_meta=meta))
        return info


================================================
FILE: lmdeploy/vl/model/qwen3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List

import torch
from transformers import AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import Modality
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel

logger = get_logger('lmdeploy')


def check_transformers():
    try:
        from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration  # noqa: F401
    except ImportError:
        raise ImportError('please install latest transformers by '
                          'pip install git+https://github.com/huggingface/transformers.git')


@VISION_MODELS.register_module()
class Qwen3VLModel(VisionModel):
    """Qwen3VL model."""

    _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration']

    def build_preprocessor(self):
        check_transformers()
        self.processor = AutoProcessor.from_pretrained(self.model_path)

        # image tokens
        self.image_token = self.processor.image_token
        self.image_token_id = self.processor.image_token_id

        # video tokens
        self.video_token = self.processor.video_token
        self.video_token_id = self.processor.video_token_id

        # vision start and end tokens
        self.vision_start_token = self.processor.vision_start_token
        self.vision_end_token = self.processor.vision_end_token

    def get_processor_args(self, mm_processor_kwargs: Dict[str, Any] | None = None):
        min_pixels = self.processor.image_processor.size['shortest_edge']
        max_pixels = self.processor.image_processor.size['longest_edge']

        if mm_processor_kwargs is None:
            return min_pixels, max_pixels

        input_min_pixels = mm_processor_kwargs.get('min_pixels', None)
        input_max_pixels = mm_processor_kwargs.get('max_pixels', None)

        # boundary check for min_pixels and max_pixels
        if input_min_pixels is None:
            if input_max_pixels is not None:
                # only max_pixels is given in the input
                if input_max_pixels < min_pixels:
                    logger.warning(
                        f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.')
                    return min_pixels, max_pixels
                max_pixels = input_max_pixels
        else:
            if input_max_pixels is None:
                # only min_pixels is given in the input
                if input_min_pixels > max_pixels:
                    logger.warning(
                        f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.')
                    return min_pixels, max_pixels
            else:
                if input_min_pixels > input_max_pixels:
                    logger.warning(
                        f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.')
                    return min_pixels, max_pixels
                max_pixels = input_max_pixels
            min_pixels = input_min_pixels

        return min_pixels, max_pixels

    def _preprocess_image(self,
                          data: List[Any],
                          params: Dict[str, Any],
                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:

        image = data.convert('RGB')
        min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs)

        result = self.processor.image_processor(images=image,
                                                size={
                                                    'shortest_edge': min_pixels,
                                                    'longest_edge': max_pixels
                                                },
                                                return_tensors='pt')
        merge_length = self.processor.image_processor.merge_size**2
        image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
        result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
        return result

    def _preprocess_video(self,
                          data: List[Any],
                          params: Dict[str, Any],
                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:

        # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs
        metadata = params['video_metadata']
        video_kwargs = dict(return_metadata=True,
                            do_resize=True,
                            do_sample_frames=False,
                            video_metadata=metadata,
                            return_tensors='pt')
        result = self.processor.video_processor(videos=data, **video_kwargs)
        video_grid_thw = result['video_grid_thw']

        merge_length = self.processor.video_processor.merge_size**2
        if metadata.get('fps') is None:
            logger.warning_once('Qwen3VL: fps not found, defaulting to 24.')
            metadata['fps'] = metadata['fps'] or 24

        # if timestamps are not provided, calculate them
        curr_timestamp = self.processor._calculate_timestamps(
            metadata['frames_indices'],
            metadata['fps'],
            self.processor.video_processor.merge_size,
        )

        frame_seqlen = video_grid_thw[0][1:].prod() // merge_length
        result.update(curr_timestamp=curr_timestamp, frame_seqlen=frame_seqlen, video_token_id=self.video_token_id)
        return result

    def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:
        """Refer to `super().preprocess()` for spec."""
        outputs = []
        self.contains_video_input = False

        mm_items = self.collect_multimodal_items(messages)
        for modality, data, params in mm_items:
            result = {}
            if modality == Modality.IMAGE:
                result = self._preprocess_image(data, params, mm_processor_kwargs)
            elif modality == Modality.VIDEO:
                self.contains_video_input = True
                result = self._preprocess_video(data, params, mm_processor_kwargs)

            result.update(modality=modality)
            outputs.append(result)

        messages.append(dict(role='preprocess', content=outputs))
        return messages

    def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None):
        """Apply chat template to get the prompt."""
        chat_template_kwargs = chat_template_kwargs or {}
        prompt_messages = []
        IMAGE_TOKEN = ''
        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]
        if VisionModel.IMAGE_TOKEN_included(messages):
            # backward compatibility
            for message in messages:
                role, content = message['role'], message['content']
                if role != 'user' or isinstance(content, str):
                    prompt_messages.append(message)
                    continue
                content = [x['text'] for x in content if x['type'] == 'text']
                prompt = ''.join(content)
                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')
                prompt_messages.append(dict(role='user', content=prompt))
        else:
            prompt_messages = messages
        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs)
        return prompt, None

    def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start):
        """Pack the video input to the compatible format with pytorch
        engine."""

        # collect all preprocessing result from messages
        preps = [x['content'] for x in messages if x['role'] == 'preprocess']
        assert len(preps) == 1
        preps = preps[0]

        # split prompt into segments and validate data
        segs = prompt.split(self.vision_start_token + self.video_token + self.vision_end_token)
        assert len(segs) == len(preps) + 1, (f'the number of {self.video_token} is not equal '
                                             f'to input videos, {len(segs) - 1} vs {len(preps)}')

        # calculate the video token offset for each video
        input_ids = []
        for i, seg in enumerate(segs):
            if i > 0 and i <= len(preps):
                preps[i - 1].update(offset=len(input_ids))
                frame_seqlen = preps[i - 1]['frame_seqlen']
                assert self.video_token_id == preps[i - 1]['video_token_id']

                video_grid_thw = preps[i - 1]['video_grid_thw']
                curr_timestamp = preps[i - 1]['curr_timestamp']

                # update prompt with timestamp index tokens and video pad tokens
                video_placeholder = ''
                for frame_idx in range(video_grid_thw[0][0]):
                    curr_time = curr_timestamp[frame_idx]
                    video_placeholder += f'<{curr_time:.1f} seconds>'
                    video_placeholder += (self.vision_start_token + '<|placeholder|>' * frame_seqlen +
                                          self.vision_end_token)

                video_placeholder = video_placeholder.replace('<|placeholder|>', self.video_token)
                video_token_ids = tokenizer.encode(video_placeholder)
                input_ids.extend(video_token_ids)

                preps[i - 1].update(video_tokens=len(video_token_ids))

            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))
            input_ids.extend(token_ids)

        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)

    def to_pytorch(self,
                   messages,
                   chat_template,
                   tokenizer,
                   sequence_start,
                   chat_template_kwargs: Dict | None = None,
                   **kwargs):
        """Return to the information needed by pytorch engine."""
        prompt, _ = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)

        if self.contains_video_input:
            return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start)
        else:
            return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start)

    def build_model(self):
        # TODO: implement for turbomind
        pass

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        # TODO: implement for turbomind
        pass

    def to_turbomind(self,
                     messages,
                     chat_template,
                     tokenizer,
                     sequence_start,
                     chat_template_kwargs: Dict | None = None,
                     **kwargs):
        # TODO: implement for turbomind
        pass


================================================
FILE: lmdeploy/vl/model/qwen3_5.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import AutoProcessor

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS

from .qwen3 import Qwen3VLModel

logger = get_logger('lmdeploy')


def check_transformers():
    try:
        from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5MoeForConditionalGeneration  # noqa: F401
    except ImportError:
        raise ImportError('please install latest transformers by '
                          'pip install git+https://github.com/huggingface/transformers.git')


@VISION_MODELS.register_module()
class Qwen3_5Model(Qwen3VLModel):
    """Qwen3_5 model."""

    _arch = ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration']

    def build_preprocessor(self):
        check_transformers()

        self.processor = AutoProcessor.from_pretrained(self.model_path)

        # image tokens
        self.image_token = self.processor.image_token
        self.image_token_id = self.processor.image_token_id

        # video tokens
        self.video_token = self.processor.video_token
        self.video_token_id = self.processor.video_token_id

        # vision start and end tokens
        self.vision_start_token = self.processor.vision_start_token
        self.vision_end_token = self.processor.vision_end_token


================================================
FILE: lmdeploy/vl/model/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import inspect
from contextlib import contextmanager
from typing import Callable, MutableSequence

import torch


@contextmanager
def disable_transformers_logging():
    import transformers
    from transformers.utils import logging
    previous_level = logging.get_verbosity()
    logging.set_verbosity(transformers.logging.ERROR)
    yield
    logging.set_verbosity(previous_level)


@contextmanager
def disable_logging():
    import logging
    previous_level = logging.root.manager.disable
    logging.disable(logging.ERROR)
    yield
    logging.disable(previous_level)


def _set_func(origin_func_path: str | None, rewrite_func: Callable, origin_func: Callable = None):
    """Replace old function with the new function.

    Args:
        origin_func_path (str): original function path
        rewrite_func (Callable): function to replace with
        origin_func (Callable): function to replace
    """
    # import module
    if isinstance(origin_func_path, str):
        split_path = origin_func_path.split('.')
        for i in range(len(split_path), 0, -1):
            try:
                exec('import {}'.format('.'.join(split_path[:i])))
                break
            except Exception:
                continue

        origin_func = eval(origin_func_path) \
            if origin_func is None else origin_func

    method_class = inspect.ismethod(origin_func)

    # replace method
    if not method_class:
        import gc
        refs = gc.get_referrers(origin_func)
        obj_id = id(origin_func)
        for ref in refs:
            if isinstance(ref, dict):
                for x, y in ref.items():
                    if id(y) == obj_id:
                        ref[x] = rewrite_func
            elif isinstance(ref, MutableSequence):
                for i, v in enumerate(ref):
                    if id(v) == obj_id:
                        ref[i] = rewrite_func
    if isinstance(origin_func_path, str):
        exec(f'{origin_func_path} = rewrite_func')
    elif method_class:
        raise NotImplementedError

    return origin_func


@contextmanager
def rewrite_ctx(origin_func_path: list[str | Callable], rewrite_func: list[Callable]):
    """Rewrite context."""
    assert len(origin_func_path) == len(rewrite_func)
    origin_func_list = []
    for (func_path, dst_func) in zip(origin_func_path, rewrite_func):
        if isinstance(func_path, Callable):
            origin_func = _set_func(None, dst_func, func_path)
        else:
            origin_func = _set_func(func_path, dst_func)
        origin_func_list.append(origin_func)
    yield
    for (func_path, dst_func, origin_func) in zip(origin_func_path, rewrite_func, origin_func_list):
        if isinstance(func_path, Callable):
            _set_func(None, origin_func, dst_func)
        else:
            _set_func(func_path, origin_func, dst_func)


def add_device_hook(module: torch.nn.Module, device: torch.device, fn: Callable = None):
    """Add device hook."""
    from accelerate.hooks import ModelHook, add_hook_to_module

    class ToDevice(ModelHook):
        """ToDevice hook."""

        def __init__(self, device):
            self.device = device

        def post_forward(self, module, output):
            if fn is not None:
                output = fn(output)
            else:
                output = output.to(device=self.device)
            return output

    add_hook_to_module(module=module, hook=ToDevice(device=device), append=True)


================================================
FILE: lmdeploy/vl/model/xcomposer2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import enum
import os
import sys
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import add_device_hook, disable_logging, rewrite_ctx

logger = get_logger('lmdeploy')


def check_xcomposer_install():
    try:
        # WARNING! we have to do this otherwise the model_type is wrong for
        # xcomposer2d5
        import decord  # noqa: F401
    except ImportError:
        raise ImportError("No module named 'decord'. Please install decord by `pip install decord`"  # noqa
                          )


class ModelType(enum.Enum):
    """Request type."""
    XCOMPOSER2 = enum.auto()
    XCOMPOSER2_4KHD = enum.auto()
    XCOMPOSER2D5 = enum.auto()


def get_xcomposer_type(model_path: str) -> Tuple[ModelType, Any]:
    """Get xcomposer type."""
    from transformers.dynamic_module_utils import get_class_from_dynamic_module
    match_modules = {
        'ixc_utils.Image_transform': ModelType.XCOMPOSER2D5,
        'ixc_utils.HD_transform': ModelType.XCOMPOSER2_4KHD
    }
    for key, value in match_modules.items():
        try:
            module = get_class_from_dynamic_module(key, model_path)
            return value, module
        except Exception:
            pass
    return ModelType.XCOMPOSER2, None


def _CLIPVisionModel_from_pretrained(vision_tower_name):
    from transformers import CLIPVisionConfig, CLIPVisionModel
    config = CLIPVisionConfig.from_pretrained(vision_tower_name)
    model = CLIPVisionModel._from_config(config)
    return model


@contextmanager
def init_empty_vit(model_path):
    """Skip download vision model."""
    origin_func_path = [
        'transformers.CLIPVisionModel.from_pretrained',
    ]
    rewrite_func = [
        _CLIPVisionModel_from_pretrained,
    ]

    model_type, _ = get_xcomposer_type(model_path)
    if model_type == ModelType.XCOMPOSER2D5:
        from transformers.dynamic_module_utils import get_class_from_dynamic_module
        from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME
        _ = get_class_from_dynamic_module('modeling_internlm_xcomposer2.get_font', model_path)
        folder = model_path.rstrip(os.sep).split(os.sep)[-1]
        module_path = '.'.join([TRANSFORMERS_DYNAMIC_MODULE_NAME, folder, 'modeling_internlm_xcomposer2'])
        origin_get_font_func = getattr(sys.modules[module_path], 'get_font')
        origin_func_path.append(origin_get_font_func)
        rewrite_func.append(lambda: None)

    with rewrite_ctx(origin_func_path, rewrite_func):
        yield


@VISION_MODELS.register_module()
class Xcomposer2VisionModel(VisionModel):
    """InternLM-Xcomposer2 vision model."""

    def __init__(self,
                 model_path: str,
                 with_llm: bool = False,
                 max_memory: Dict[int, int] = None,
                 hf_config: AutoConfig = None,
                 backend: str = ''):
        model_path = model_path.rstrip(os.sep)
        super().__init__(model_path, with_llm, max_memory, hf_config, backend)
        check_xcomposer_install()
        self.model_type, self.module = get_xcomposer_type(self.model_path)
        logger.info(f'matching type of {self.model_type}')

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        target = 'InternLMXComposer2ForCausalLM'
        if arch == target:
            return True
        for _, v in getattr(config, 'auto_map', {}).items():
            if target in v:
                return True
        return False

    def build_preprocessor(self):

        import torchvision.transforms as transforms
        from torchvision.transforms.functional import InterpolationMode

        if self.model_type in [ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD]:
            self.HD_transform = self.module
            self.vis_processor = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ])
            self.preprocess_func = (self._preprocess_2d5
                                    if self.model_type == ModelType.XCOMPOSER2D5 else self._preprocess_4khd_7b)
        else:
            self.vis_processor = transforms.Compose([
                transforms.Resize((self.hf_config.img_size, self.hf_config.img_size),
                                  interpolation=InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ])
            self.preprocess_func = self._preprocess_7b

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        from accelerate import init_empty_weights
        with init_empty_weights(), warnings.catch_warnings(), \
                init_empty_vit(self.model_path):
            warnings.simplefilter('ignore')
            config = self.hf_config
            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
            model.vit.load_model()
            model.vit.resize_pos()
            if hasattr(self.hf_config, 'img_size'):
                model.vit.vision_tower.vision_model.embeddings.image_size = \
                    self.hf_config.img_size
            model.vit.vision_tower.vision_model.post_layernorm.to_empty(device='cpu').half()
            self.vl_model = model
            if not self.with_llm:
                del model.model
                del model.output

        from accelerate.utils import get_balanced_memory, infer_auto_device_map
        max_memory = get_balanced_memory(model,
                                         max_memory=self.max_memory,
                                         dtype=torch.half,
                                         no_split_module_classes=['CLIPEncoderLayer'])
        device_map = infer_auto_device_map(model,
                                           no_split_module_classes=['CLIPEncoderLayer'],
                                           max_memory=max_memory,
                                           dtype=torch.half)
        # make all tensor on same device for postprocess
        if 'plora_glb_GN' in device_map:
            device_map['plora_sub_GN'] = device_map['plora_glb_GN']

        from accelerate import load_checkpoint_and_dispatch
        with disable_logging():
            load_checkpoint_and_dispatch(model=model,
                                         checkpoint=self.model_path,
                                         device_map=device_map if not self.with_llm else {'': 'cpu'},
                                         no_split_module_classes=['CLIPEncoderLayer'],
                                         dtype=torch.half)

        if 'plora_glb_GN' in device_map:
            add_device_hook(model.vit.vision_tower.vision_model.encoder.layers[-1], device_map['plora_glb_GN'],
                            lambda x: (x[0].to(device=device_map['plora_glb_GN']), ))

        self.model = model.eval()

    def _preprocess_2d5(self, image: Image, params: Dict) -> Dict:
        """Image preprocessing for internlm-xcomposer2d5-7b."""
        hd_num = params.get('hd_num', 24)
        image = self.HD_transform(image, hd_num=hd_num)
        pixel_values = self.vis_processor(image).unsqueeze(0).half()
        w, h = image.size
        w, h = w // 560, h // 560
        n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20)
        return pixel_values, n_token_per_image

    def _preprocess_7b(self, image: Image, params: Dict) -> Dict:
        """Image preprocessing for internlm-xcomposer2-7b."""
        pixel_values = self.vis_processor(image).unsqueeze(0).half()
        return pixel_values, 256

    def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict:
        """Image preprocessing for internlm-xcomposer2-4khd-7b."""
        image = self.HD_transform(image, hd_num=25)
        pixel_values = self.vis_processor(image).unsqueeze(0).half()
        w, h = image.size
        w, h = w // 336, h // 336
        n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
        return pixel_values, n_token_per_image

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values, n_token = self.preprocess_func(image, params)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=n_token,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages

    @torch.no_grad()
    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
        """Extract image feature. ONLY implement it when the backend is
        turbomind engine.

        Args:
            messages(List[Dict]): the outputs of `preprocess`
            max_batch_size(int): the max batch size when forwarding vision
                model
        Return:
            the message list with forwarding results included
        """
        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
        inputs = inputs[0]
        outputs = []
        for idx in range(0, len(inputs), max_batch_size):
            if self.model_type in [ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD]:
                pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
                embeds, split = self.model.vit(pixel_values, self.model.plora_glb_GN, self.model.plora_sub_GN)
                embeds = self.model.vision_proj(embeds)
                embeds = torch.split(embeds, split, dim=1)
                embeds = [x.squeeze() for x in embeds]
            else:
                pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]
                pixel_values = torch.cat(pixel_values, dim=0)
                logger.info(f'vision forward shape: {pixel_values.shape}')
                embeds = self.model.vit(pixel_values)
                embeds = self.model.vision_proj(embeds)
                embeds = torch.split(embeds, 1, dim=0)
                embeds = [x.squeeze() for x in embeds]
            outputs.extend(embeds)
        messages.append(dict(role='forward', content=outputs))
        return messages

    @staticmethod
    def proc_messages(messages, chat_template, sequence_start, model_type):
        """Apply chat template to get the prompt."""
        prompt_messages = []
        IMAGE_TOKEN = ''
        prefix_image_token = ''
        for message in messages:
            if isinstance(message['content'], str):
                prompt_messages.append(message)
                continue
            elif message['role'] in ['images', 'preprocess', 'forward']:
                continue
            n_images = len([1 for x in message['content'] if x['type'] == 'image'])
            content = [item['text'] for item in message['content'] if item['type'] == 'text']
            if IMAGE_TOKEN not in content[0]:
                if model_type == ModelType.XCOMPOSER2D5:
                    if n_images == 1:
                        prefix_image_token, prompt = IMAGE_TOKEN, content[0]
                    else:
                        prompt = ''.join([f'Image{i+1} {IMAGE_TOKEN}; ' for i in range(n_images)]) + content[0]
                else:
                    prompt = ''.join([IMAGE_TOKEN] * n_images) + content[0]
            else:
                prompt = content[0]
            prompt_messages.append(dict(role='user', content=prompt))
        prompt = prefix_image_token + chat_template.messages2prompt(prompt_messages, sequence_start)
        return prompt, IMAGE_TOKEN

    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, self.model_type)
        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, self.model_type)
        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)


================================================
FILE: lmdeploy/vl/model/yi.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.

import os
from contextlib import contextmanager
from os import path as osp
from typing import Dict, List

import torch.nn as nn
from transformers import AutoConfig

from lmdeploy.vl.model.base import VISION_MODELS
from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install, process_images

from .utils import disable_transformers_logging, rewrite_ctx

_model_path = None


def _build_vision_projector(config, delay_load=False, **kwargs):
    """Build yi projector."""
    # copy from https://github.com/01-ai/Yi/blob/main/VL/llava/model/multimodal_projector/builder.py # noqa: E501
    projector_type = getattr(config, 'mm_projector_type', 'linear')

    if projector_type == 'linear':
        return nn.Linear(config.mm_hidden_size, config.hidden_size)

    import re
    use_norm = False
    if '_Norm' in projector_type:
        use_norm = True
        projector_type = projector_type.replace('_Norm', '')
    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        if use_norm:
            modules = [
                nn.Linear(config.mm_hidden_size, config.hidden_size),
                nn.LayerNorm(config.hidden_size),
            ]
        else:
            modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            if use_norm:
                modules.append(nn.Linear(config.hidden_size, config.hidden_size))
                modules.append(nn.LayerNorm(config.hidden_size))
            else:
                modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == 'identity':
        return nn.Identity()

    raise ValueError(f'Unknown projector type: {projector_type}')


def _build_vision_tower(vision_tower_cfg, **kwargs):
    """Build yi vision tower."""
    cfg = vision_tower_cfg
    vision_tower = getattr(cfg, 'mm_vision_tower', getattr(cfg, 'vision_tower', None))
    if os.path.exists(os.path.join(_model_path, vision_tower)):
        vision_tower = os.path.join(_model_path, vision_tower)

    from llava.model.multimodal_encoder.clip_encoder import CLIPVisionTower
    is_absolute_path_exists = os.path.exists(vision_tower)
    if is_absolute_path_exists or vision_tower.startswith('openai') or vision_tower.startswith(
            'laion') or 'ShareGPT4V' in vision_tower:
        return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)

    raise ValueError(f'Unknown vision tower: {vision_tower}')


@contextmanager
def init_yi_model():
    origin_func_path = [
        'llava.model.multimodal_projector.builder.build_vision_projector',
        'llava.model.multimodal_encoder.builder.build_vision_tower'
    ]
    rewrite_func = [_build_vision_projector, _build_vision_tower]
    with rewrite_ctx(origin_func_path, rewrite_func):
        yield


@VISION_MODELS.register_module()
class YiVisionModel(LlavaVisionModel):
    """Yi visual model."""

    @classmethod
    def match(cls, config: AutoConfig):
        """Check whether the config match the model."""
        arch = config.architectures[0] if config.architectures else None
        if arch == 'LlavaLlamaForCausalLM':
            projector_type = getattr(config, 'mm_projector_type', 'linear')
            if '_Norm' in projector_type:
                return True
        return False

    def build_preprocessor(self):
        from transformers import CLIPImageProcessor
        vision_tower_name = osp.join(self.model_path, self.hf_config.mm_vision_tower)
        self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
        config = AutoConfig.from_pretrained(vision_tower_name)
        image_size = config.image_size
        patch_size = config.patch_size
        self.n_token_per_image = (image_size // patch_size)**2
        if self.hf_config.mm_vision_select_feature == 'cls_patch':
            self.n_token_per_image += 1

    def build_model(self):
        """Build the vision part of a VLM model when backend is turbomind, or
        load the whole VLM model when `self.with_llm==True`"""
        check_llava_install()

        global _model_path
        _model_path = self.model_path

        with init_yi_model(), disable_transformers_logging():
            super().build_model()

    def preprocess(self, messages: List[Dict]) -> List[Dict]:
        """Refer to `super().preprocess() for spec."""
        images = self.collect_multimodal_items(messages)
        outputs = []
        for modality, image, params in images:
            image = image.convert('RGB')
            pixel_values = process_images([image], self.image_processor, self.config)
            outputs.append(
                dict(pixel_values=pixel_values,
                     image_size=image.size,
                     image_tokens=self.n_token_per_image,
                     image_token_id=self.image_token_id))
        messages.append(dict(role='preprocess', content=outputs))
        return messages


================================================
FILE: lmdeploy/vl/tools/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.


================================================
FILE: lmdeploy/vl/tools/merge_xcomposer2d5_task.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil

import fire
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def main(src_path: str, dst_path: str, task: str):
    """Merge internlm-xcomposer2d5-7b LoRA model weights.

    Args:
        src_path (str): the source model path of internlm-xcomposer2d5-7b
        dst_path (str): the target model path of merged model
        task (str): the task of source model, should choose from
            ['web', 'write']
    """
    if os.path.exists(dst_path):
        shutil.rmtree(dst_path)

    to_merged = dict(web=['lora_web'], write=['lora_sft', 'lora_dpo'])
    keys = to_merged[task]

    # load model
    model = AutoModelForCausalLM.from_pretrained(src_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(src_path, trust_remote_code=True)

    # merge lora weight to base model
    @torch.inference_mode
    def _merge(module: torch.nn.Module, lora_weights):
        # merge lora weight first to reduce precision loss
        mw = None
        for wa, wb in lora_weights:
            if mw is None:
                mw = (wb.float() @ wa.float())
            else:
                mw += (wb.float() @ wa.float())
        ow = module.weight
        mw += ow.float()
        module.weight.data = mw.half()

    def _extract_lora(module: torch.nn.Module, keys: str):
        lora_weights = []
        for key in keys:
            lora_a_key = f'{key}_A'
            lora_b_key = f'{key}_B'
            wa = getattr(module, lora_a_key).weight
            wb = getattr(module, lora_b_key).weight
            lora_weights.append((wa, wb))
        return lora_weights

    for _, module in tqdm(model.named_modules()):
        if type(module).__name__ == 'PLoRA':
            lora_weights = _extract_lora(module, keys)
            _merge(module, lora_weights)

    # save model
    model.save_pretrained(dst_path, torch_dtype=torch.half)
    tokenizer.save_pretrained(dst_path)


if __name__ == '__main__':
    fire.Fire(main)


================================================
FILE: lmdeploy/vl/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Tuple

import numpy.typing as npt
from PIL import Image

from .media.connection import load_from_url
from .media.image import ImageMediaIO
from .media.time_series import TimeSeriesMediaIO
from .media.video import VideoMediaIO


def load_image(image_url: str, **kwargs) -> Image.Image:
    """Fetch and decode an image from a URL, path, or base64 string."""
    image_io = ImageMediaIO(**kwargs)
    return load_from_url(image_url, image_io)


def load_video(video_url: str, **kwargs) -> Tuple[npt.NDArray, Dict[str, Any]]:
    """Fetch and decode video frames from a URL, path, or base64 string."""
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io=image_io, **kwargs)
    return load_from_url(video_url, video_io)


def load_time_series(ts_url: str, **kwargs) -> npt.NDArray:
    """Fetch and decode time-series from a URL or path or base64 string.."""
    ts_io = TimeSeriesMediaIO(**kwargs)
    return load_from_url(ts_url, ts_io)


def encode_image_base64(image: str | Image.Image, format: str = 'PNG', **kwargs) -> str:
    """Encode image (path or PIL image) to a base64 string."""
    if isinstance(image, str):
        image = load_image(image, **kwargs)
    image_io = ImageMediaIO(**kwargs)
    return image_io.encode_base64(image, image_format=format)


def encode_video_base64(video: str | npt.NDArray, format: str = 'JPEG', **kwargs) -> str:
    """Encode video (path or frames) to a base64 string."""
    if isinstance(video, str):
        video, _ = load_video(video, **kwargs)
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io=image_io, **kwargs)
    return video_io.encode_base64(video, video_format=format)


def encode_time_series_base64(data: str | npt.NDArray, **kwargs) -> str:
    """Encode time-series (path or numpy array) to a base64 string."""
    if isinstance(data, str):
        data = load_time_series(data, **kwargs)
    ts_io = TimeSeriesMediaIO(**kwargs)
    return ts_io.encode_base64(data)


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = [
    "cmake_build_extension",
]
build-backend = "setuptools.build_meta"


================================================
FILE: setup.py
================================================
import os
import re
import subprocess
import sys
from pathlib import Path

from setuptools import find_packages, setup

pwd = os.path.dirname(__file__)
version_file = 'lmdeploy/version.py'


def get_target_device():
    return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda')


def readme():
    with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
        content = f.read()
    return content


def get_version():
    file_path = os.path.join(pwd, version_file)
    pattern = re.compile(r"\s*__version__\s*=\s*'([0-9A-Za-z.-]+)'")
    with open(file_path, 'r') as f:
        for line in f:
            m = pattern.match(line)
            if m:
                return m.group(1)
        else:
            assert False, f'No version found {file_path}'


def get_turbomind_deps():
    if os.name == 'nt':
        return []

    CUDA_COMPILER = os.getenv('CUDACXX', os.getenv('CMAKE_CUDA_COMPILER', 'nvcc'))
    nvcc_output = subprocess.check_output([CUDA_COMPILER, '--version'], stderr=subprocess.DEVNULL).decode()
    CUDAVER, = re.search(r'release\s+(\d+).', nvcc_output).groups()
    if int(CUDAVER) >= 13:
        return [
            f'nvidia-nccl-cu{CUDAVER}',
            'nvidia-cuda-runtime',
            'nvidia-cublas',
            'nvidia-curand',
        ]
    else:
        return [
            f'nvidia-nccl-cu{CUDAVER}',
            f'nvidia-cuda-runtime-cu{CUDAVER}',
            f'nvidia-cublas-cu{CUDAVER}',
            f'nvidia-curand-cu{CUDAVER}',
        ]


def parse_requirements(fname='requirements.txt', with_version=True):
    """Parse the package dependencies listed in a file but strips specific
    versioning information.

    Args:
        fname (str): path to the file
        with_version (bool, default=False): if True include version specs

    Returns:
        List[str]: list of requirements items

    CommandLine:
        python -c "import setup; print(setup.parse_requirements())"
    """
    require_fpath = fname

    def parse_line(line):
        """Parse information from a line in a requirements text file."""
        if line.startswith('-r '):
            # Allow specifying requirements in other files
            target = line.split(' ')[1]
            for info in parse_require_file(target):
                yield info
        else:
            info = {'line': line}
            if line.startswith('-e '):
                info['package'] = line.split('#egg=')[1]
            elif '@git+' in line:
                info['package'] = line
            else:
                # Remove versioning from the package
                pat = '(' + '|'.join(['>=', '==', '>']) + ')'
                parts = re.split(pat, line, maxsplit=1)
                parts = [p.strip() for p in parts]

                info['package'] = parts[0]
                if len(parts) > 1:
                    op, rest = parts[1:]
                    if ';' in rest:
                        # Handle platform specific dependencies
                        # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
                        version, platform_deps = map(str.strip, rest.split(';'))
                        info['platform_deps'] = platform_deps
                    else:
                        version = rest  # NOQA
                    info['version'] = (op, version)
            yield info

    def parse_require_file(fpath):
        with open(fpath, 'r') as f:
            for line in f.readlines():
                line = line.strip()
                if line and not line.startswith('#'):
                    for info in parse_line(line):
                        yield info

    def gen_packages_items():
        if os.path.exists(require_fpath):
            for info in parse_require_file(require_fpath):
                parts = [info['package']]
                if with_version and 'version' in info:
                    parts.extend(info['version'])
                if not sys.version.startswith('3.4'):
                    # apparently package_deps are broken in 3.4
                    platform_deps = info.get('platform_deps')
                    if platform_deps is not None:
                        parts.append(';' + platform_deps)
                item = ''.join(parts)
                yield item

    packages = list(gen_packages_items())

    return packages


if get_target_device() == 'cuda' and not os.getenv('DISABLE_TURBOMIND', '').lower() in ('yes', 'true', 'on', 't', '1'):
    import cmake_build_extension

    ext_modules = [
        cmake_build_extension.CMakeExtension(
            name='_turbomind',
            install_prefix='lmdeploy/lib',
            cmake_depends_on=['pybind11'],
            source_dir=str(Path(__file__).parent.absolute()),
            cmake_generator=None if os.name == 'nt' else 'Ninja',
            cmake_build_type=os.getenv('CMAKE_BUILD_TYPE', 'RelWithDebInfo'),
            cmake_configure_options=[
                f'-DPython3_ROOT_DIR={Path(sys.prefix)}',
                f'-DPYTHON_EXECUTABLE={Path(sys.executable)}',
                '-DCALL_FROM_SETUP_PY:BOOL=ON',
                '-DBUILD_SHARED_LIBS:BOOL=OFF',
                # Select the bindings implementation
                '-DBUILD_PY_FFI=ON',
                '-DBUILD_MULTI_GPU=' + ('OFF' if os.name == 'nt' else 'ON'),
                '-DUSE_NVTX=' + ('OFF' if os.name == 'nt' else 'ON'),
            ],
        ),
    ]
    extra_deps = get_turbomind_deps()
    cmdclass = dict(build_ext=cmake_build_extension.BuildExtension, )
else:
    ext_modules = []
    cmdclass = {}
    extra_deps = []

if __name__ == '__main__':
    setup(
        name='lmdeploy',
        version=get_version(),
        description='A toolset for compressing, deploying and serving LLM',
        long_description=readme(),
        long_description_content_type='text/markdown',
        author='OpenMMLab',
        author_email='openmmlab@gmail.com',
        packages=find_packages(exclude=()),
        include_package_data=True,
        setup_requires=parse_requirements('requirements/build.txt'),
        tests_require=parse_requirements('requirements/test.txt'),
        install_requires=parse_requirements(f'requirements/runtime_{get_target_device()}.txt') + extra_deps,
        extras_require={
            'all': parse_requirements(f'requirements_{get_target_device()}.txt'),
            'lite': parse_requirements('requirements/lite.txt'),
            'serve': parse_requirements('requirements/serve.txt'),
        },
        classifiers=[
            'Programming Language :: Python :: 3.10',
            'Programming Language :: Python :: 3.11',
            'Programming Language :: Python :: 3.12',
            'Programming Language :: Python :: 3.13',
            'Intended Audience :: Developers',
            'Intended Audience :: Education',
            'Intended Audience :: Science/Research',
        ],
        entry_points={'console_scripts': ['lmdeploy = lmdeploy.cli:run']},
        ext_modules=ext_modules,
        cmdclass=cmdclass,
    )


================================================
FILE: src/CMakeLists.txt
================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(turbomind)


================================================
FILE: src/turbomind/CMakeLists.txt
================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(utils)
add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(comm)
add_subdirectory(generation)
add_subdirectory(models)
add_subdirectory(engine)

if(BUILD_PY_FFI)
    add_subdirectory(python)
endif()

add_library(turbomind STATIC turbomind.cc)
set_property(TARGET turbomind PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(turbomind PUBLIC
        engine
        models
        device_comm
        host_comm
        core
        memory_utils
        nvtx_utils
        CUDA::cublasLt
        CUDA::cudart
        yaml-cpp::yaml-cpp)


================================================
FILE: src/turbomind/comm/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)

find_package(Threads)

add_library(host_comm STATIC host_comm.cc thread_comm.cc)
target_link_libraries(host_comm PRIVATE core logger Threads::Threads)
set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON)

add_library(device_comm STATIC device_comm.cc)
target_link_libraries(device_comm PRIVATE core logger)
set_property(TARGET device_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET device_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

if (BUILD_MULTI_GPU)
    add_subdirectory(cuda_ipc)
    target_link_libraries(device_comm INTERFACE cuda_ipc_comm)

    if (USE_NCCL)
        add_subdirectory(nccl)
        target_link_libraries(device_comm INTERFACE nccl_comm)
    endif ()

    add_subdirectory(gloo)
    target_link_libraries(host_comm INTERFACE gloo_comm)

    if (BUILD_TEST)
        add_executable(test_comm test_comm.cu)
        target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
        target_compile_options(test_comm PRIVATE -march=native -mtune=native)

        add_executable(test_host_comm test_host_comm.cc)
        target_link_libraries(test_host_comm PRIVATE host_comm core Threads::Threads)
    endif ()
endif ()


================================================
FILE: src/turbomind/comm/barrier.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#if defined(_MSC_VER) && !defined(__clang__)

#include 
#include 
#include 

namespace turbomind::comm {

class Barrier {
public:
    explicit Barrier(int count): threshold_{count}, count_{count} {}

    void arrive_and_wait()
    {
        std::unique_lock lock{mutex_};
        auto             phase = phase_;
        if (--count_ == 0) {
            ++phase_;
            count_ = threshold_;
            cv_.notify_all();
        }
        else {
            cv_.wait(lock, [this, phase] { return phase_ != phase; });
        }
    }

private:
    std::mutex              mutex_;
    std::condition_variable cv_;

    int threshold_;
    int count_;

    uint32_t phase_{};
};

}  // namespace turbomind::comm

#else

#include 

namespace turbomind::comm {

class Barrier {
public:
    explicit Barrier(int count): barrier_{}
    {
        pthread_barrier_init(&barrier_, {}, count);
    }

    ~Barrier()
    {
        pthread_barrier_destroy(&barrier_);
    }

    void arrive_and_wait()
    {
        pthread_barrier_wait(&barrier_);
    }

private:
    pthread_barrier_t barrier_;
};

}  // namespace turbomind::comm

#endif


================================================
FILE: src/turbomind/comm/cuda_ipc/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)

add_library(cuda_ipc_comm STATIC
        cuda_ipc_comm.cu
        allreduce.cu
        allgather.cu
        fused_allreduce.cu
        fused_allreduce_ex.cu
        broadcast.cu)

target_link_libraries(cuda_ipc_comm PRIVATE
        rms_norm
        host_comm
        core
        cuda_utils
        CUDA::cuda_driver
        logger)

set_property(TARGET cuda_ipc_comm PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET cuda_ipc_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)


================================================
FILE: src/turbomind/comm/cuda_ipc/allgather.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/multimem.cuh"
#include "src/turbomind/comm/cuda_ipc/semaphore.cuh"

#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

__global__ void Barrier_V2(SystemSemaphoreInfo* semaphores, int ranks)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(true);
    sem.Wait(true);
    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

void CudaIpcCommImpl::Barrier(int group, cudaStream_t stream)
{
    const int ranks = n_ranks(group);
    Barrier_V2<<<1, ranks, 0, stream>>>(groups_.at(group).semaphore.handle(), ranks);
}

template
__global__ void __launch_bounds__(1024, 1) Allgather_Simple_Pull(
    Array uc, SystemSemaphoreInfo* semaphores, int rank, int ranks, int64_t slice, Relaxed relaxed)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    auto local = uc[rank];

    for (int i = 1; i < ranks; ++i) {
        const int p  = rank + i < ranks ? rank + i : rank + i - ranks;
        const T*  ch = cvta_generic_to_global(uc[p]);
        for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {
            local[slice * p + idx] = ch[slice * p + idx];
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void __launch_bounds__(1024, 1) Allgather_NVLS_V2(
    T* uc, T* mc, SystemSemaphoreInfo* semaphores, int rank, int ranks, int64_t slice, Relaxed relaxed)
{
#if TURBOMIND_ARCH_SM90
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {
        multimem_st(&mc[slice * rank + idx], uc[slice * rank + idx]);
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);
    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
#endif
}

void CudaIpcCommImpl::AllGather(
    const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream)
{
    const size_t bytesize = turbomind::byte_size(type) * sendcount;

    const int ranks = this->n_ranks(group);
    const int rank  = this->rank(group);

    auto semaphore = groups_.at(group).semaphore.handle();

    auto invoke = [&](auto t) {
        using T               = decltype(t);
        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);
        const size_t slice    = bytesize / sizeof(T);
        const int    threads  = 1024;
        if (symm_ptr.mc) {
            const int blocks = std::min(4, (slice + threads - 1) / threads);
            Allgather_NVLS_V2<<>>(
                symm_ptr.uc[rank], symm_ptr.mc, semaphore, rank, ranks, slice, std::false_type{});
        }
        else {
            const int blocks = std::min(max_ctas_.apply(32), (slice + threads - 1) / threads);
            Allgather_Simple_Pull
                <<>>(symm_ptr.uc, semaphore, rank, ranks, slice, std::false_type{});
        }
    };

    auto invoke_copy_engine = [&] {
        auto symm_ptr = get_symmetric_v2((char*)recvbuff, group);

        Barrier(group, stream);

        for (int i = 1; i < ranks; ++i) {
            const int p = (rank + i) % ranks;
            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[p] + rank * bytesize,  //
                                             (char*)recvbuff + rank * bytesize,
                                             bytesize,
                                             cudaMemcpyDefault,
                                             stream));
        }

        Barrier(group, stream);
    };

    if (bytesize < copy_threshold_) {
        if (bytesize % sizeof(uint4) == 0) {
            invoke(uint4{});
        }
        else if (bytesize % sizeof(uint2) == 0) {
            invoke(uint2{});
        }
        else if (bytesize % sizeof(uint) == 0) {
            invoke(uint{});
        }
        else {
            TM_CHECK(0) << "not implemented";
        }
    }
    else {
        invoke_copy_engine();
    }
}

template
__global__ void __launch_bounds__(1024, 1) Allgather2D_Simple_Pull(T*                   local,
                                                                   Array uc,
                                                                   SystemSemaphoreInfo* semaphores,
                                                                   int                  rank,
                                                                   int                  ranks,
                                                                   int64_t              pitch,
                                                                   int64_t              stride,
                                                                   int                  width,
                                                                   int                  height,
                                                                   int                  log2_groups,
                                                                   constant,
                                                                   Relaxed relaxed)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);

    const int log2_threads = log2_block_dim - log2_groups;
    const int threads      = 1 << log2_threads;
    const int groups       = 1 << log2_groups;

    const int gi = threadIdx.x >> log2_threads;
    const int di = (threadIdx.x & (threads - 1));
    const int bi = blockIdx.x * groups + gi;
    const int bn = gridDim.x * groups;

    sem.Wait(relaxed);

    __syncthreads();

    for (int i = 1; i < ranks; ++i) {
        const int     p      = rank + i < ranks ? rank + i : rank + i - ranks;
        const T*      ch     = cvta_generic_to_global(uc[p]);
        const int64_t offset = stride * p;
        for (int x = di; x < width; x += threads) {
            for (int y = bi; y < height; y += bn) {
                local[offset + y * pitch + x] = ch[offset + y * pitch + x];
            }
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void __launch_bounds__(1024, 1) Allgather2D_NVLS_V2(T*                   uc_buf,
                                                               T*                   mc_buf,
                                                               SystemSemaphoreInfo* semaphores,
                                                               int                  rank,
                                                               int                  ranks,
                                                               int64_t              pitch,
                                                               int64_t              stride,
                                                               int                  width,
                                                               int                  height,
                                                               int                  log2_groups,
                                                               constant,
                                                               Relaxed relaxed)
{

#if TURBOMIND_ARCH_SM90

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    const int log2_threads = log2_block_dim - log2_groups;
    const int threads      = 1 << log2_threads;
    const int groups       = 1 << log2_groups;

    const int gi = threadIdx.x >> log2_threads;
    const int di = (threadIdx.x & (threads - 1));
    const int bi = blockIdx.x * groups + gi;
    const int bn = gridDim.x * groups;

    __syncthreads();

    const int64_t offset = stride * rank;
    for (int y = bi; y < height; y += bn) {
        for (int x = di; x < width; x += threads) {
            const int64_t idx = offset + y * pitch + x;
            multimem_st(&mc_buf[idx], uc_buf[idx]);
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
#endif
}

void CudaIpcCommImpl::AllGather2D(const void*  sendbuff,
                                  void*        recvbuff,
                                  size_t       pitch,
                                  size_t       stride,
                                  int          width,
                                  int          height,
                                  DataType     type,
                                  int2         flags,
                                  int          group,
                                  cudaStream_t stream)
{
    const size_t byte_width  = byte_size(type, width);
    const size_t byte_pitch  = byte_size(type, pitch);
    const size_t byte_stride = byte_size(type, stride);

    const size_t nbytes = byte_width * height;

    const int ranks = this->n_ranks(group);
    const int rank  = this->rank(group);

    TM_CHECK_EQ((char*)sendbuff, (char*)recvbuff + rank * byte_stride);

    auto semaphore = groups_.at(group).semaphore.handle();

    auto invoke = [&](auto t) {
        using T = decltype(t);

        const int threads     = 1024;
        int       log2_groups = 0;
        while ((threads * sizeof(T) >> log2_groups) > byte_width * 2) {
            ++log2_groups;
        }
        const int groups = 1 << log2_groups;

        auto symm_ptr = get_symmetric_v2((T*)recvbuff, group);

        if (symm_ptr.mc) {
            const int blocks = std::min(4, (height + groups - 1) >> log2_groups);
            Allgather2D_NVLS_V2<<>>((T*)recvbuff,
                                                                   symm_ptr.mc,
                                                                   semaphore,
                                                                   rank,
                                                                   this->n_ranks(group),
                                                                   byte_pitch / sizeof(T),
                                                                   byte_stride / sizeof(T),
                                                                   byte_width / sizeof(T),
                                                                   height,
                                                                   log2_groups,
                                                                   constant<10>{},
                                                                   std::true_type{});
        }
        else {
            const int blocks = std::min(max_ctas_.apply(48), (height + groups - 1) >> log2_groups);
            Allgather2D_Simple_Pull<<>>((T*)recvbuff,  //
                                                                       symm_ptr.uc,
                                                                       semaphore,
                                                                       rank,
                                                                       ranks,
                                                                       byte_pitch / sizeof(T),
                                                                       byte_stride / sizeof(T),
                                                                       byte_width / sizeof(T),
                                                                       height,
                                                                       log2_groups,
                                                                       constant<10>{},
                                                                       std::true_type{});
        }
    };

    auto invoke_copy_engine = [&] {
        auto symm_ptr = get_symmetric_v2((char*)recvbuff, group);

        Barrier(group, stream);

        for (int i = 1; i < ranks; ++i) {
            const int p = (rank + i) % ranks;
            check_cuda_error(cudaMemcpy2DAsync(symm_ptr.uc[p] + rank * byte_stride,
                                               byte_pitch,
                                               (char*)recvbuff + rank * byte_stride,
                                               byte_pitch,
                                               byte_width,
                                               height,
                                               cudaMemcpyDefault,
                                               stream));
        }

        Barrier(group, stream);
    };

    if (nbytes < copy_threshold_) {
        if (byte_width % sizeof(uint4) == 0) {
            invoke(uint4{});
        }
        else if (byte_width % sizeof(uint2) == 0) {
            invoke(uint2{});
        }
        else if (byte_width % sizeof(uint) == 0) {
            invoke(uint{});
        }
        else {
            TM_CHECK(0) << "not implemented";
        }
    }
    else {
        invoke_copy_engine();
    }
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/allreduce.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/mscclpp.h"
#include "src/turbomind/comm/cuda_ipc/multimem.cuh"
#include "src/turbomind/comm/cuda_ipc/semaphore.cuh"

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

using mscclpp::LLPacket;

// reduce-scatter + allgather using LL16Packet
template
__global__ void __launch_bounds__(1024, 1) Allreduce_LL16_V2(T*                          dst,
                                                             const T*                    src,
                                                             LLPacket*                   incoming,
                                                             Array outgoing,
                                                             int                         rank,
                                                             int                         ranks,
                                                             int                         slice,  // padded slice
                                                             int                         count,  // actual count
                                                             uint32_t                    flag,
                                                             CtasPerPeer                 ctas_per_peer)
{

    constexpr int vec_size = sizeof(uint2) / sizeof(T);

    using Vec = Array;

    const int bi = blockIdx.x % ctas_per_peer;
    const int p  = [&, i = blockIdx.x / ctas_per_peer + 1] { return rank + i < ranks ? rank + i : rank + i - ranks; }();
    const int n  = min(count, p * slice + slice) - p * slice;

    {  // send slice of `src` to peers  (src -> packet0)
        auto chn = outgoing[p] + rank * slice;
        for (int idx = threadIdx.x + bi * blockDim.x; idx < n; idx += ctas_per_peer * blockDim.x) {
            chn[idx].write(*((const uint2*)src + p * slice + idx), flag);
        }
    }

    // device-wide barrier not required as what we are sending is not what we are going to modify

    {  // recv data | reduce | send results (src -> packet0 -> packet1)
        using namespace ops;
        const int n = min(count, rank * slice + slice) - rank * slice;
        for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < n; idx += blockDim.x * gridDim.x) {
            Vec vec;
            Load(vec, src + (rank * slice + idx) * vec_size);
            for (int i = 1; i < ranks; ++i) {
                const int p    = rank + i < ranks ? rank + i : rank + i - ranks;
                uint2     data = incoming[p * slice + idx].read(flag);
                vec            = vec + (Vec&)data;
            }
            Store(dst + (rank * slice + idx) * vec_size, vec);
            for (int i = 1; i < ranks; ++i) {
                const int p = rank + i < ranks ? rank + i : rank + i - ranks;
                outgoing[p][(ranks + rank) * slice + idx].write((uint2&)vec, flag);
            }
        }
    }

    {  // recv results (packet1 -> dst)
        incoming += (ranks + p) * slice;
        dst += p * slice * vec_size;
        // ! note that `dst` MUST have same partition as we are sending `src`
        for (int idx = threadIdx.x + bi * blockDim.x; idx < n; idx += ctas_per_peer * blockDim.x) {
            uint2 data = incoming[idx].read(flag);
            Store(dst + idx * vec_size, (Vec&)data);
        }
    }
}

// Modified from
// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/test/mscclpp-test/allreduce_test.cu#L963
template
__global__ void Allreduce_Simple_Pull(T*                   buf,
                                      Array chns,
                                      SystemSemaphoreInfo* semaphores,
                                      int                  rank,
                                      int                  ranks,
                                      int                  slice,
                                      int                  count,
                                      constant,
                                      Relaxed relaxed)
{
    const int block_num  = gridDim.x;
    const int thread_num = blockDim.x * block_num;
    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    using Vec = Array;

    using namespace ops;

    const int first = rank * slice;
    const int last  = min(count, first + slice);

    for (int i = 1; i < ranks; ++i) {
        const int p   = rank + i < ranks ? rank + i : rank + i - ranks;
        auto      chn = cvta_generic_to_global(chns[p]);
        for (int idx = first + thread_idx; idx < last; idx += thread_num) {
            Vec acc, tmp;
            Load(tmp, chn + idx * vec_size);
            Load(acc, buf + idx * vec_size);
            acc = acc + tmp;
            Store(buf + idx * vec_size, acc);
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    for (int i = 1; i < ranks; ++i) {
        const int p     = rank + i < ranks ? rank + i : rank + i - ranks;
        const int first = p * slice;
        const int last  = min(count, first + slice);
        auto      chn   = cvta_generic_to_global(chns[p]);
        for (int idx = first + thread_idx; idx < last; idx += thread_num) {
            Vec vec;
            Load(vec, chn + idx * vec_size);
            Store(buf + idx * vec_size, vec);
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void Allreduce_Simple_Push_v3(T*                   buf,
                                         T*                   scratch,
                                         Array symm_buf,
                                         Array symm_scratch,
                                         SystemSemaphoreInfo* semaphores,
                                         int                  rank,
                                         int                  ranks,
                                         int                  slice,  // in vec
                                         int                  count,  // in vec
                                         constant,
                                         Relaxed relaxed)
{
    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;
    const int thread_num = blockDim.x * gridDim.x;

    using Vec = Array;

    for (int i = 1; i < ranks; ++i) {
        const int p = rank + i < ranks ? rank + i : rank + i - ranks;
        const int n = min(count, p * slice + slice) - p * slice;
        for (int idx = thread_idx; idx < n; idx += thread_num) {
            Vec vec;
            Load(vec, buf + (p * slice + idx) * vec_size);
            Store(symm_scratch[p] + (rank * slice + idx) * vec_size, vec);
        }
    }

    __syncthreads();

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    using namespace ops;
    const int n = min(count, rank * slice + slice) - rank * slice;
    for (int idx = thread_idx; idx < n; idx += thread_num) {
        Vec acc;
        Load(acc, buf + (rank * slice + idx) * vec_size);
        for (int i = 1; i < ranks; ++i) {
            const int p = rank + i < ranks ? rank + i : rank + i - ranks;
            Vec       tmp;
            Load(tmp, scratch + (p * slice + idx) * vec_size);
            acc = acc + tmp;
        }
        Store(buf + (rank * slice + idx) * vec_size, acc);
        for (int i = 1; i < ranks; ++i) {
            const int p = rank + i < ranks ? rank + i : rank + i - ranks;
            Store(symm_buf[p] + (rank * slice + idx) * vec_size, acc);
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void Allreduce_NVLS_V2(
    T* mc_buf, SystemSemaphoreInfo* semaphores, int ranks, int first, int last, constant, Relaxed relaxed)
{
#if TURBOMIND_ARCH_SM90
    const int block_num  = gridDim.x;
    const int thread_num = blockDim.x * block_num;
    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    using Vec = Array;

    using namespace ops;

    for (int idx = first + thread_idx; idx < last; idx += thread_num) {
        Vec vsum = multimem_ld_reduce_sum((const Vec*)(mc_buf + idx * vec_size));
        multimem_st(mc_buf + idx * vec_size, vsum);
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
#endif
}

void CudaIpcCommImpl::AllReduceSum(
    const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream)
{
    FT_CHECK(sendbuff == recvbuff);

    void* data = recvbuff;

    const int n_ranks = this->n_ranks(group);
    const int rank    = this->rank(group);

    auto semaphore = groups_.at(group).semaphore.handle();

    auto invoke = [&](auto t) {
        using T               = decltype(t);
        const size_t bytesize = sizeof(T) * count;

        auto symm_ptr = get_symmetric_v2((T*)data, group);

        if (symm_ptr.mc) {
            constexpr int vec_size = sizeof(uint4) / sizeof(T);
            constexpr int threads  = 1024;
            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;
            const int     first    = rank * slice;
            const int     last     = std::min(count / vec_size, first + slice);
            const int     max_ctas = max_ctas_.apply(8);
            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);
            Allreduce_NVLS_V2<<>>(symm_ptr.mc,  //
                                                              semaphore,
                                                              n_ranks,
                                                              first,
                                                              last,
                                                              constant{},
                                                              std::false_type{});
        }
#if 1
        else if (round_up(bytesize, 2 * n_ranks * sizeof(LLPacket)) <= std::min(1 << 20, kPacketBuffSize)) {
            constexpr int vec_size      = sizeof(uint2) / sizeof(T);
            const int     slice         = (count / vec_size + n_ranks - 1) / n_ranks;
            constexpr int ctas_per_peer = 4;
            constexpr int threads       = 1024;
            const int     blocks        = (n_ranks - 1) * ctas_per_peer;
            auto          incoming      = (LLPacket*)packet_buff_;
            auto          outgoing      = get_symmetric_v2(incoming, group).uc;
            Allreduce_LL16_V2<<>>((T*)data,  //
                                                              (T*)data,
                                                              incoming,
                                                              outgoing,
                                                              rank,
                                                              n_ranks,
                                                              slice,
                                                              count / vec_size,
                                                              flag_++,
                                                              constant{});
        }
#endif
        else if (round_up(bytesize, n_ranks * sizeof(uint4)) <= std::min(6 << 20, kScratchBuffSize)) {
            constexpr int vec_size = sizeof(uint4) / sizeof(T);
            constexpr int threads  = 1024;
            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;
            const int     max_ctas = max_ctas_.apply(48);
            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);
            Allreduce_Simple_Push_v3<<>>((T*)data,
                                                                     (T*)scratch_buff_,
                                                                     symm_ptr.uc,
                                                                     get_symmetric_v2((T*)scratch_buff_, group).uc,
                                                                     semaphore,
                                                                     rank,
                                                                     n_ranks,
                                                                     slice,
                                                                     count / vec_size,
                                                                     constant{},
                                                                     std::false_type{});
        }
        else {
            constexpr int vec_size = sizeof(uint4) / sizeof(T);
            constexpr int threads  = 1024;
            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;
            const int     max_ctas = max_ctas_.apply(48);
            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);
            Allreduce_Simple_Pull<<>>((T*)data,
                                                                  symm_ptr.uc,
                                                                  semaphore,
                                                                  rank,
                                                                  n_ranks,
                                                                  slice,
                                                                  count / vec_size,
                                                                  constant{},
                                                                  std::false_type{});
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(type, invoke);
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/bootstrap.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include "src/turbomind/comm/barrier.h"
#include "src/turbomind/comm/device_comm.h"

namespace turbomind::comm {

// Inspired by
// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/include/mscclpp/core.hpp#L31
class LocalBootstrap {
public:
    struct State {

        explicit State(int n): num(n), barrier(n), ptrs(n), queues(n * n)
        {
            for (int i = 0; i < n; ++i) {
                mutexes.emplace_back();
            }
        }

        using Queue = std::queue>;

        Queue& get_que(int from, int to)
        {
            return queues[from * num + to];
        }

        int num;

        comm::Barrier barrier;

        std::vector     ptrs;
        std::deque mutexes;
        std::vector     queues;
    };

    LocalBootstrap(int world_size, int rank, std::shared_ptr state):
        world_size_{world_size}, rank_{rank}, state_{state}
    {
    }

    int getRank()
    {
        return rank_;
    }

    int getNranks()
    {
        return world_size_;
    }

    int getNranksPerNode()
    {
        return world_size_;
    }

    void send(void* data, int size, int peer, int tag)
    {
        // std::cerr << "send " << size << " " << rank_ << " -> " << peer << " " << tag << "\n";
        std::lock_guard lock{state_->mutexes[peer]};
        auto&           que = state_->get_que(rank_, peer);
        que.push(std::vector((uint8_t*)data, (uint8_t*)data + size));
    }

    void recv(void* data, int size, int peer, int tag)
    {
        // std::cerr << "recv " << size << " " << rank_ << " <- " << peer << " " << tag << "\n";
        auto& que = state_->get_que(peer, rank_);
        while (true) {
            {
                std::lock_guard lock{state_->mutexes[rank_]};
                if (!que.empty()) {
                    FT_CHECK(que.front().size() == (size_t)size);
                    std::copy_n(que.front().begin(), size, (uint8_t*)data);
                    que.pop();
                    return;
                }
            }
            std::this_thread::yield();
        }
    }

    void allGather(void* allData, int size)
    {
        barrier();

        state_->ptrs[rank_] = allData;

        barrier();

        for (int i = 0; i < world_size_; ++i) {
            if (i == rank_) {
                continue;
            }
            const auto offset = i * (size_t)size;
            std::copy_n((uint8_t*)state_->ptrs[i] + offset, size, (uint8_t*)allData + offset);
        }

        barrier();
    }

    void barrier()
    {
        state_->barrier.arrive_and_wait();
    }

private:
    int world_size_;
    int rank_;

    std::shared_ptr state_;
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/broadcast.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/multimem.cuh"
#include "src/turbomind/comm/cuda_ipc/semaphore.cuh"

#include "src/turbomind/comm/cuda_ipc/semaphore.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

template
__global__ void __launch_bounds__(1024, 1) Broadcast_NVLS_V2(const T*             uc,
                                                             T*                   mc,
                                                             SystemSemaphoreInfo* semaphores,
                                                             int                  rank,
                                                             int                  ranks,
                                                             int                  root,
                                                             int64_t              slice,
                                                             int64_t              count,
                                                             Relaxed              relaxed)
{

#if TURBOMIND_ARCH_SM90
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    int64_t first = rank * slice;
    int64_t last  = min(first + slice, count);

    for (int64_t idx = first + threadIdx.x + blockIdx.x * blockDim.x; idx < last; idx += blockDim.x * gridDim.x) {
        multimem_st(&mc[idx], uc[idx]);
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);
    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
#endif
}

template
__global__ void __launch_bounds__(1024, 1) Broadcast_Simple_Pull(Array uc,
                                                                 SystemSemaphoreInfo* semaphores,
                                                                 int                  rank,
                                                                 int                  ranks,
                                                                 int                  root,
                                                                 int64_t              slice,
                                                                 Relaxed              relaxed)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    auto dst = uc[rank];
    auto src = uc[root];

    if (rank != root) {
        for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {
            dst[idx] = src[idx];
        }
        __syncthreads();
    }

    sem.Signal(relaxed);
    sem.Wait(relaxed);
    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void __launch_bounds__(1024, 1) Broadcast_Simple_V2(Array uc,
                                                               SystemSemaphoreInfo* semaphores,
                                                               int                  index,
                                                               int                  rank,
                                                               int                  ranks,
                                                               int                  root,
                                                               int64_t              slice,
                                                               int64_t              count,
                                                               Relaxed              relaxed)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);
    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    auto dst = uc[rank];
    auto src = uc[root];

    int64_t first = index * slice;
    int64_t last  = min(first + slice, count);

    if (rank != root) {
        for (int64_t idx = first + threadIdx.x + blockIdx.x * blockDim.x; idx < last; idx += blockDim.x * gridDim.x) {
            dst[idx] = src[idx];
            for (int i = 0; i < ranks; ++i) {
                int p = rank + i < ranks ? rank + i : rank + i - ranks;
                if (p != root) {
                    uc[p][idx] = dst[idx];
                }
            }
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);
    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

void CudaIpcCommImpl::Broadcast(const void*  sendbuff,  //
                                void*        recvbuff,
                                size_t       count,
                                DataType     type,
                                int          root,
                                int          group,
                                cudaStream_t stream)
{

    const int rank  = this->rank(group);
    const int ranks = this->n_ranks(group);

    const size_t bytesize = turbomind::byte_size(type, count);

    auto semaphore = groups_.at(group).semaphore.handle();

    const int algo = 5;

    if (algo == 0) {
        Barrier(group, stream);
        if (rank != root) {
            SymmetricPtr_V2 symm_ptr = get_symmetric_v2((char*)recvbuff, group);
            check_cuda_error(cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[root], bytesize, cudaMemcpyDefault, stream));
        }
        Barrier(group, stream);
    }
    else if (algo == 1) {
        const int    slices = 16;
        const size_t slice  = bytesize / slices;
        TM_CHECK(bytesize % slices == 0);
        TM_CHECK_EQ(root, 0);
        SymmetricPtr_V2 symm_ptr = get_symmetric_v2((char*)recvbuff, group);
        for (int i = 1; i <= ranks + slices - 2; ++i) {
            Barrier(group, stream);
            int s = i - rank;
            if (0 <= s && s < slices && rank != root) {
                check_cuda_error(cudaMemcpyAsync(
                    (char*)recvbuff + s * slice, symm_ptr.uc[rank - 1] + s * slice, slice, cudaMemcpyDefault, stream));
            }
        }
        Barrier(group, stream);
    }
    else if (algo == 2) {
        SymmetricPtr_V2 symm_ptr = get_symmetric_v2((char*)recvbuff, group);
        TM_CHECK_EQ(ranks, 8);
        TM_CHECK_EQ(root, 0);
        Barrier(group, stream);
        if (rank == 4) {
            check_cuda_error(
                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 4], bytesize, cudaMemcpyDefault, stream));
        }
        Barrier(group, stream);
        if (rank == 2 || rank == 6) {
            check_cuda_error(
                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 2], bytesize, cudaMemcpyDefault, stream));
        }
        Barrier(group, stream);
        if (rank & 1) {
            check_cuda_error(
                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 1], bytesize, cudaMemcpyDefault, stream));
        }
        Barrier(group, stream);
    }
    else if (algo == 3) {
        using T               = uint4;
        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);
        const size_t count    = bytesize / sizeof(T);
        const size_t slice    = cdiv(count, ranks);
        const int    threads  = 1024;
        const int    blocks   = std::min(2, (slice + threads - 1) / threads);
        Broadcast_NVLS_V2<<>>(
            symm_ptr.uc[root], symm_ptr.mc, semaphore, rank, ranks, root, slice, count, std::true_type{});
    }
    else if (algo == 4) {
        using T               = uint4;
        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);
        const size_t slice    = bytesize / sizeof(T);
        const int    threads  = 1024;
        const int    blocks   = std::min(32, (slice + threads - 1) / threads);
        Broadcast_Simple_Pull
            <<>>(symm_ptr.uc, semaphore, rank, ranks, root, slice, std::false_type{});
    }
    else if (algo == 5) {
        using T               = uint4;
        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);
        const size_t count    = bytesize / sizeof(T);
        const int    peers    = ranks - 1;
        const size_t slice    = (count + peers - 1) / peers;
        const int    threads  = 1024;
        const int    blocks   = std::min(32, (slice + threads - 1) / threads);
        const int    index    = rank >= root ? rank - 1 : rank;
        Broadcast_Simple_V2<<>>(
            symm_ptr.uc, semaphore, index, rank, ranks, root, slice, count, std::false_type{});
    }
    else if (algo == 6) {
        TM_CHECK_EQ(ranks, 8);
        TM_CHECK_EQ(root, 0);
        const auto   symm_ptr = get_symmetric_v2((char*)recvbuff, group);
        const size_t count    = bytesize;
        const size_t slice    = cdiv(count, ranks);

        // 0->4
        // 0->2,       4->6
        // 0->1, 2->3, 4->5, 6->7
        Barrier(group, stream);
        if (rank == 0) {
            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 4] + slice * (rank + 4),
                                             symm_ptr.uc[rank + 0] + slice * (rank + 4),
                                             slice * 4,
                                             cudaMemcpyDefault,
                                             stream));
        }
        Barrier(group, stream);
        if (rank == 0 || rank == 4) {
            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 2] + slice * (rank + 2),
                                             symm_ptr.uc[rank + 0] + slice * (rank + 2),
                                             slice * 2,
                                             cudaMemcpyDefault,
                                             stream));
        }
        Barrier(group, stream);
        if (rank % 2 == 0) {
            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 1] + slice * (rank + 1),
                                             symm_ptr.uc[rank + 0] + slice * (rank + 1),
                                             slice * 1,
                                             cudaMemcpyDefault,
                                             stream));
        }
        Barrier(group, stream);
        for (int i = 1; i < ranks; ++i) {
            const int p = (rank + i) % ranks;
            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[p] + rank * slice,  //
                                             (char*)recvbuff + rank * slice,
                                             slice,
                                             cudaMemcpyDefault,
                                             stream));
        }
        Barrier(group, stream);
    }
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/common.h
================================================
#pragma once

#include "src/turbomind/kernels/core/array.h"

namespace turbomind::comm {

inline constexpr int kMaxRanks        = 8;
static constexpr int kPacketBuffSize  = 8 << 20;  // 8 MB
static constexpr int kScratchBuffSize = 8 << 20;  // 8 MB
static constexpr int kMaxChannels     = 64;

template
struct SymmetricPtr_V2 {
    Array uc;
    T*                   mc;
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/cuda_ipc_comm.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

#include 

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/kernels/core/math.h"

#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/comm/env.h"
#include "src/turbomind/comm/host_comm.h"

#include "src/turbomind/comm/cuda_ipc/semaphore.h"

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind::comm {

TM_ENV_VAR(COMM, MAX_CTAS, 0);
TM_ENV_VAR(COMM, NVLS_ENABLE, 1);
// per-rank send size threshold to use copy engine instead of p2p for all-gather colls
TM_ENV_VAR(COMM, COPY_THRESHOLD, INT64_MAX);

int CudaIpcCommImpl::Split(int color, int key, int group)
{
    FT_CHECK(color >= 0);
    FT_CHECK(rank(group) >= 0);

    auto& parent = groups_.at(group);

    auto vec = comm::AllGather(h_comm_, std::make_tuple(color, key, parent.g2l[global_rank_]));

    auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) {  //
        return std::get<0>(x) == color;
    });
    vec.erase(last, vec.end());
    std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) {  //
        return a < b;
    });

    std::vector l2g;
    std::vector g2l(parent.g2l.size(), -1);

    for (size_t local = 0; local < vec.size(); ++local) {
        const auto r      = std::get<2>(vec[local]);
        int        global = parent.l2g.at(r);
        l2g.push_back(global);
        g2l[global] = local;
    }

    int index = groups_.size();

    auto& g = groups_.emplace_back(Group{l2g, g2l});

    for (auto& a : allocation_) {
        Register(a, index);
    }

    g.semaphore.Allocate(l2g.size(), g2l[global_rank_], [&](size_t size) {
        auto buf = (uint64_t*)Allocate(size);
        check_cuda_error(cudaMemsetAsync(buf, 0, size));
        check_cuda_error(cudaStreamSynchronize(0));
        Register(buf, size);
        return get_symmetric_v2(buf, index);
    });

    return index;
};

CudaIpcCommImpl::CudaIpcCommImpl(HostComm h_comm):
    h_comm_{h_comm}, global_n_ranks_{h_comm->n_ranks()}, global_rank_{h_comm->rank()}
{
    h_comm_ = h_comm;

    const int n_ranks = global_n_ranks_;
    const int rank    = global_rank_;

    // Exchange device ordinals
    ordinals_.resize(n_ranks);
    check_cuda_error(cudaGetDevice(&ordinals_[rank]));
    comm::AllGather(h_comm_, ordinals_.data(), 1);

    max_ctas_ = {std::min(getSMCount(), kMaxChannels)};
    if (auto v = GetEnv()) {
        max_ctas_.set_value(std::min(v, max_ctas_.value()));
    }
    auto minval = comm::AllReduce(h_comm_, max_ctas_.value(), RedOp::kMin);
    TM_CHECK_EQ(max_ctas_.value(), minval) << "MAX_CTAS set to different values";

#if __CUDACC_VER_MAJOR__ >= 12
    if (global_n_ranks_ >= 4 && GetEnv()) {  // solve 2n-2>n+1 -> n>3
        CUDRVCHECK(
            cuDeviceGetAttribute(&multicast_capability_, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, ordinals_[rank]));
        multicast_capability_ = comm::AllReduce(h_comm_, multicast_capability_, RedOp::kMin);
    }
#endif

    copy_threshold_ = GetEnv();

    // Prepare access descriptors
    alloc_access_descs_.resize(n_ranks);
    for (int r = 0; r < n_ranks; ++r) {
        alloc_access_descs_[r].location.id   = ordinals_[r];
        alloc_access_descs_[r].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
        alloc_access_descs_[r].flags         = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
    }

    // Initialize group mapping
    std::vector idxs(n_ranks);
    std::iota(idxs.begin(), idxs.end(), 0);
    auto& g = groups_.emplace_back();
    g.l2g = g.g2l = idxs;

    // Prepare packet buffer
    packet_buff_ = Allocate(kPacketBuffSize);
    check_cuda_error(cudaMemsetAsync(packet_buff_, 0, kPacketBuffSize));

    // Prepare scratch buffer
    scratch_buff_ = Allocate(kScratchBuffSize);
    check_cuda_error(cudaMemsetAsync(scratch_buff_, 0, kScratchBuffSize));

    /// TODO: release
    g.semaphore.Allocate(global_n_ranks_, global_rank_, [this](size_t size) {
        auto buf = (uint64_t*)Allocate(size);
        check_cuda_error(cudaMemsetAsync(buf, 0, size));
        check_cuda_error(cudaStreamSynchronize(0));
        Register(buf, size);
        return get_symmetric_v2(buf, 0);
    });

    check_cuda_error(cudaStreamSynchronize(0));

    Register(packet_buff_, kPacketBuffSize);
    Register(scratch_buff_, kScratchBuffSize);
}

CudaIpcCommImpl::~CudaIpcCommImpl()
{
    Deregister(scratch_buff_);
    Deregister(packet_buff_);

    Free(scratch_buff_);
    Free(packet_buff_);

    for (auto i = (int)groups_.size() - 1; i >= 0; --i) {
        groups_[i].semaphore.Free([this](void* ptr) {
            Deregister(ptr);
            Free(ptr);
        });
    }

    for (const auto& a : allocation_) {
        TM_LOG_WARNING("[COMM][%d] Allocation (%p, %lu) is not freed", global_rank_, a.uc_beg, a.size);
    }

    cudaStreamSynchronize(0);
}

void* CudaIpcCommImpl::Allocate(size_t size)
{
    size_t              granularity{};
    CUmemAllocationProp prop{};

    prop.type          = CU_MEM_ALLOCATION_TYPE_PINNED;
    prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    prop.location.id   = ordinals_[global_rank_];

    if (multicast_capability_) {
#if __CUDACC_VER_MAJOR__ >= 12
        CUmulticastObjectProp prop{};
        prop.numDevices = alloc_access_descs_.size();
        prop.size       = size;
        CUDRVCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_MINIMUM));
#else
        TM_CHECK(0);
#endif
    }
    else {
        CUDRVCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
    }

    size = round_up(size, granularity);

    CUmemGenericAllocationHandle handle{};
    CUDRVCHECK(cuMemCreate(&handle, size, &prop, 0));

    CUdeviceptr ptr{};
    CUDRVCHECK(cuMemAddressReserve(&ptr, size, granularity, 0, 0));
    CUDRVCHECK(cuMemMap(ptr, size, 0, handle, 0));
    CUDRVCHECK(cuMemSetAccess(ptr, size, alloc_access_descs_.data(), alloc_access_descs_.size()));

    Allocation a{};
    a.handle    = handle;
    a.size      = size;
    a.uc_beg    = reinterpret_cast(ptr);
    a.uc_end    = (char*)a.uc_beg + size;
    a.alignment = granularity;

    a.uc_ptrs = comm::AllGather(h_comm_, a.uc_beg);

    allocation_.emplace(a);

    return a.uc_beg;
}

void CudaIpcCommImpl::Free(void* ptr)
{
    if (auto it = allocation_.find(ptr); it != allocation_.end()) {
        auto& a    = *it;
        auto  dptr = reinterpret_cast(ptr);
        CUDRVCHECK(cuMemUnmap(dptr, a.size));
        CUDRVCHECK(cuMemRelease(a.handle));
        CUDRVCHECK(cuMemAddressFree(dptr, a.size));
        allocation_.erase(it);
    }
    else {
        TM_LOG_WARNING("[TM][COMM][%d] Freeing %p which is not allocated by this module", global_rank_, ptr);
    }
}

void CudaIpcCommImpl::Register(void* ptr, size_t size)
{
    // register for all groups
    auto& symm = groups_.at(0).symmetric;

    if (symm.find(ptr) != symm.end()) {
        TM_LOG_WARNING("[TM][COMM][%d] Duplicated registration on (%p, %lu)", global_rank_, ptr, size);
        return;
    }

    auto alloc = allocation_.find(ptr);
    TM_CHECK(alloc != allocation_.end());

    for (size_t i = 0; i < groups_.size(); ++i) {
        Register(*alloc, i);
    }
}

void CudaIpcCommImpl::Register(const Allocation& alloc, int group)
{
    auto size = alloc.size;

    auto& g = groups_.at(group);

    Symmetric s{};
    s.size   = size;
    s.uc_beg = alloc.uc_beg;
    s.uc_end = alloc.uc_end;

    for (auto r : g.l2g) {
        s.uc_ptrs.push_back(alloc.uc_ptrs[r]);
    }

    const int ranks = n_ranks(group);
    const int rank  = this->rank(group);

    if (multicast_capability_ && ranks > 1) {  // ! `cuMulticastCreate` fails for `ranks == 1`
#if __CUDACC_VER_MAJOR__ >= 12
        CUmulticastObjectProp mc_prop{};
        mc_prop.numDevices = ranks;
        mc_prop.size       = size;
        if (rank == 0) {
            CUDRVCHECK(cuMulticastCreate(&s.mc_handle, &mc_prop));
        }
        auto handles = comm::AllGather(h_comm_, s.mc_handle);
        s.mc_handle  = handles.at(g.l2g[0]);
        CUDRVCHECK(cuMulticastAddDevice(s.mc_handle, ordinals_[global_rank_]));
        CUDRVCHECK(cuMulticastBindMem(s.mc_handle, 0, alloc.handle, 0, size, 0));
        CUdeviceptr mc_ptr{};
        CUDRVCHECK(cuMemAddressReserve(&mc_ptr, size, alloc.alignment, 0, 0));
        CUDRVCHECK(cuMemMap(mc_ptr, size, 0, s.mc_handle, 0));
        CUDRVCHECK(cuMemSetAccess(mc_ptr, size, &alloc_access_descs_[global_rank_], 1));
        s.mc_ptr = reinterpret_cast(mc_ptr);
        if (rank != 0) {
            // Increase reference count to the original handle so that all handles can be released
            // without explicit synchronization
            CUDRVCHECK(cuMemRetainAllocationHandle(&s.mc_handle, s.mc_ptr));
        }
#else
        TM_CHECK(0);
#endif
    }

    g.symmetric.insert(std::move(s));
}

void CudaIpcCommImpl::Deregister(Symmetric& s)
{
    if (s.mc_handle) {
#if __CUDACC_VER_MAJOR__ >= 12
        auto deviceptr = reinterpret_cast(s.mc_ptr);
        CUDRVCHECK(cuMemUnmap(deviceptr, s.size));
        CUDRVCHECK(cuMemAddressFree(deviceptr, s.size));
        CUDRVCHECK(cuMulticastUnbind(s.mc_handle, ordinals_.at(global_rank_), 0, s.size));
        CUDRVCHECK(cuMemRelease(s.mc_handle));
        s.mc_handle = {};
        s.mc_ptr    = {};
#else
        TM_CHECK(0);
#endif
    }
}

void CudaIpcCommImpl::Deregister(void* ptr)
{
    std::vector handles;

    for (size_t i = 0; i < groups_.size(); ++i) {
        auto& s = groups_[i].symmetric;
        if (auto it = s.find(ptr); it != s.end()) {
            Deregister(s.extract(it).value());
        }
        else {
            TM_LOG_WARNING("[TM][COMM][%d] Deregistering non-registered address %p", global_rank_, ptr);
        }
    }
}

int CudaIpcCommImpl::Query(QueryAttr attr) const noexcept
{
    if (attr == kHasAllGather2D) {
        return 1;
    }
    return 0;
}

auto CudaIpcCommImpl::get_symmetric_v2_impl(void* ptr, int group) -> SymmetricPtr_V2
{
    auto& g = groups_.at(group);

    auto symm = g.symmetric.find(ptr);
    TM_CHECK(symm != g.symmetric.end());

    auto offset = (char*)ptr - (char*)symm->uc_beg;

    SymmetricPtr_V2 p{};

    TM_CHECK_LE((int)symm->uc_ptrs.size(), p.uc.size());

    for (size_t i = 0; i < symm->uc_ptrs.size(); ++i) {
        p.uc[i] = (char*)symm->uc_ptrs[i] + offset;
    }

    if (symm->mc_ptr) {
        p.mc = (char*)symm->mc_ptr + offset;
    }

    return p;
}

DeviceComm CreateCudaIpcCommunicator(int n_ranks, int rank, HostComm h_comm)
{
    auto comm = std::make_unique(h_comm);

    return DeviceComm{std::move(comm)};
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/comm/cuda_ipc/semaphore.h"
#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/comm/host_comm.h"

#include "src/turbomind/kernels/core/array.h"

#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

class MaxCtas {
public:
    MaxCtas(int value = 0): is_set_{}, value_{value} {}

    void set_value(int value)
    {
        value_  = value;
        is_set_ = true;
    }

    int value()
    {
        return value_;
    }

    int apply(int _default)
    {
        if (!is_set_) {  // `value_` is max possible value in this case
            return std::min(_default, value_);
        }
        else {
            return value_;
        }
    }

private:
    bool is_set_;
    int  value_;
};

class CudaIpcCommImpl: public DeviceCommImpl {
    struct Allocation;
    struct Symmetric;

public:
    ~CudaIpcCommImpl() override;

    explicit CudaIpcCommImpl(HostComm h_comm);

    int n_ranks(int group) const override
    {
        return groups_.at(group).l2g.size();
    }

    int rank(int group) const override
    {
        return groups_.at(group).g2l.at(global_rank_);
    }

    void* Allocate(size_t size) override;

    void Free(void* ptr) override;

    void Register(void* ptr, size_t size) override;

    void Deregister(void* ptr) override;

    int Split(int color, int key, int group) override;

    int Query(QueryAttr attr) const noexcept override;

    void AllReduceSum(
        const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream) override;

    void AllGather(
        const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream) override;

    void Broadcast(const void*  sendbuff,
                   void*        recvbuff,
                   size_t       count,
                   DataType     type,
                   int          root,
                   int          group,
                   cudaStream_t stream) override;

    void
    Gather(const void* sendbuff, void* recvbuff, size_t count, DataType type, int root, int group, cudaStream_t stream);

    void Barrier(int group, cudaStream_t stream);

    void AllreduceResidualBiasRMSnorm(void*        hidden,
                                      void*        residual,
                                      const void*  bias,
                                      const void*  weights,
                                      float        eps,
                                      int          dim,
                                      int          token_num,
                                      DataType     dtype,
                                      int          group,
                                      cudaStream_t stream) override;

    void AllreduceResidualBiasRMSnormEx(void*        hidden,
                                        void*        residual,
                                        const void*  bias,
                                        const void*  weights,
                                        float        eps,
                                        int          dim,
                                        DataType     type,
                                        int          group0,
                                        int          group1,
                                        const int*   local_token_nums,
                                        cudaStream_t stream) override;

    void AllGather2D(const void*  sendbuff,
                     void*        recvbuff,
                     size_t       pitch,
                     size_t       stride,
                     int          width,
                     int          height,
                     DataType     type,
                     int2         flags,
                     int          group,
                     cudaStream_t stream) override;

private:
    template
    inline SymmetricPtr_V2 get_symmetric_v2(T* ptr, int group)
    {
        auto               tmp = get_symmetric_v2_impl(ptr, group);
        SymmetricPtr_V2 ret{};
        ret.mc = static_cast(tmp.mc);
        for (int i = 0; i < ret.uc.size(); ++i) {
            ret.uc[i] = static_cast(tmp.uc[i]);
        }
        return ret;
    }

    SymmetricPtr_V2 get_symmetric_v2_impl(void* ptr, int group);

    void Register(const Allocation& alloc, int group);

    void Deregister(Symmetric& s);

private:
    HostComm h_comm_;

    int global_n_ranks_;
    int global_rank_;

    std::vector ordinals_;

    struct Symmetric {
        void*              uc_beg;
        void*              uc_end;
        size_t             size;
        std::vector uc_ptrs;  // peers
        void*              mc_ptr;

        CUmemGenericAllocationHandle mc_handle;

        friend bool operator<(const Symmetric& a, const Symmetric& b)
        {
            return (char*)a.uc_beg < (char*)b.uc_beg;
        }
        friend bool operator<(const Symmetric& a, void* b)
        {
            return (char*)a.uc_end <= (char*)b;
        }
        friend bool operator<(void* a, const Symmetric& b)
        {
            return (char*)a < (char*)b.uc_beg;
        }
    };

    void*    packet_buff_{};
    void*    scratch_buff_{};
    uint32_t flag_{1};

    struct Allocation {
        void*                        uc_beg;
        void*                        uc_end;
        size_t                       size;
        size_t                       alignment;
        std::vector           uc_ptrs;  // ranks
        CUmemGenericAllocationHandle handle;

        friend bool operator<(const Allocation& a, const Allocation& b)
        {
            return (char*)a.uc_beg < (char*)b.uc_beg;
        }
        friend bool operator<(const Allocation& a, void* b)
        {
            return (char*)a.uc_end <= (char*)b;
        }
        friend bool operator<(void* a, const Allocation& b)
        {
            return (char*)a < (char*)b.uc_beg;
        }
    };

    std::vector alloc_access_descs_{};

    int multicast_capability_{false};

    std::set> allocation_;

    struct Group {
        std::vector l2g;  // local -> global
        std::vector g2l;  // global -> local

        SystemSemaphoreStorage semaphore;

        std::set> symmetric;
    };

    std::vector groups_;

    MaxCtas max_ctas_;
    size_t  copy_threshold_{INT64_MAX};
};

struct Rank {
    int                     rank;
    int                     peers;
    __host__ __device__ int get_next_peer(int i)
    {
        return i + rank < peers ? i + rank : i + rank - peers;
    }
    __host__ __device__ int get_prev_peer(int i)
    {
        return get_next_peer(peers - 1 - i);
    }
    __host__ __device__ int get_peer_rank(int p)  // rank of `p`
    {
        return p < rank ? p : p + 1;
    }
    __host__ __device__ int inverse_peer(int p)  // peer idx of `rank` on peer `p`
    {
        return p < rank ? rank - 1 : rank;
    }
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/fused_allreduce.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 

#include "cub/block/block_reduce.cuh"

#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/group_sum.h"
#include "src/turbomind/comm/cuda_ipc/multimem.cuh"
#include "src/turbomind/comm/cuda_ipc/semaphore.cuh"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/kernels/norm/rms_norm.h"

#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

template
__global__ void AllreduceResidualBiasRMSnorm_Simple_Pull(T*                   buf,
                                                         T*                   res,
                                                         const T*             bias,
                                                         const T*             weights,
                                                         Array symm,
                                                         SystemSemaphoreInfo* semaphores,
                                                         int                  rank,
                                                         int                  ranks,
                                                         int                  slice,
                                                         int                  count,
                                                         int                  vdim,
                                                         float                inv_dim,
                                                         float                eps,
                                                         constant,
                                                         constant,
                                                         constant,
                                                         Relaxed relaxed)
{
    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    using Vec = Array;

    using namespace ops;

    static_assert(block_dim % groups == 0);
    constexpr int threads = block_dim / groups;

    static_assert(threads % WARP_SIZE == 0);
    constexpr int warps = threads / WARP_SIZE;

    const int xi = threadIdx.x / threads;
    const int di = threadIdx.x % threads;
    const int bi = blockIdx.x * groups + xi;
    const int bn = gridDim.x * groups;

    auto syncgroup = [&] {  //
        asm volatile("bar.sync %0, %1;" : : "r"(15 - xi), "r"(threads) : "memory");
    };

    const int first = rank * slice;
    const int last  = min(count, first + slice);

    for (int i = 1; i < ranks - 1; ++i) {
        const int  p   = rank + i < ranks ? rank + i : rank + i - ranks;
        const auto src = cvta_generic_to_global(symm[p]);
        Vec        acc, tmp;
        for (int ti = first + bi; ti < last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            if (di < vdim) {
                Load(tmp, src + idx);
                Load(acc, buf + idx);
                acc = acc + tmp;
                Store(buf + idx, acc);
            }
        }
    }

    Vec b_vec{};
    if (bias && di < vdim) {
        Ldg(b_vec, bias + di * vec_size);
    }

    Vec w_vec;
    if (di < vdim) {
        Ldg(w_vec, weights + di * vec_size);
    }

    {
        const int p   = rank > 0 ? rank - 1 : ranks - 1;  // last peer
        auto      chn = cvta_generic_to_global(symm[p]);
        for (int ti = first + bi; ti < last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            Vec       acc, tmp;
            Vec       r_vec{};
            float     sum{};
            if (di < vdim) {
                Load(tmp, chn + idx);
                Load(acc, buf + idx);
                acc = acc + tmp;
                Load(r_vec, res + idx);
                r_vec = r_vec + acc;
                if (bias) {
                    r_vec = r_vec + b_vec;
                }
                Store(res + idx, r_vec);
                PRAGMA_UNROLL
                for (int i = 0; i < vec_size; ++i) {
                    sum += (float)r_vec[i] * (float)r_vec[i];
                }
            }
            sum = detail::GroupSum(sum, warps, syncgroup);
            __shared__ float shared_sum[groups];
            if (di == 0) {
                shared_sum[xi] = rsqrtf(sum * inv_dim + eps);
            }
            syncgroup();
            sum = shared_sum[xi];
            if (di < vdim) {
                PRAGMA_UNROLL
                for (int i = 0; i < vec_size; ++i) {
                    r_vec[i] = static_cast(((float)r_vec[i] * sum)) * w_vec[i];
                }
                Store(buf + idx, r_vec);
            }
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    for (int i = 1; i < ranks; ++i) {
        const int p     = rank + i < ranks ? rank + i : rank + i - ranks;
        const int first = slice * p;
        const int last  = min(count, first + slice);
        auto      src   = cvta_generic_to_global(symm[p]);
        for (int ti = first + bi; ti < last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            if (di < vdim) {
                Vec vec;
                Load(vec, src + idx);
                Store(buf + idx, vec);
            }
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void AllreduceResidualBiasRMSnorm_NVLS(T*                   mc_buf,
                                                  T*                   uc_buf,
                                                  T*                   res,
                                                  const T*             bias,
                                                  const T*             weights,
                                                  SystemSemaphoreInfo* semaphores,
                                                  int                  rank,
                                                  int                  ranks,
                                                  int                  slice,
                                                  int                  count,
                                                  int                  vdim,
                                                  float                inv_dim,
                                                  float                eps,
                                                  constant,
                                                  constant,
                                                  constant,
                                                  Relaxed relaxed)
{

#if TURBOMIND_ARCH_SM90

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);

    static_assert(block_dim % groups == 0);
    constexpr int threads = block_dim / groups;

    static_assert(threads % WARP_SIZE == 0);
    constexpr int warps = threads / WARP_SIZE;

    const int xi = threadIdx.x / threads;
    const int di = threadIdx.x % threads;

    using Vec = Array;

    Vec b_vec{};
    if (bias && di < vdim) {
        Ldg(b_vec, bias + di * vec_size);
    }

    Vec w_vec;
    if (di < vdim) {
        Ldg(w_vec, weights + di * vec_size);
    }

    sem.Wait(relaxed);

    __syncthreads();

    using namespace ops;

    const int bi = blockIdx.x * groups + xi;
    const int bn = gridDim.x * groups;

    auto syncgroup = [&] {  //
        asm volatile("bar.sync %0, %1;" : : "r"(15 - xi), "r"(threads) : "memory");
    };

    const int first = rank * slice;
    const int last  = min(count, first + slice);

    for (int ti = first + bi; ti < last; ti += bn) {
        const int idx = (ti * vdim + di) * vec_size;
        float     sum{};
        Vec       vec;
        if (di < vdim) {
            Vec acc = multimem_ld_reduce_sum((const Vec*)(mc_buf + idx));
            Load(vec, res + idx);
            vec = vec + acc;
            if (bias) {
                vec = vec + b_vec;
            }
            Store(res + idx, vec);
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                sum += (float)vec[i] * (float)vec[i];
            }
        }
        sum = detail::GroupSum(sum, warps, syncgroup);
        __shared__ float shared_sum[groups];
        if (di == 0) {
            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);
        }
        syncgroup();
        sum = shared_sum[xi];
        if (di < vdim) {
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                vec[i] = static_cast(((float)vec[i] * sum)) * w_vec[i];
            }
            multimem_st(mc_buf + idx, vec);
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);

#endif
}

template
__global__ void AllreduceResidualBiasRMSnorm_Simple_Push(T*                   buf,
                                                         T*                   res,
                                                         const T*             bias,
                                                         const T*             weights,
                                                         T*                   scratch,
                                                         Array symm_buf,
                                                         Array symm_scratch,
                                                         SystemSemaphoreInfo* semaphores,
                                                         int                  rank,
                                                         int                  ranks,
                                                         int                  slice,
                                                         int                  count,
                                                         int                  vdim,
                                                         float                inv_dim,
                                                         float                eps,
                                                         constant,
                                                         constant,
                                                         constant,
                                                         Relaxed relaxed)
{
    using Vec = Array;

    using namespace ops;

    static_assert(block_dim % groups == 0);
    constexpr int threads = block_dim / groups;

    static_assert(threads % WARP_SIZE == 0);
    constexpr int warps = threads / WARP_SIZE;

    const int xi = threadIdx.x / threads;
    const int di = threadIdx.x % threads;
    const int bi = blockIdx.x * groups + xi;
    const int bn = gridDim.x * groups;

    auto syncgroup = [&] {  //
        asm volatile("bar.sync %0, %1;" : : "r"(15 - xi), "r"(threads) : "memory");
    };

    for (int i = 1; i < ranks; ++i) {
        const int  p   = rank + i < ranks ? rank + i : rank + i - ranks;
        const int  n   = min(count, p * slice + slice) - p * slice;
        const auto src = buf + p * slice * vdim * vec_size;
        const auto dst = symm_scratch[p] + rank * slice * vdim * vec_size;
        for (int ti = bi; ti < n; ti += bn) {
            if (di < vdim) {
                Vec vec;
                Load(vec, src + (ti * vdim + di) * vec_size);
                Store(dst + (ti * vdim + di) * vec_size, vec);
            }
        }
    }

    __syncthreads();

    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

    Vec b_vec{};
    if (bias && di < vdim) {
        Ldg(b_vec, bias + di * vec_size);
    }

    Vec w_vec;
    if (di < vdim) {
        Ldg(w_vec, weights + di * vec_size);
    }

    const int n = min(count, rank * slice + slice) - rank * slice;

    for (int ti = bi; ti < n; ti += bn) {
        const int idx = ((rank * slice + ti) * vdim + di) * vec_size;  // idx into local buffers
        Vec       r_vec{};
        float     sum{};
        if (di < vdim) {
            Vec acc;
            Load(acc, buf + idx);
            for (int i = 1; i < ranks; ++i) {
                const int p = rank + i < ranks ? rank + i : rank + i - ranks;
                Vec       tmp;
                Load(tmp, scratch + ((p * slice + ti) * vdim + di) * vec_size);
                acc = acc + tmp;
            }
            Load(r_vec, res + idx);
            r_vec = r_vec + acc;
            if (bias) {
                r_vec = r_vec + b_vec;
            }
            Store(res + idx, r_vec);
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                sum += (float)r_vec[i] * (float)r_vec[i];
            }
        }

        sum = detail::GroupSum(sum, warps, syncgroup);
        __shared__ float shared_sum[groups];
        if (di == 0) {
            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);
        }
        syncgroup();
        sum = shared_sum[xi];

        if (di < vdim) {
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                r_vec[i] = static_cast(((float)r_vec[i] * sum)) * w_vec[i];
            }
            Store(buf + idx, r_vec);
            for (int i = 1; i < ranks; ++i) {
                const int p = rank + i < ranks ? rank + i : rank + i - ranks;
                Store(symm_buf[p] + ((rank * slice + ti) * vdim + di) * vec_size, r_vec);
            }
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);
}

void CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void*        hidden,
                                                   void*        residual,
                                                   const void*  bias,
                                                   const void*  weights,
                                                   float        eps,
                                                   int          dim,
                                                   int          token_num,
                                                   DataType     dtype,
                                                   int          group,
                                                   cudaStream_t stream)
{

    const size_t elemsize = byte_size(dtype);
    const size_t bytesize = elemsize * token_num * dim;

    const int n_ranks = this->n_ranks(group);
    const int rank    = this->rank(group);

    auto semaphore = groups_.at(group).semaphore.handle();

    auto invoke = [&](auto t, auto groups) {
        using T                = decltype(t);
        auto          symm_ptr = get_symmetric_v2((T*)hidden, group);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        const int     slice    = (token_num + n_ranks - 1) / n_ranks;
        const int     count    = token_num;

        if (symm_ptr.mc) {
            constexpr int block_dim = 1024;
            const int     max_ctas  = max_ctas_.apply(8);
            const int     blocks    = std::min((slice + groups - 1) / groups, max_ctas);
            AllreduceResidualBiasRMSnorm_NVLS<<>>(symm_ptr.mc,
                                                                                (T*)hidden,
                                                                                (T*)residual,
                                                                                (const T*)bias,
                                                                                (const T*)weights,
                                                                                semaphore,
                                                                                rank,
                                                                                n_ranks,
                                                                                slice,
                                                                                count,
                                                                                dim / vec_size,
                                                                                1.f / dim,
                                                                                eps,
                                                                                constant{},
                                                                                constant{},
                                                                                groups,
                                                                                std::false_type{});
        }
#if 1
        else if (bytesize <= 1 << 19) {
            return false;
        }
#endif
        else if (bytesize <= kScratchBuffSize && bytesize <= 6 << 20) {
            constexpr int block_dim    = 1024;
            const int     max_ctas     = max_ctas_.apply(48);
            const int     blocks       = std::min((slice + groups - 1) / groups, max_ctas);
            auto          symm_scratch = get_symmetric_v2((T*)scratch_buff_, group).uc;
            AllreduceResidualBiasRMSnorm_Simple_Push<<>>((T*)hidden,
                                                                                       (T*)residual,
                                                                                       (const T*)bias,
                                                                                       (const T*)weights,
                                                                                       (T*)scratch_buff_,
                                                                                       symm_ptr.uc,
                                                                                       symm_scratch,
                                                                                       semaphore,
                                                                                       rank,
                                                                                       n_ranks,
                                                                                       slice,
                                                                                       count,
                                                                                       dim / vec_size,
                                                                                       1.f / dim,
                                                                                       eps,
                                                                                       constant{},
                                                                                       constant{},
                                                                                       groups,
                                                                                       std::false_type{});
        }
        else {
            constexpr int block_dim = 1024;
            const int     max_ctas  = max_ctas_.apply(48);
            const int     blocks    = std::min((slice + groups - 1) / groups, max_ctas);
            AllreduceResidualBiasRMSnorm_Simple_Pull<<>>((T*)hidden,
                                                                                       (T*)residual,
                                                                                       (const T*)bias,
                                                                                       (const T*)weights,
                                                                                       symm_ptr.uc,
                                                                                       semaphore,
                                                                                       rank,
                                                                                       n_ranks,
                                                                                       slice,
                                                                                       count,
                                                                                       dim / vec_size,
                                                                                       1.f / dim,
                                                                                       eps,
                                                                                       constant{},
                                                                                       constant{},
                                                                                       groups,
                                                                                       std::false_type{});
        }

        return true;
    };

    auto dispatch_D = [&](auto t) {
        using T                = decltype(t);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        if (dim % vec_size) {
            return false;  // non-aligned
        }
        const int vdim = dim / vec_size;
        if (0) {}
        else if (vdim <= 256) {
            return invoke(t, constant<4>{});
        }
        else if (vdim <= 512) {
            return invoke(t, constant<2>{});
        }
        else if (vdim <= 1024) {
            return invoke(t, constant<1>{});
        }
        return false;  // > 1024 vdim
    };

    auto dispatch = [&]() -> bool { TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D); };

    if (dispatch()) {
        return;
    }

    // fallback
    AllReduceSum(hidden, hidden, token_num * dim, dtype, group, stream);
    invokeResidualBiasRMSNorm(hidden, residual, weights, bias, dtype, dim, token_num, eps, stream);
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/cuda_ipc/common.h"

#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/group_sum.h"
#include "src/turbomind/comm/cuda_ipc/semaphore.cuh"

#include "src/turbomind/comm/cuda_ipc/multimem.cuh"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

template
__global__ void AllreduceResidualBiasRMSnormV_Simple_Pull(T*                     buf,
                                                          T*                     res,
                                                          const T*               bias,
                                                          const T*               weights,
                                                          Array   rs_buf,
                                                          Array   ag_buf,
                                                          SystemSemaphoreInfo*   g_semaphores,
                                                          int                    rs_rank,
                                                          int                    ag_rank,
                                                          int                    rs_ranks,
                                                          int                    ag_ranks,
                                                          int                    g_rank,
                                                          int                    g_ranks,
                                                          int                    offset,
                                                          int                    first,
                                                          int                    last,
                                                          Array ag_ranges,
                                                          int                    vdim,
                                                          float                  inv_dim,
                                                          float                  eps,
                                                          constant,
                                                          constant,
                                                          constant,
                                                          Relaxed relaxed)
{
    SystemSemaphore sem(g_semaphores, g_ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);

    using Vec = Array;

    using namespace ops;

    static_assert(block_dim % groups == 0);
    constexpr int threads = block_dim / groups;

    static_assert(threads % WARP_SIZE == 0);
    constexpr int warps = threads / WARP_SIZE;

    const int xi = threadIdx.x / threads;
    const int di = threadIdx.x % threads;

    Vec b_vec{};
    if (bias && di < vdim) {
        Ldg(b_vec, bias + di * vec_size);
    }

    Vec w_vec;
    if (di < vdim) {
        Ldg(w_vec, weights + di * vec_size);
    }

    sem.Wait(relaxed);

    __syncthreads();

    const int bi = blockIdx.x * groups + xi;
    const int bn = gridDim.x * groups;

    for (int i = 1; i < rs_ranks - 1; ++i) {
        const int  p   = rs_rank + i < rs_ranks ? rs_rank + i : rs_rank + i - rs_ranks;
        const auto src = cvta_generic_to_global(rs_buf[p]);
        Vec        acc, tmp;
        for (int ti = offset + first + bi; ti < offset + last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            if (di < vdim) {
                Load(tmp, src + idx);
                Load(acc, buf + idx);
                acc = acc + tmp;
                Store(buf + idx, acc);
            }
        }
    }

    auto syncgroup = [&] {  //
        asm volatile("bar.sync %0, %1;" : : "r"(15 - xi), "r"(threads) : "memory");
    };

    {
        const T* chn{};
        if (rs_ranks > 1) {
            const int p = rs_rank > 0 ? rs_rank - 1 : rs_ranks - 1;  // last peer
            chn         = cvta_generic_to_global(rs_buf[p]);
        }
        for (int ti = first + bi; ti < last; ti += bn) {
            const int idx = ((offset + ti) * vdim + di) * vec_size;
            Vec       acc, tmp;
            Vec       r_vec{};
            float     sum{};
            if (di < vdim) {
                if (chn) {
                    Load(tmp, chn + idx);
                }
                Load(acc, buf + idx);
                if (chn) {
                    acc = acc + tmp;
                }
                Load(r_vec, res + (ti * vdim + di) * vec_size);
                r_vec = r_vec + acc;
                if (bias) {
                    r_vec = r_vec + b_vec;
                }
                Store(res + (ti * vdim + di) * vec_size, r_vec);
                PRAGMA_UNROLL
                for (int i = 0; i < vec_size; ++i) {
                    sum += (float)r_vec[i] * (float)r_vec[i];
                }
            }
            sum = detail::GroupSum(sum, warps, syncgroup);
            __shared__ float shared_sum[groups];
            if (di == 0) {
                shared_sum[xi] = rsqrtf(sum * inv_dim + eps);
            }
            syncgroup();
            sum = shared_sum[xi];
            if (di < vdim) {
                PRAGMA_UNROLL
                for (int i = 0; i < vec_size; ++i) {
                    r_vec[i] = static_cast(((float)r_vec[i] * sum)) * w_vec[i];
                }
                Store(buf + idx, r_vec);
            }
        }
    }

    __syncthreads();

    sem.Signal(relaxed);
    sem.Wait(relaxed);

    __syncthreads();

#if 1
    for (int i = 1; i < ag_ranks; ++i) {
        const int p   = ag_rank + i < ag_ranks ? ag_rank + i : ag_rank + i - ag_ranks;
        auto      dst = cvta_generic_to_global(ag_buf[p]);
        for (int ti = offset + first + bi; ti < offset + last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            if (di < vdim) {
                Vec vec;
                Load(vec, buf + idx);
                Store(dst + idx, vec);
            }
        }
    }
#else
    for (int i = 1; i < ag_ranks; ++i) {
        const int p              = ag_rank + i < ag_ranks ? ag_rank + i : ag_rank + i - ag_ranks;
        const auto [first, last] = ag_ranges[p];
        auto src                 = cvta_generic_to_global(ag_buf[p]);
        for (int ti = first + bi; ti < last; ti += bn) {
            const int idx = (ti * vdim + di) * vec_size;
            if (di < vdim) {
                Vec vec;
                Load(vec, src + idx);
                Store(buf + idx, vec);
            }
        }
    }
#endif

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(g_semaphores, g_ranks, blockIdx.x, threadIdx.x);
}

template
__global__ void AllreduceResidualBiasRMSnormV_NVLS(T*                   rs_mc_buf,
                                                   T*                   ag_mc_buf,
                                                   T*                   res,
                                                   const T*             bias,
                                                   const T*             weights,
                                                   SystemSemaphoreInfo* semaphores,
                                                   int                  g_rank,
                                                   int                  g_ranks,
                                                   int                  first,
                                                   int                  last,
                                                   int                  offset,
                                                   int                  vdim,
                                                   float                inv_dim,
                                                   float                eps,
                                                   constant,
                                                   constant,
                                                   constant,
                                                   Relaxed relaxed)
{

#if TURBOMIND_ARCH_SM90

    SystemSemaphore sem(semaphores, g_ranks, blockIdx.x, threadIdx.x);

    sem.Signal(relaxed);

    using Vec = Array;

    using namespace ops;

    static_assert(block_dim % groups == 0);
    constexpr int threads = block_dim / groups;

    static_assert(threads % WARP_SIZE == 0);
    constexpr int warps = threads / WARP_SIZE;

    const int xi = threadIdx.x / threads;
    const int di = threadIdx.x % threads;

    using Vec = Array;

    Vec b_vec{};
    if (bias && di < vdim) {
        Ldg(b_vec, bias + di * vec_size);
    }

    Vec w_vec;
    if (di < vdim) {
        Ldg(w_vec, weights + di * vec_size);
    }

    sem.Wait(relaxed);

    __syncthreads();

    const int bi = blockIdx.x * groups + xi;
    const int bn = gridDim.x * groups;

    auto syncgroup = [&] {  //
        asm volatile("bar.sync %0, %1;" : : "r"(15 - xi), "r"(threads) : "memory");
    };

    for (int ti = first + bi; ti < last; ti += bn) {
        const int idx = ((offset + ti) * vdim + di) * vec_size;
        float     sum{};
        Vec       vec;
        if (di < vdim) {
            Vec acc = multimem_ld_reduce_sum((const Vec*)(rs_mc_buf + idx));
            Load(vec, res + (ti * vdim + di) * vec_size);
            vec = vec + acc;
            if (bias) {
                vec = vec + b_vec;
            }
            Store(res + (ti * vdim + di) * vec_size, vec);
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                sum += (float)vec[i] * (float)vec[i];
            }
        }
        sum = detail::GroupSum(sum, warps, syncgroup);
        __shared__ float shared_sum[groups];
        if (di == 0) {
            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);
        }
        syncgroup();
        sum = shared_sum[xi];
        if (di < vdim) {
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                vec[i] = static_cast(((float)vec[i] * sum)) * w_vec[i];
            }
            multimem_st(ag_mc_buf + idx, vec);
        }
    }

    __syncthreads();

    sem.Signal(true);
    sem.Wait(true);

    sem.Update(semaphores, g_ranks, blockIdx.x, threadIdx.x);
#endif
}

void CudaIpcCommImpl::AllreduceResidualBiasRMSnormEx(void*        hidden,
                                                     void*        residual,
                                                     const void*  bias,
                                                     const void*  weights,
                                                     float        eps,
                                                     int          dim,
                                                     DataType     dtype,
                                                     int          group0,
                                                     int          group1,
                                                     const int*   local_token_nums,
                                                     cudaStream_t stream)
{
    FT_CHECK(group0 * group1 == 0);

    const auto& g0 = groups_.at(group0);
    const auto& g1 = groups_.at(group1);

    const int tp0 = n_ranks(group0);
    const int tp1 = n_ranks(group1);

    const int inner_tp = std::min(tp0, tp1);

    FT_CHECK(tp0 % inner_tp == 0 && tp1 % inner_tp == 0);

    Array offsets{};
    Array firsts{};
    Array lasts{};

    for (int i = 0, offset = 0; i < global_n_ranks_; ++i) {
        const int num   = local_token_nums[i / inner_tp];
        const int slice = (num + inner_tp - 1) / inner_tp;
        const int first = std::min(num, i % inner_tp * slice);
        const int last  = std::min(num, first + slice);

        std::tie(offsets[i], firsts[i], lasts[i]) = std::tie(offset, first, last);

        if ((i + 1) % inner_tp == 0) {
            offset += num;
        }
    }
    const int g_rank = rank(0);

    const int first  = firsts[g_rank];
    const int last   = lasts[g_rank];
    const int offset = offsets[g_rank];

    auto semaphore = groups_.at(0).semaphore.handle();

    auto invoke = [&](auto t, auto groups) {
        using T                = decltype(t);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);

        auto rs_symm_ptr = get_symmetric_v2((T*)hidden, group0);
        auto ag_symm_ptr = get_symmetric_v2((T*)hidden, group1);

        if (rs_symm_ptr.mc && ag_symm_ptr.mc) {
            const int max_ctas = max_ctas_.apply(40);
            AllreduceResidualBiasRMSnormV_NVLS<<>>(rs_symm_ptr.mc,
                                                                              ag_symm_ptr.mc,
                                                                              (T*)residual,
                                                                              (const T*)bias,
                                                                              (const T*)weights,
                                                                              semaphore,
                                                                              g_rank,
                                                                              n_ranks(0),
                                                                              first,
                                                                              last,
                                                                              offset,
                                                                              dim / vec_size,
                                                                              1.f / dim,
                                                                              eps,
                                                                              constant{},
                                                                              constant<1024>{},
                                                                              constant<1>{},
                                                                              std::true_type{});
        }
        else {
            Array ag_ranges{};
            for (int i = 0; i < tp1; ++i) {
                const auto r = g1.l2g[i];
                ag_ranges[i] = {offsets[r] + firsts[r], offsets[r] + lasts[r]};
            }
            const int max_ctas = max_ctas_.apply(48);
            AllreduceResidualBiasRMSnormV_Simple_Pull<<>>((T*)hidden,
                                                                                     (T*)residual,
                                                                                     (const T*)bias,
                                                                                     (const T*)weights,
                                                                                     rs_symm_ptr.uc,
                                                                                     ag_symm_ptr.uc,
                                                                                     semaphore,
                                                                                     rank(group0),
                                                                                     rank(group1),
                                                                                     tp0,
                                                                                     tp1,
                                                                                     rank(0),
                                                                                     n_ranks(0),
                                                                                     offset,
                                                                                     first,
                                                                                     last,
                                                                                     ag_ranges,
                                                                                     dim / vec_size,
                                                                                     1.f / dim,
                                                                                     eps,
                                                                                     constant{},
                                                                                     constant<1024>{},
                                                                                     constant<1>{},
                                                                                     std::true_type{});
        }
        return true;
    };

    sync_check_cuda_error();

    auto dispatch_D = [&](auto t) {
        using T                = decltype(t);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        if (dim % vec_size) {
            return false;  // non-aligned
        }
        const int vdim = dim / vec_size;
        if (0) {}
        else if (vdim <= 256) {
            return invoke(t, constant<4>{});
        }
        else if (vdim <= 512) {
            return invoke(t, constant<2>{});
        }
        else if (vdim <= 1024) {
            return invoke(t, constant<1>{});
        }
        return false;  // > 1024 vdim
    };

    auto dispatch = [&]() -> bool {  //
        TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D);
    };

    TM_CHECK(dispatch());
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/group_sum.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"

namespace turbomind::comm {

namespace detail {

template
__device__ float GroupSum(const float val, int warps, Syncgroup syncgroup)
{
    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;
    float     sum     = val;
    PRAGMA_UNROLL
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
        sum += __shfl_xor_sync((uint32_t)-1, sum, mask);
    }
    __shared__ float smem[32];
    // syncgroup();
    if (lane_id == 0) {
        smem[warp_id] = sum;
    }
    syncgroup();
    for (int i = 1; i < warps; ++i) {
        sum += smem[warp_id / warps * warps + i];
    }
    // sum = {};
    // for (int i = 0; i < warps; ++i) {
    //     sum += smem[warp_id / warps * warps + i];
    // }
    return sum;
}

}  // namespace detail

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/mscclpp.h
================================================
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#pragma once

#include 
#include 

namespace mscclpp {

// Copied from
// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/include/mscclpp/packet_device.hpp#L19
union alignas(16) LL16Packet {
    // Assume data is written with an atomicity of 8 bytes (IB/RDMA).
    struct {
        uint32_t data1;
        uint32_t flag1;
        uint32_t data2;
        uint32_t flag2;
    };
    using Payload = uint2;

    ulonglong2 raw_;

    __device__ LL16Packet() {}

    __device__ LL16Packet(uint2 val, uint32_t flag)
    {
        data1 = val.x;
        flag1 = flag;
        data2 = val.y;
        flag2 = flag;
    }

    /// Write 8 bytes of data to the packet.
    /// @param val1 The first 4-byte data to write.
    /// @param val2 The second 4-byte data to write.
    /// @param flag The flag to write.
    __device__ void write(uint32_t val1, uint32_t val2, uint32_t flag)
    {
        asm volatile(
            "st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(&raw_), "r"(val1), "r"(flag), "r"(val2), "r"(flag));
    }

    /// Write 8 bytes of data to the packet.
    /// @param val The 8-byte data to write.
    /// @param flag The flag to write.
    __device__ void write(uint64_t val, uint32_t flag)
    {
        write((uint32_t)val, (uint32_t)(val >> 32), flag);
    }

    /// Write 8 bytes of data to the packet.
    /// @param val The 8-byte data to write.
    /// @param flag The flag to write.
    __device__ void write(uint2 val, uint32_t flag)
    {
        write(val.x, val.y, flag);
    }

    /// Helper of @ref read().
    /// @param flag The flag to read.
    /// @param data The 8-byte data read.
    /// @return True if the flag is not equal to the given flag.
    __device__ bool readOnce(uint32_t flag, uint2& data) const
    {
        uint32_t flag1, flag2;
        asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
                     : "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2)
                     : "l"(&raw_));
        return (flag1 != flag) || (flag2 != flag);
    }

    /// Read 8 bytes of data from the packet.
    /// @param flag The flag to read.
    /// @return The 8-byte data read.
    __device__ uint2 read(uint32_t flag) const
    {
        uint2 data;
        while (readOnce(flag, data)) {}
        return data;
    }

    /// Clear the packet.
    __device__ void clear()
    {
        raw_ = make_ulonglong2(0, 0);
    }
};

using LLPacket = LL16Packet;

}  // namespace mscclpp


================================================
FILE: src/turbomind/comm/cuda_ipc/multimem.cuh
================================================
#pragma once

#include "src/turbomind/kernels/core/array.h"
#include 

namespace turbomind {

template
inline __device__ Array multimem_ld_reduce_sum(const Array* mc_ptr)
{
    return {};
}

inline __device__ Array multimem_ld_reduce_sum(const Array* mc_ptr)
{
    union {
        Array     x;
        Array u;
    };
    // LDGMC.E.ADD.F16x8.RN.STRONG.SYS
    asm volatile("multimem.ld_reduce.weak.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
                 : "=r"(u[0]), "=r"(u[1]), "=r"(u[2]), "=r"(u[3])
                 : "l"(mc_ptr)
                 : "memory");
    return x;
}

inline __device__ Array multimem_ld_reduce_sum(const Array* mc_ptr)
{
    union {
        Array x;
        Array    u;
    };
    asm volatile("multimem.ld_reduce.weak.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
                 : "=r"(u[0]), "=r"(u[1]), "=r"(u[2]), "=r"(u[3])
                 : "l"(mc_ptr)
                 : "memory");
    return x;
}

template
inline __device__ void multimem_st(T* mc_ptr, const Array& vec)
{
}

inline __device__ void multimem_st(half* mc_ptr, const Array& vec)
{
    union {
        Array     x;
        Array u;
    };
    x = vec;
    // STG.E.128
    asm volatile("multimem.st.weak.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr),
                 "r"(u[0]),
                 "r"(u[1]),
                 "r"(u[2]),
                 "r"(u[3]));
}

inline __device__ void multimem_st(nv_bfloat16* mc_ptr, const Array& vec)
{
    union {
        Array x;
        Array    u;
    };
    x = vec;
    asm volatile("multimem.st.weak.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr),
                 "r"(u[0]),
                 "r"(u[1]),
                 "r"(u[2]),
                 "r"(u[3]));
}

inline __device__ void multimem_st(uint4* mc_ptr, const uint4& u)
{
    asm volatile(
        "multimem.st.weak.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr), "r"(u.x), "r"(u.y), "r"(u.z), "r"(u.w));
}

inline __device__ void multimem_st(uint2* mc_ptr, const uint2& u) {}

inline __device__ void multimem_st(uint* mc_ptr, const uint& u) {}

}  // namespace turbomind


================================================
FILE: src/turbomind/comm/cuda_ipc/semaphore.cuh
================================================
#pragma once

#include 

#include "src/turbomind/kernels/core/array.h"

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/comm/cuda_ipc/semaphore.h"

namespace turbomind::comm {

template
__device__ T* cvta_generic_to_global(T* p)
{
    uintptr_t ret;
    asm("cvta.to.global.u64 %0, %1;" : "=l"(ret) : "l"(p));
    return reinterpret_cast(ret);
}

struct SystemSemaphore {

    using T = uint64_t;

    T* outbound_;
    T* inbound_;
    T  expected_;
    // T* mc_ptr_;

    bool uc_predicate_;
    // bool mc_predicate_;

    __device__ SystemSemaphore(const SystemSemaphoreInfo* info, int ranks, int channel, int thread_idx)
    {
        uc_predicate_ = thread_idx < ranks;
        // mc_predicate_ = thread_idx == 0;

        if (uc_predicate_) {
            int index = channel * kMaxRanks + thread_idx;
            inbound_  = info->inbound[index];
            outbound_ = info->outbound[index];
            expected_ = info->expected[index];
            // mc_ptr_   = info->mc_ptr[channel];
        }
    }

    __device__ void Update(SystemSemaphoreInfo* info, int ranks, int channel, int thread_idx)
    {
        if (uc_predicate_) {
            info->expected[channel * kMaxRanks + thread_idx] = expected_;
        }
    }

    __device__ void Signal(bool relaxed)
    {
        if (uc_predicate_) {
            if (relaxed) {
                asm volatile("atom.relaxed.sys.global.add.u64 _, [%0], %1;" ::"l"(outbound_), "n"(1) : "memory");
            }
            else {
                asm volatile("atom.release.sys.global.add.u64 _, [%0], %1;" ::"l"(outbound_), "n"(1) : "memory");
            }
        }
    }

    __device__ void Wait(bool relaxed)
    {
        if (uc_predicate_) {
            ++expected_;
            T x{};
            do {
                if (relaxed) {
                    asm volatile("ld.relaxed.sys.global.u64 %0,[%1];" : "=l"(x) : "l"(inbound_) : "memory");
                }
                else {
                    asm volatile("ld.acquire.sys.global.u64 %0,[%1];" : "=l"(x) : "l"(inbound_) : "memory");
                }
            } while (x < expected_);
        }
    }

    //     __device__ void SignalMulticast(bool relaxed)
    //     {
    // #if TURBOMIND_ARCH_SM90
    //         if (mc_predicate_) {
    //             if (relaxed) {
    //                 asm volatile("multimem.red.relaxed.sys.global.add.u64 [%0], %1;" ::"l"(mc_ptr_), "n"(1) :
    //                 "memory");
    //             }
    //             else {
    //                 asm volatile("multimem.red.release.sys.global.add.u64 [%0], %1;" ::"l"(mc_ptr_), "n"(1) :
    //                 "memory");
    //             }
    //             asm volatile("fence.proxy.alias;" ::: "memory");
    //         }
    // #endif
    //     }
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/cuda_ipc/semaphore.h
================================================
#pragma once

#include 

#include "src/turbomind/comm/cuda_ipc/common.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

struct SystemSemaphoreInfo {
    uint64_t* outbound[kMaxChannels * kMaxRanks];
    uint64_t* inbound[kMaxChannels * kMaxRanks];
    uint64_t  expected[kMaxChannels * kMaxRanks];
    // uint64_t* mc_ptr[kMaxChannels];
};

struct SystemSemaphoreStorage {

    uint64_t*            data_{};  // uint32[kMaxChannels][kMaxRanks], symmetric
    SystemSemaphoreInfo* info_{};

    template
    void Allocate(int ranks, int rank, AllocReg alloc_reg)
    {
        const size_t byte_size = sizeof(uint64_t) * kMaxChannels * kMaxRanks;

        SymmetricPtr_V2 v = alloc_reg(byte_size);

        data_ = v.uc[rank];

        SystemSemaphoreInfo info{};

        for (int c = 0; c < kMaxChannels; ++c) {  // block idx
            for (int r = 0; r < ranks; ++r) {     // thread idx
                info.inbound[c * kMaxRanks + r]  = v.uc[rank] + c * kMaxRanks + r;
                info.outbound[c * kMaxRanks + r] = v.uc[r] + c * kMaxRanks + rank;
                // info.mc_ptr[c]                   = v.mc + c * kMaxRanks + rank;
            }
        }

        check_cuda_error(cudaMallocAsync(&info_, sizeof(SystemSemaphoreInfo), 0));
        check_cuda_error(cudaMemcpyAsync(info_, &info, sizeof(SystemSemaphoreInfo), cudaMemcpyDefault, 0));

        check_cuda_error(cudaStreamSynchronize(0));
    }

    template
    void Free(DeregFree dereg_free)
    {
        check_cuda_error(cudaFreeAsync(info_, 0));
        info_ = {};

        dereg_free(data_);
        data_ = {};
    }

    SystemSemaphoreInfo* handle()
    {
        return info_;
    }
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/device_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {

DeviceCommImpl::~DeviceCommImpl() = default;

DeviceComm CreateNcclCommunicator(int n_ranks, int rank, HostComm h_comm);

DeviceComm CreateCudaIpcCommunicator(int n_ranks, int rank, HostComm h_comm);

DeviceComm CreateDeviceCommunicator(const std::string& backend, int n_ranks, int rank, HostComm h_comm)
{
#if BUILD_MULTI_GPU && USE_NCCL
    if (backend == "nccl") {
        return CreateNcclCommunicator(n_ranks, rank, h_comm);
    }
#endif

#if BUILD_MULTI_GPU
    if (backend == "native" || backend == "cuda-ipc") {
        return CreateCudaIpcCommunicator(n_ranks, rank, h_comm);
    }
#endif

    TM_CHECK(0) << "Unknown communication backend: " << backend;
    return {};
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/device_comm.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include 

#include 

#include "src/turbomind/comm/host_comm.h"

namespace turbomind::comm {

enum QueryAttr
{
    kHasAllGather2D
};

class DeviceCommImpl {
public:
    virtual ~DeviceCommImpl();

    virtual int n_ranks(int group) const = 0;

    virtual int rank(int group) const = 0;

    virtual void* Allocate(size_t size) = 0;

    virtual void Free(void* ptr) = 0;

    virtual void Register(void* ptr, size_t size) = 0;

    virtual void Deregister(void* ptr) = 0;

    virtual int Split(int color, int key, int group)
    {
        throw std::runtime_error("not implemented");
    }

    virtual int Query(QueryAttr attr) const noexcept = 0;

    virtual void AllReduceSum(const void*  sendbuff,  //
                              void*        recvbuff,
                              size_t       count,
                              DataType     type,
                              int          group,
                              cudaStream_t stream) = 0;

    virtual void AllGather(const void*  sendbuff,  //
                           void*        recvbuff,
                           size_t       sendcount,
                           DataType     type,
                           int          group,
                           cudaStream_t stream) = 0;

    virtual void ReduceScatter(const void*  sendbuff,  //
                               void*        recvbuff,
                               size_t       recvcount,
                               DataType     type,
                               int          group,
                               cudaStream_t stream)
    {
        throw std::runtime_error("not implemented");
    }

    virtual void AllreduceResidualBiasRMSnorm(void*        hidden,
                                              void*        residual,
                                              const void*  bias,
                                              const void*  weights,
                                              float        eps,
                                              int          dim,
                                              int          token_num,
                                              DataType     dtype,
                                              int          group,
                                              cudaStream_t stream)
    {
        throw std::runtime_error("not implemented");
    }

    virtual void AllreduceResidualBiasRMSnormEx(void*        hidden,
                                                void*        residual,
                                                const void*  bias,
                                                const void*  weights,
                                                float        eps,
                                                int          dim,
                                                DataType     type,
                                                int          group0,
                                                int          group1,
                                                const int*   local_token_nums,
                                                cudaStream_t stream)
    {
        throw std::runtime_error("not implemented");
    }

    virtual void AllGather2D(const void*  sendbuff,
                             void*        recvbuff,
                             size_t       pitch,
                             size_t       stride,
                             int          width,
                             int          height,
                             DataType     type,
                             int2         flags,  // (is_first, is_last)
                             int          group,
                             cudaStream_t stream)
    {
        throw std::runtime_error("not implemented");
    }

    virtual void Broadcast(const void*  sendbuff,  //
                           void*        recvbuff,
                           size_t       count,
                           DataType     type,
                           int          root,
                           int          group,
                           cudaStream_t stream)
    {
        throw std::runtime_error("not implemented");
    }
};

class DeviceComm {
public:
    DeviceComm() = default;

    /* implicit */ DeviceComm(std::unique_ptr impl): impl_{std::move(impl)} {}

    DeviceCommImpl* operator->() const noexcept
    {
        return impl_.get();
    }

    operator DeviceCommImpl*() const noexcept
    {
        return impl_.get();
    }

private:
    std::unique_ptr impl_;
};

DeviceComm CreateDeviceCommunicator(const std::string& backend,  //
                                    int                n_ranks,
                                    int                rank,
                                    HostComm           h_comm);

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/env.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 

#include "src/turbomind/utils/logger.h"

namespace turbomind {

template
auto GetEnv()
{
    static auto value = [] {
        bool is_set{};
        auto x  = E::init();
        using T = decltype(x);
        try {
            if (auto p = std::getenv(E::full_name)) {
                is_set = true;
                if constexpr (std::is_integral_v) {
                    x = std::stoll(p);
                }
                else if constexpr (std::is_floating_point_v) {
                    x = std::stod(p);
                }
                else if constexpr (std::is_same_v) {
                    x = std::string{p};
                }
                else {
                    static_assert(!std::is_same_v, "not implemented");
                }
            }
        }
        catch (...) {
        }
        if (is_set) {
            std::stringstream ss;
            ss << x;
            TM_LOG_INFO("[%s] %s=%s", E::prefix, E::name, ss.str().c_str());
        }
        return x;
    }();
    return value;
}

#define TM_ENV_VAR(prefix_, name_, init_)                                                                              \
    struct prefix_##_##name_ {                                                                                         \
        static auto init()                                                                                             \
        {                                                                                                              \
            return init_;                                                                                              \
        }                                                                                                              \
        static constexpr auto prefix    = #prefix_;                                                                    \
        static constexpr auto name      = #name_;                                                                      \
        static constexpr auto full_name = "TM_" #prefix_ "_" #name_;                                                   \
    }

}  // namespace turbomind


================================================
FILE: src/turbomind/comm/gloo/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.8)

include(FetchContent)
FetchContent_Declare(
  gloo
  GIT_REPOSITORY https://github.com/pytorch/gloo.git
  GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4
)

# some settings of gloo,
set(GLOO_INSTALL OFF CACHE BOOL "" FORCE)
set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE)
set(USE_NCCL OFF)
set(BUILD_TEST OFF)
set(USE_IBVERBS OFF)
FetchContent_MakeAvailable(gloo)

# gloo build doesn't add include directories as a target property...
target_include_directories(gloo PUBLIC
    $
    $ # config.h generated at cmake config time
)

target_compile_options(gloo PRIVATE
    $<$:/W0>
    $<$,$>:-w>
)

add_library(gloo_comm STATIC
    gloo_comm.cc
    hybrid_comm.cc
    tcp_store.cc
)
set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(gloo_comm PUBLIC gloo host_comm logger xgrammar)

add_executable(test_ipc_comm test_ipc_comm.cc)
target_link_libraries(test_ipc_comm PRIVATE gloo_comm Threads::Threads)


================================================
FILE: src/turbomind/comm/gloo/gloo_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#if GLOO_HAVE_TRANSPORT_IBVERBS
#include "gloo/transport/ibverbs/device.h"
#endif

#include "src/turbomind/comm/gloo/tcp_store.h"
#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind::comm {

const char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
const char  STORE_INFO_DELIM       = ',';

std::shared_ptr<::gloo::transport::Device> createGlooDevice()
{
#if GLOO_HAVE_TRANSPORT_IBVERBS
    if (auto transport = std::getenv("GLOO_DEVICE_TRANSPORT");
        transport != nullptr && strcmp(transport, "ibverbs") == 0) {
        ::gloo::transport::ibverbs::attr ib_attr{};
        ib_attr.name  = "";
        ib_attr.port  = 1;
        ib_attr.index = 3;  // use IBV_GID_TYPE_ROCE_V2 and ipv4
        return ::gloo::transport::ibverbs::CreateDevice(ib_attr);
    }
#endif
    ::gloo::transport::tcp::attr attr;
    if (auto ifname = std::getenv(GLOO_SOCKET_IFNAME_ENV); ifname) {
        attr.iface = ifname;
    }
    else {
        attr.hostname = ::gloo::getHostname();
    }
    return ::gloo::transport::tcp::CreateDevice(attr);
}

class Store: public ::gloo::rendezvous::PrefixStore {
public:
    explicit Store(const std::string& host, int port, const std::string& prefix):
        host_(host), port_(port), ::gloo::rendezvous::PrefixStore(prefix, nullptr)
    {
        store_ = std::make_shared(host_, port_);
    };

    ~Store() = default;

    std::shared_ptr New(const std::string& prefix)
    {
        std::string new_prefix = prefix + "/" + prefix_;
        return std::make_shared(host_, port_, new_prefix);
    }

public:
    std::string host_;
    int         port_;

    using ::gloo::rendezvous::PrefixStore::store_;
    using ::gloo::rendezvous::PrefixStore::prefix_;
};

class GlobalStoreFactory {
public:
    static GlobalStoreFactory& Instance()
    {
        static GlobalStoreFactory instance;
        return instance;
    }

    std::string New()
    {
        std::lock_guard lock(mutex_);
        TM_CHECK(std::getenv("LMDEPLOY_DIST_INIT_ADDR") != nullptr) << "LMDEPLOY_DIST_INIT_ADDR not set";
        TM_CHECK(std::getenv("LMDEPLOY_DIST_INIT_PORT") != nullptr) << "LMDEPLOY_DIST_INIT_PORT not set";

        std::string host = std::getenv("LMDEPLOY_DIST_INIT_ADDR");
        int         port = std::stoi(std::getenv("LMDEPLOY_DIST_INIT_PORT"));

        std::stringstream ss;
        ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++;
        return ss.str();
    }

    std::shared_ptr Load(const std::string& info)
    {
        std::stringstream        ss(info);
        std::vector keys;
        std::string              local;
        while (getline(ss, local, STORE_INFO_DELIM)) {
            keys.push_back(std::move(local));
        }
        TM_CHECK(keys.size() == 3);

        std::string host   = keys[0];
        int         port   = stoi(keys[1]);
        std::string prefix = keys[2];

        return std::make_shared(host, port, prefix);
    }

private:
    GlobalStoreFactory() {}

    std::mutex mutex_;
    int        prefix_{0};
};

typedef void (*ReduceFunc)(void*, const void*, const void*, size_t);

struct GlooCommImpl: public HostCommImpl {

    struct SplitInfo {
        int color;
        int rank;

        bool operator<(const SplitInfo& other) const
        {
            return (color < other.color) || (color == other.color && rank < other.rank);
        }

        bool operator==(const SplitInfo& other) const
        {
            return (color == other.color) && (rank == other.rank);
        }
    };

    GlooCommImpl(std::shared_ptr store, int n_ranks, int rank):
        store_{std::move(store)}, rank_{rank}, n_ranks_{n_ranks}
    {
        device_  = createGlooDevice();
        context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_);
        context_->setTimeout(kTimeOut);
        context_->connectFullMesh(store_, device_);
    }

    ~GlooCommImpl() {}

    int rank() const override
    {
        return rank_;
    }

    int n_ranks() const override
    {
        return n_ranks_;
    }

    bool is_same_process() const override
    {
        return false;
    }

    std::shared_ptr Split(int color, int key) override
    {
        auto vec  = comm::AllGather(this, SplitInfo{color, rank_});
        auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) {  //
            return x.color == color;
        });
        vec.erase(last, vec.end());
        std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) {  //
            return a < b;
        });

        auto new_prefix  = std::to_string(color) + ":" + std::to_string(n_split_++);
        auto new_store   = store_->New(new_prefix);
        int  new_n_ranks = vec.size();
        int  new_rank    = std::find(vec.begin(), vec.end(), SplitInfo{color, rank_}) - vec.begin();
        return std::make_shared(new_store, new_n_ranks, new_rank);
    }

    void Sync(bool blocking) override
    {
        ::gloo::BarrierOptions opts(context_);
        ::gloo::barrier(opts);
    }

    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override
    {
        // trivially copyable if no ser/des function
        if (!ser || !des) {
            return Broadcast(data, count, dtype, root);
        }

        // broadcast buffer size
        size_t size;
        if (root == rank()) {
            ser(data, 0, count, size, nullptr);
        }
        Broadcast(&size, 1, data_type_v, root);

        // serialize data on root rank
        std::vector bytes;
        bytes.reserve(size);
        if (root == rank()) {
            ser(data, 0, count, size, bytes.data());
        }

        // broadcast serialized data
        Broadcast(bytes.data(), size, data_type_v, root);

        // deserialize data on all ranks
        if (root != rank()) {
            des(data, 0, count, bytes.data(), size);
        }
    }

    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override
    {
        // trivially copyable if no ser/des function
        if (!ser || !des) {
            return AllGather(data, count, dtype);
        }

        // get buffer size on each rank and find max size
        size_t size;
        ser(data, count * rank(), count, size, nullptr);
        std::vector sizes(n_ranks());
        sizes[rank()] = size;
        AllGather(sizes.data(), 1, data_type_v);
        auto max_size = *std::max_element(sizes.begin(), sizes.end());

        // serialize data on each rank
        std::vector bytes(max_size * n_ranks());
        ser(data, count * rank(), count, size, bytes.data() + rank() * max_size);

        // gather serialized data
        AllGather(bytes.data(), max_size, data_type_v);

        // deserialize data on each rank
        for (int i = 0; i < n_ranks(); ++i) {
            if (i != rank()) {
                des(data, i * count, count, bytes.data() + i * max_size, sizes[i]);
            }
        }
    }

    void Broadcast(void* data, int count, DataType dtype, int root)
    {
        ::gloo::BroadcastOptions opts(context_);
        opts.setRoot(root);
        opts.setOutput((char*)data, count * byte_size(dtype));
        ::gloo::broadcast(opts);
    }

    void AllGather(void* data, int count, DataType dtype)
    {
        ::gloo::AllgatherOptions opts(context_);
        opts.setOutput((char*)data, count * byte_size(dtype) * n_ranks_);
        ::gloo::allgather(opts);
    }

    static ReduceFunc getReduceFunc(DataType dtype, RedOp red_op)
    {

        auto dispatch_op = [&](auto t) -> ReduceFunc {
            using T = decltype(t);
            switch (red_op) {
                case RedOp::kSum:
                    return ::gloo::sum;
                case RedOp::kMax:
                    return ::gloo::max;
                case RedOp::kMin:
                    return ::gloo::min;
                default:
                    return {};
            }
        };

        auto dispatch = [&]() -> ReduceFunc {
            switch (dtype) {
                case kInt32:
                    return dispatch_op(int32_t{});
                case kInt64:
                    return dispatch_op(int64_t{});
                case kUint32:
                    return dispatch_op(uint32_t{});
                case kUint64:
                    return dispatch_op(uint64_t{});
                default:
                    return {};
            }
        };

        if (auto fn = dispatch()) {
            return fn;
        }
        else {
            throw std::runtime_error("not implemented");
            return {};
        }
    }

    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override
    {
        ::gloo::AllreduceOptions opts(context_);
        opts.setReduceFunction(getReduceFunc(dtype, red_op));
        switch (dtype) {
            case kInt32:
                opts.setOutput((int32_t*)data, count);
                break;
            case kInt64:
                opts.setOutput((int64_t*)data, count);
                break;
            case kUint32:
                opts.setOutput((uint32_t*)data, count);
                break;
            case kUint64:
                opts.setOutput((uint64_t*)data, count);
                break;
            default:
                throw std::runtime_error("not implemented");
        }
        ::gloo::allreduce(opts);
    }

    // there might be very long intervals between receiving requests.
    static constexpr std::chrono::milliseconds kTimeOut = std::chrono::milliseconds(1000LL * 3600 * 24 * 365);

    int                                          n_split_{};
    std::shared_ptr<::gloo::transport::Device>   device_;
    std::shared_ptr<::gloo::rendezvous::Context> context_;
    std::shared_ptr                       store_;
    int                                          rank_;
    int                                          n_ranks_;
};

class GlooGroupId: public HostGroupId {

    void Initialize() override
    {
        info_ = GlobalStoreFactory::Instance().New();
        TM_LOG_INFO("[TM][COMM] GlooGroupId=%s", info_.c_str());
    }

    void Export(std::ostream& os) override
    {
        os << info_;
    }

    void Import(std::istream& is) override
    {
        std::stringstream ss;
        ss << is.rdbuf();
        info_ = ss.str();
    }

    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override
    {
        TM_CHECK(info_ != "");
        auto impl = std::make_shared(GlobalStoreFactory::Instance().Load(info_), n_ranks, rank);
        return std::static_pointer_cast(impl);
    }

private:
    std::string                                info_;  // ip,port,prefix
    std::shared_ptr<::gloo::rendezvous::Store> store_;
};

std::unique_ptr CreateGlooGroupId()
{
    return std::make_unique();
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/gloo/hybrid_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/check.h"

namespace turbomind::comm {

extern std::unique_ptr CreateThreadGroupId();
extern std::unique_ptr CreateGlooGroupId();

struct HybridCommImpl: public HostCommImpl {

    HybridCommImpl(int n_ranks, int rank, int node_rank, HostGroupId* gloo_group_id, HostGroupId* thread_group_id):
        n_ranks_{n_ranks},  //
        rank_{rank},
        node_rank_(node_rank)
    {
        gloo_comm_     = gloo_group_id->CreateCommunicator(n_ranks, rank);
        rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank);
        same_process_  = rank_to_nodes_.front() == rank_to_nodes_.back();
        if (same_process_) {
            intra_comm_ = thread_group_id->CreateCommunicator(n_ranks, rank);
        }
        else {
            init_inter_comm();
            intra_comm_ = thread_group_id->CreateCommunicator(intra_n_ranks_, rank_to_intra_[rank_]);
        }
    }

    HybridCommImpl(std::shared_ptr gloo_comm, std::shared_ptr intra_comm, int node_rank):
        gloo_comm_{std::move(gloo_comm)},
        intra_comm_{std::move(intra_comm)},
        rank_{gloo_comm_->rank()},
        n_ranks_{gloo_comm_->n_ranks()},
        node_rank_(node_rank)
    {
        rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank);
        same_process_  = rank_to_nodes_.front() == rank_to_nodes_.back();
        if (same_process_) {}
        else {
            init_inter_comm();
        }
    }

    void init_inter_comm()
    {
        int intra_n_ranks = 0;
        int intra_rank    = -1;
        for (int r = 0; r < n_ranks_; ++r) {
            if (rank_to_nodes_[r] == node_rank_) {
                if (r == rank_) {
                    intra_rank = intra_n_ranks;
                }
                intra_n_ranks++;
            }
        }

        intra_n_ranks_ = intra_n_ranks;
        gloo_comm_->AllReduce(&intra_n_ranks_, 1, DataType::kInt, RedOp::kMin);
        TM_CHECK_EQ(intra_n_ranks_, intra_n_ranks) << "The number of ranks in each node should be same.";
        TM_CHECK_GT(intra_rank, -1) << "Invalid intra_rank.";
        rank_to_intra_ = ::turbomind::comm::AllGather(gloo_comm_, intra_rank);

        inter_comm_    = gloo_comm_->Split(rank_to_intra_[rank_], 0);
        rank_to_inter_ = ::turbomind::comm::AllGather(gloo_comm_, inter_comm_->rank());
    }

    std::shared_ptr Split(int color, int key) override
    {
        if (!is_same_process()) {
            auto new_gloo_comm  = gloo_comm_->Split(color, key);
            auto new_intra_comm = intra_comm_->Split(color, key);
            return std::make_shared(new_gloo_comm, new_intra_comm, node_rank_);
        }
        else {
            return intra_comm_->Split(color, key);
        }
    }

    int rank() const override
    {
        return rank_;
    }

    int n_ranks() const override
    {
        return n_ranks_;
    }

    bool is_same_process() const override
    {
        return same_process_;
    }

    void Sync(bool blocking) override
    {
        if (!is_same_process() && rank_to_intra_[rank_] == 0) {
            inter_comm_->Sync(blocking);
        }
        intra_comm_->Sync(blocking);
    }

    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override
    {
        if (!ser || !des) {
            return Broadcast(data, count, dtype, root, copy);
        }

        if (rank_to_intra_[root] == rank_to_intra_[rank_]) {  // same ith rank in node
            inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy, ser, des);
        }
        intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy);
    }

    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy)
    {
        if (is_same_process()) {
            return intra_comm_->Broadcast(data, count, dtype, root, copy);
        }

        if (rank_to_intra_[root] == rank_to_intra_[rank_]) {  // same ith rank in node
            inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy);
        }
        intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy);
    }

    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override
    {
        if (!ser || !des) {
            return AllGather(data, count, dtype, copy);
        }

        return gloo_comm_->AllGather(data, count, dtype, copy, ser, des);
    }

    void AllGather(void* data, int count, DataType dtype, copy_fn copy)
    {
        if (is_same_process()) {
            return intra_comm_->AllGather(data, count, dtype, copy);
        }

        // TODO: support allgatherv in gloo comm (each node may has different rank size)
        return gloo_comm_->AllGather(data, count, dtype, copy);
    }

    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override
    {
        if (is_same_process()) {
            return intra_comm_->AllReduce(data, count, dtype, red_op);
        }

        intra_comm_->AllReduce(data, count, dtype, red_op);
        if (rank_to_intra_[rank_] == 0) {
            inter_comm_->AllReduce(data, count, dtype, red_op);
        }
        intra_comm_->Broadcast(data, byte_size(dtype) * count, data_type_v, 0, detail::copy_fn);
    }

    HostComm gloo_comm_{};   // primitive comm, used for initializing inter_comm and intra_comm
    HostComm inter_comm_{};  // inter-node comm
    HostComm intra_comm_{};  // intra-node comm

    int rank_;       // group rank
    int n_ranks_;    // group size
    int node_rank_;  // node rank
    int intra_n_ranks_;

    std::vector rank_to_nodes_{};  // map group rank to node rank (not global)
    std::vector rank_to_intra_{};  // map group rank to intra-node rank
    std::vector rank_to_inter_{};  // map group rank to inter-node rank

    bool same_process_;
};

class HybridGroupId: public HostGroupId {
public:
    HybridGroupId()
    {
        thread_group_id_ = CreateThreadGroupId();
        gloo_group_id_   = CreateGlooGroupId();
    }

    void Initialize() override
    {
        thread_group_id_->Initialize();
        gloo_group_id_->Initialize();
    }

    void Export(std::ostream& os) override
    {
        thread_group_id_->Export(os);
        gloo_group_id_->Export(os);
    }

    void Import(std::istream& is) override
    {
        thread_group_id_->Import(is);
        gloo_group_id_->Import(is);
    }

    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank)
    {
        auto impl = std::make_shared(n_ranks,  //
                                                     rank,
                                                     node_rank,
                                                     gloo_group_id_.get(),
                                                     thread_group_id_.get());
        return std::static_pointer_cast(impl);
    }

    std::unique_ptr thread_group_id_;
    std::unique_ptr gloo_group_id_;
};

std::unique_ptr CreateHybridGroupId()
{
    return std::make_unique();
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/gloo/tcp_store.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

#include 
#include 

#include "src/turbomind/comm/gloo/tcp_store.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind::comm {

namespace {

// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.8.0-rc4/torch/csrc/distributed/c10d/TCPStoreBackend.hpp

static const uint32_t validationMagicNumber = 0x3C85F7CE;

enum class CheckResponseType : uint8_t
{
    READY,
    NOT_READY
};

enum class QueryType : uint8_t
{
    VALIDATE,
    SET,
    COMPARE_SET,
    GET,
    ADD,
    CHECK,
    WAIT,
    GETNUMKEYS,
    DELETE_KEY,
    APPEND,
    MULTI_GET,
    MULTI_SET,
    CANCEL_WAIT,
    PING,
    QUEUE_PUSH,
    QUEUE_POP,
    QUEUE_LEN,
};

}  // namespace

struct Buffer {
    std::vector buffer;

    template>>
    void append(T val)
    {
        char* ptr = (char*)&val;
        buffer.insert(buffer.end(), ptr, ptr + sizeof(T));
    }

    void append(const std::vector& vec)
    {
        append((uint64_t)vec.size());
        buffer.insert(buffer.end(), vec.begin(), vec.end());
    }

    void append(const std::string& str)
    {
        append((uint64_t)str.size());
        buffer.insert(buffer.end(), str.begin(), str.end());
    }

    const char* data() const
    {
        return buffer.data();
    }

    size_t count() const
    {
        return buffer.size();
    }
};

void validate(std::shared_ptr<::gloo::transport::tcp::Socket>& socket)
{
    Buffer buffer;
    buffer.append(QueryType::VALIDATE);
    buffer.append(validationMagicNumber);
    socket->write(buffer.data(), buffer.count());
}

void ping(std::shared_ptr<::gloo::transport::tcp::Socket>& socket)
{
    Buffer buffer;
    buffer.append(QueryType::PING);
    uint32_t nonce         = getpid();
    uint32_t returnedNonce = -1;
    buffer.append(nonce);
    socket->write(buffer.data(), buffer.count());
    int r = socket->read(&returnedNonce, sizeof(returnedNonce));
    if (nonce != returnedNonce) {
        std::stringstream ss;
        ss << "Ping failed, nonce=" << nonce << ", returnedNonce=" << returnedNonce << ", socket read=" << r;
        throw std::runtime_error(ss.str());
    }
}

TCPStore::TCPStore(const std::string& host, int port)
{
    auto retry = 0;
    do {
        try {
            ::addrinfo hints{}, *res{};
            hints.ai_flags    = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
            hints.ai_family   = AF_UNSPEC;
            hints.ai_socktype = SOCK_STREAM;

            int status = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res);

            std::shared_ptr holder(res, [](addrinfo* p) {
                if (p != nullptr) {
                    freeaddrinfo(p);
                }
            });

            if (status != 0) {
                throw std::runtime_error("getaddrinfo failed: " + std::string(gai_strerror(status)));
            }

            for (::addrinfo* addr = res; addr != nullptr; addr = addr->ai_next) {
                int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
                if (fd == -1) {
                    continue;
                }
                auto socket = std::make_shared<::gloo::transport::tcp::Socket>(fd);
                socket->connect(addr->ai_addr, addr->ai_addrlen);
                socket->noDelay(true);
                socket->recvTimeout(std::chrono::milliseconds(5000));
                socket->sendTimeout(std::chrono::milliseconds(5000));
                validate(socket);  // validate the connection
                ping(socket);      // check send/recv
                socket_ = std::move(socket);
                break;
            }

            if (socket_ == nullptr) {
                throw std::runtime_error("unable to connect to " + host + ":" + std::to_string(port));
            }
        }
        catch (const std::exception& e) {
            TM_LOG_WARNING("[TM][COMM] Failed to connect to store after %d retries: %s", retry, e.what());
            std::this_thread::sleep_for(std::chrono::seconds(1));
            retry += 1;
        }
    } while (socket_ == nullptr);
}

void TCPStore::set(const std::string& key, const std::vector& data)
{
    std::lock_guard lock(mutex_);
    Buffer                      buffer;
    buffer.append(QueryType::SET);
    buffer.append(key);
    buffer.append(data);
    socket_->write(buffer.data(), buffer.count());
}

std::vector TCPStore::get(const std::string& key)
{
    wait({key});
    std::lock_guard lock(mutex_);
    Buffer                      buffer;
    buffer.append(QueryType::GET);
    buffer.append(key);
    socket_->write(buffer.data(), buffer.count());

    uint64_t vec_size;
    socket_->read(&vec_size, sizeof(vec_size));
    std::vector value(vec_size);
    socket_->read(value.data(), value.size());
    return value;
}

bool TCPStore::check(const std::vector& keys)
{
    std::lock_guard lock(mutex_);
    Buffer                      buffer;
    buffer.append(QueryType::CHECK);
    buffer.append((uint64_t)keys.size());
    for (const auto& key : keys) {
        buffer.append(key);
    }
    socket_->write(buffer.data(), buffer.count());

    CheckResponseType response;
    socket_->read(&response, sizeof(response));
    return response == CheckResponseType::READY;
}

void TCPStore::wait(const std::vector& keys, const std::chrono::milliseconds& timeout)
{
    const auto start = std::chrono::steady_clock::now();
    while (!check(keys)) {
        const auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start);
        if (elapsed > timeout) {
            std::stringstream ss;
            ss << "Wait timeout for key(s): [";
            for (const auto& key : keys) {
                ss << key << " ";
            }
            ss << "]";
            TM_LOG_ERROR("[TM][COMM] %s, elapsed %lld s", ss.str().c_str(), elapsed.count());
            throw std::runtime_error("Wait timeout for key(s): " + ss.str());
        }
        std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    }
}

TCPStore::~TCPStore() = default;

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/gloo/tcp_store.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include 
#include 

namespace turbomind::comm {

class TCPStore: public gloo::rendezvous::Store {
public:
    explicit TCPStore(const std::string& host, int port);

    ~TCPStore();

    void set(const std::string& key, const std::vector& data) override;

    std::vector get(const std::string& key) override;

    bool check(const std::vector& keys);

    void wait(const std::vector& keys) override
    {
        wait(keys, std::chrono::seconds(30));
    }

    void wait(const std::vector& keys, const std::chrono::milliseconds& timeout) override;

private:
    std::shared_ptr<::gloo::transport::tcp::Socket> socket_;
    std::mutex                                      mutex_;
};

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/gloo/test_ipc_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/comm/host_comm.h"

using namespace turbomind::comm;

#define TEST_TRIVIALLY_COPYABLE 1

// #define SKIP_SERIALIZE 0 // useless now

// const std::string backend = "";
const std::string backend = "hybrid";
// const std::string backend = "gloo";

struct Store {
    std::string hostname_;
    std::string port_;
    int         nnodes_;
    int         node_rank_;
    std::string py_script_;
    std::string py_file_path_ = "/tmp/start_tcp_store.py";

    std::thread thread_;

    Store(const std::string& hostname, const std::string& port, int nnodes, int node_rank):
        hostname_(hostname), port_(port), nnodes_(nnodes), node_rank_(node_rank)
    {

        int pid = getpid();

        // clang-format off
    py_script_ =
"import psutil\n"
"import os\n"
"import time\n"
"from torch.distributed import TCPStore\n"
"store = TCPStore(host_name='" + hostname_ + "',\n"
"                 port=" + port_ + ",\n"
"                 world_size=" + std::to_string(nnodes_) + ",\n"
"                 is_master=" + (node_rank_ == 0 ? "True" : "False") + ")\n"
"while True:\n"
"    time.sleep(1)\n"
"    if not psutil.pid_exists(" + std::to_string(pid) + "):\n"
"        break\n"
"    if not os.path.exists('/tmp/start_tcp_store.py'):\n"
"        break\n";

        // clang-format on
        std::ofstream py_file(py_file_path_);
        py_file << py_script_;
        py_file.close();

        std::string env_addr = "LMDEPLOY_DIST_INIT_ADDR=" + hostname_;
        std::string env_port = "LMDEPLOY_DIST_INIT_PORT=" + port_;
        setenv("LMDEPLOY_DIST_INIT_ADDR", hostname_.c_str(), 1);
        setenv("LMDEPLOY_DIST_INIT_PORT", port_.c_str(), 1);

        start();
        // wait a moment for the store to start.
        std::this_thread::sleep_for(std::chrono::seconds(3));
    }

    ~Store()
    {
        stop();
    }

    void start()
    {
        const std::string cmd = ("python " + py_file_path_);
        thread_               = std::thread([](const std::string& cmd) { int result = system(cmd.c_str()); }, cmd);
    }

    void stop()
    {
        int r = system("rm /tmp/start_tcp_store.py");
        thread_.join();
    }
};

struct TestGlooComm {
    std::string hostname_;
    std::string port_;
    int         nnodes_;
    int         node_rank_;
    int         n_ranks_per_node_;

    std::vector h_comm_;

    TestGlooComm(const std::string& host, const std::string& port, int nnodes, int node_rank, int n_ranks_per_node):
        hostname_(host), port_(port), nnodes_(nnodes), node_rank_(node_rank), n_ranks_per_node_(n_ranks_per_node)
    {
        h_comm_.resize(n_ranks_per_node_);
    }

    void init()
    {
        std::unique_ptr group_id = CreateHostGroupId(backend);
        std::string                  group_id_data;
        if (1) {  // master
            group_id->Initialize();
            std::stringstream ss;
            group_id->Export(ss);
            group_id_data = ss.str();
        }

        auto init = [&](int rank) {
            // initialize host communicators
            std::stringstream            ss(group_id_data);
            std::unique_ptr host_id = CreateHostGroupId(backend);
            host_id->Import(ss);
            h_comm_[rank % n_ranks_per_node_] =
                host_id->CreateCommunicator(n_ranks_per_node_ * nnodes_, rank, node_rank_);
        };

        std::vector threads;
        for (int i = 0; i < n_ranks_per_node_; ++i) {
            threads.emplace_back(init, n_ranks_per_node_ * node_rank_ + i);
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    void test_broadcast()
    {
        const int count = 10;

        auto fun = [&](HostComm& comm, int rank) {
            for (int r = 0; r < comm->n_ranks(); ++r) {

#if TEST_TRIVIALLY_COPYABLE
                std::vector data(count);
#else
                std::shared_ptr> data_ptr = std::make_shared>(count);
                int*                              data     = data_ptr->data();
#endif

                for (int i = 0; i < count; ++i) {
                    data[i] = i + rank * count;  // i + rank * count
                }

#if TEST_TRIVIALLY_COPYABLE
                Broadcast(comm, data.data(), count, r);
#else
                Broadcast(comm, data_ptr, r);
                data = data_ptr->data();
#endif
                // check result
                for (int i = 0; i < count; ++i) {
                    int expected = i + r * count;
                    if (data[i] != expected) {
                        printf("Rank %d: Broadcast failed at root %d, index %d, got %d, expected %d\n",
                               rank,
                               r,
                               i,
                               data[i],
                               expected);
                    }
                }
            }
        };

        std::vector threads;
        for (size_t i = 0; i < n_ranks_per_node_; ++i) {
            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    void test_allgather()
    {
        const int count = 40;

        auto fun = [&](HostComm& comm, int rank) {

#if TEST_TRIVIALLY_COPYABLE
            std::vector data(count * comm->n_ranks());
            for (int i = 0; i < count; ++i) {
                data[i + count * comm->rank()] = i + rank * count;  // i + rank * count
            }
#else
            std::vector>> data_ptrs(comm->n_ranks());
            data_ptrs[comm->rank()] = std::make_shared>(count);
            int* data = data_ptrs[comm->rank()]->data();
            for (int i = 0; i < count; ++i) {
                data[i] = i + rank * count;  // i + rank * count
            }
#endif

#if TEST_TRIVIALLY_COPYABLE
            AllGather(comm, data.data(), count);
            for (int r = 0; r < comm->n_ranks(); ++r) {
                for (int j = 0; j < count; ++j) {
                    int expected = j + r * count;
                    if (data[j + r * count] != expected) {
                        printf("Rank %d: AllGather failed, index %d, got %d, expected %d\n",
                               rank,
                               j + r * count,
                               data[j + r * count],
                               expected);
                    }
                }
            }
#else
            AllGather(comm, data_ptrs.data(), 1);
            for (int r = 0; r < comm->n_ranks(); ++r) {
                data = data_ptrs[r]->data();
                for (int j = 0; j < count; ++j) {
                    int expected = j + r * count;
                    if (data[j] != expected) {
                        printf("Rank %d: AllGather failed, index %d, got %d, expected %d\n",
                               rank,
                               j + r * count,
                               data[j],
                               expected);
                    }
                }
            }
#endif
        };

        std::vector threads;
        for (size_t i = 0; i < n_ranks_per_node_; ++i) {
            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    void test_allreduce()
    {
        const int count = 10;

        auto fun = [&](HostComm& comm, int rank) {
            std::vector data(count);
            for (int i = 0; i < count; ++i) {
                data[i] = i + rank * count;  // i + rank * count
            }

            AllReduce(comm, data.data(), count, RedOp::kSum);
            for (int j = 0; j < count; ++j) {
                int expected{};
                for (int r = 0; r < comm->n_ranks(); ++r) {
                    expected += j + r * count;
                }
                if (data[j] != expected) {
                    printf("Rank %d: AllReduce failed, index %d, got %d, expected %d\n", rank, j, data[j], expected);
                }
            }
        };

        std::vector threads;
        for (size_t i = 0; i < n_ranks_per_node_; ++i) {
            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    void test_perf()
    {
        const long  kMinDurationNs   = 2e9;  // 2 second
        const long  kWarmupIter      = 5;    // warmup iter
        const float kItersMultiplier = 1.2;

        std::vector count = {1024, 262144, 524288, 1048576, 2097152, 4194304, 67108864};
        //                              1M,     2M,     4M,      8M,      16M,     256M

        if (node_rank_ == 0) {
            printf("%10s %10s %10s %10s %11s %18s %10s\n",
                   "size(MB)",
                   "elements",
                   "avg(us)",
                   "p50(us)",
                   "p99(us)",
                   "bandwidth(GB/s)",
                   "iterations");
        }

        auto fun = [&](HostComm& comm, int rank, int n) {

#if TEST_TRIVIALLY_COPYABLE
            std::vector data(n);
#else
            std::shared_ptr> sptr;
            if (rank == 0) {
                sptr = std::make_shared>(n);
            }
#endif

            std::vector times;

            auto job = [&](int n_iters) {
                times.clear();
                int64_t total = 0;
                int64_t ns    = 0;
                comm->Sync();
                for (int i = 0; i < n_iters; ++i) {
                    auto start = std::chrono::high_resolution_clock::now();
#if TEST_TRIVIALLY_COPYABLE
                    Broadcast(comm, data.data(), n, 0);
#else
                    Broadcast(comm, sptr, 0);
#endif
                    auto    now = std::chrono::high_resolution_clock::now();
                    int64_t ns  = std::chrono::duration_cast(now - start).count();
                    total += ns;
                    times.push_back(ns);
                }
                Broadcast(comm, total, 0);
                return total;
            };

            auto warmup_dur = job(kWarmupIter) / kWarmupIter;
            auto iter       = (int)std::max(kMinDurationNs / warmup_dur * 0.5f, 100.f);

            while (1) {
                auto dur = job(iter);
                std::sort(times.begin(), times.end());

                if (rank == 0) {
                    size_t bytes = n * sizeof(int);
                    int    p50   = std::min(times.size() / 2, times.size() - 1);
                    int    p99   = std::min((int)(times.size() * 0.99), (int)times.size() - 1);
                    printf("%10.5f %10d %10lld %10lld %10lld %18.3f %10lld\n",
                           bytes / 1024.f / 1024.f,
                           n,
                           static_cast(dur / 1e3f / iter),
                           static_cast(times[p50] / 1e3f),
                           static_cast(times[p99] / 1e3f),
                           (bytes * iter) / (dur / 1e9f) / (1024 * 1024 * 1024),
                           static_cast(iter));
                }

                if (dur >= kMinDurationNs) {
                    break;
                }
                iter = std::max(iter * kItersMultiplier, iter + 1.f);
            }
        };

        for (auto n : count) {
            std::vector threads;
            for (size_t i = 0; i < n_ranks_per_node_; ++i) {
                threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i, n);
            }
            for (auto& t : threads) {
                t.join();
            }
        }
    }
};

// ./test_gloo_comm    
int main(int argc, char* argv[])
{
    if (argc != 5) {
        std::cerr << "Usage: " << argv[0] << "    " << std::endl;
        return -1;
    }

    int nnodes           = std::atoi(argv[1]);
    int node_rank        = std::atoi(argv[2]);
    int n_ranks_per_node = std::atoi(argv[3]);

    const std::string init_addr = argv[4];
    auto              pos       = init_addr.find(":");
    const std::string host      = init_addr.substr(0, pos);
    const std::string port      = init_addr.substr(pos + 1);

    Store store(host, port, nnodes, node_rank);

    {
        TestGlooComm test(host, port, nnodes, node_rank, n_ranks_per_node);
        test.init();

        test.test_broadcast();
        test.test_allgather();
        test.test_allreduce();

        // test.test_perf();
    }

    return 0;
}


================================================
FILE: src/turbomind/comm/host_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/host_comm.h"

namespace turbomind::comm {

HostCommImpl::~HostCommImpl() = default;

std::unique_ptr CreateThreadGroupId();

std::unique_ptr CreateGlooGroupId();

std::unique_ptr CreateHybridGroupId();

std::unique_ptr CreateHostGroupId(const std::string& backend)
{
#ifdef BUILD_MULTI_GPU
    if (backend == "hybrid") {
        return CreateHybridGroupId();
    }
    if (backend == "gloo") {
        return CreateGlooGroupId();
    }
#endif

    return CreateThreadGroupId();
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/host_comm.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/serdes.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind::comm {

enum class RedOp
{
    kSum,
    kMin,
    kMax,
};

typedef void (*copy_fn)(void* src, int n, void* dst, int offset);

typedef void (*reduce_fn)(void* src, int n, void* dst, int offset);

typedef void (*ser_fn)(void* data, int offset, int n, size_t& size, void* out);

typedef void (*des_fn)(void* data, int offset, int n, void* in, size_t size);

class HostCommImpl {
public:
    virtual ~HostCommImpl();

    virtual int rank() const = 0;

    virtual int n_ranks() const = 0;

    virtual bool is_same_process() const = 0;

    virtual std::shared_ptr Split(int color, int key) = 0;

    virtual void Sync(bool blocking = false) = 0;

    virtual void Broadcast(void*    data,  //
                           int      count,
                           DataType dtype,
                           int      root,
                           copy_fn  copy,
                           ser_fn   ser = nullptr,
                           des_fn   des = nullptr) = 0;

    virtual void AllGather(void*    data,  //
                           int      count,
                           DataType dtype,
                           copy_fn  copy,
                           ser_fn   ser = nullptr,
                           des_fn   des = nullptr) = 0;

    virtual void AllReduce(void* data, int count, DataType dtype, RedOp red_op) = 0;
};

class HostComm {
public:
    HostComm() = default;

    /* implicit */ HostComm(std::shared_ptr impl): impl_{std::move(impl)} {}

    HostCommImpl* operator->() const noexcept
    {
        return impl_.get();
    }

    operator HostCommImpl*() const noexcept
    {
        return impl_.get();
    }

private:
    std::shared_ptr impl_;
};

namespace detail {
template
void copy_fn(void* src, int n, void* dst, int offset)
{
    std::copy_n((T*)src + offset, n, (T*)dst + offset);
}

template
void ser_fn(void* data, int offset, int n, size_t& size, void* out)
{
    if (out == nullptr) {
        size = 0;
        core::BinarySizeArchive sa;
        for (int i = 0; i < n; ++i) {
            sa&((T*)data)[offset + i];
        }
        size = sa.size();
    }
    else {
        core::BinaryOutputArchive oa(core::ArrayWrapper((std::byte*)out, size));
        for (int i = 0; i < n; ++i) {
            oa&((T*)data)[offset + i];
        }
    }
}

template
void des_fn(void* data, int offset, int n, void* in, size_t size)
{
    core::BinaryInputArchive ia(core::ArrayWrapper((std::byte*)in, size));
    for (int i = 0; i < n; ++i) {
        ia&((T*)data)[offset + i];
    }
}

}  // namespace detail

//////////////////////////////////////////////////////////////////////////////////
// Typed array interface
template
void Broadcast(HostCommImpl* comm, T* data, int n, int root)
{
    if constexpr (std::is_trivially_copyable_v) {
        comm->Broadcast(data, sizeof(T) * n, data_type_v, root, detail::copy_fn);
    }
    else {
        if (comm->is_same_process()) {
            /// TODO: Constness should be considered
            comm->Broadcast(data, n, kNull, root, detail::copy_fn);
        }
        else {
            comm->Broadcast(data, n, kNull, root, detail::copy_fn, detail::ser_fn, detail::des_fn);
        }
    }
}

template
void AllGather(HostCommImpl* comm, T* data, int n)
{
    if constexpr (std::is_trivially_copyable_v) {
        comm->AllGather(data, sizeof(T) * n, data_type_v, detail::copy_fn);
    }
    else {
        if (comm->is_same_process()) {
            /// TODO: Constness should be considered
            comm->AllGather(data, n, kNull, detail::copy_fn);
        }
        else {
            comm->AllGather(data, n, kNull, detail::copy_fn, detail::ser_fn, detail::des_fn);
        }
    }
}

template
void AllReduce(HostCommImpl* comm, T* data, int n, RedOp red_op)
{
    comm->AllReduce(data, n, data_type_v, red_op);
}

//////////////////////////////////////////////////////////////////////////////////
// Typed value interface
template
void Broadcast(HostCommImpl* comm, T& value, int root)
{
    Broadcast(comm, &value, 1, root);
}

template
std::vector AllGather(HostCommImpl* comm, const T& value)
{
    std::vector ret(comm->n_ranks());
    ret.at(comm->rank()) = value;
    AllGather(comm, ret.data(), 1);
    return ret;
}

template
T AllReduce(HostCommImpl* comm, const T& value, RedOp red_op)
{
    T tmp = value;
    AllReduce(comm, &tmp, 1, red_op);
    return tmp;
}

class HostGroupId {
public:
    virtual ~HostGroupId() = default;

    virtual void Initialize()             = 0;
    virtual void Export(std::ostream& os) = 0;
    virtual void Import(std::istream& is) = 0;

    virtual HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) = 0;
};

std::unique_ptr CreateHostGroupId(const std::string& backend);

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/nccl/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)

add_library(nccl_comm STATIC nccl.cu)
target_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger)
target_include_directories(nccl_comm PRIVATE ${NCCL_INCLUDE_DIRS})

set_property(TARGET nccl_comm PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET nccl_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)


================================================
FILE: src/turbomind/comm/nccl/nccl.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 

#include 

#include 

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/string_utils.h"

#include "src/turbomind/kernels/norm/rms_norm.h"

#define NCCLCHECK(e)                                                                                                   \
    if (auto ec = e; ec != ncclSuccess) {                                                                              \
        auto msg = fmtstr("NCCL error %s:%d '%s'", __FILE__, __LINE__, ncclGetErrorString(ec));                        \
        throw std::runtime_error(msg.c_str());                                                                         \
    }

#if NCCL_VERSION_CODE < NCCL_VERSION(2, 27, 0)
/* Window Registration flags */
#define NCCL_WIN_DEFAULT 0x00
#define NCCL_WIN_COLL_SYMMETRIC 0x01
#endif

namespace turbomind::comm {

static inline ncclDataType_t to_nccl_dtype(DataType type)
{
    switch (type) {
        case kFloat32:
            return ncclFloat;
        case kFloat16:
            return ncclHalf;
        case kBfloat16:
            return ncclBfloat16;
        case kUint8:
            return ncclUint8;
        default:
            throw std::runtime_error("not supported");
    }
}

struct NcclApis {
    ncclResult_t (*ncclMemAlloc)(void** ptr, size_t size);
    ncclResult_t (*ncclMemFree)(void* ptr);
    ncclResult_t (*ncclCommRegister)(const ncclComm_t comm, void* buff, size_t size, void** handle);
    ncclResult_t (*ncclCommDeregister)(const ncclComm_t comm, void* handle);
    ncclResult_t (*ncclCommWindowRegister)(ncclComm_t comm, void* buff, size_t size, void** win, int winFlags);
    ncclResult_t (*ncclCommWindowDeregister)(ncclComm_t comm, void* win);
    // `ncclConfig_t` varies between versions, should be fine as long as we are passing nullptr to it
    ncclResult_t (*ncclCommSplit)(ncclComm_t comm, int color, int key, ncclComm_t* newcomm, void* config);
};

static NcclApis& nccl_apis()
{
    static auto value = [] {
        int version{};
        ncclGetVersion(&version);
        auto     handle = dlopen("libnccl.so.2", RTLD_LAZY);
        NcclApis apis{};
        if (!handle) {
            return apis;
        }
        auto load_symbol = [&](auto& dst, auto name) {
            using T = std::remove_reference_t;
            dst     = reinterpret_cast(dlsym(handle, name));
        };
        if (version >= NCCL_VERSION(2, 27, 0)) {
            if (version < NCCL_VERSION(2, 28, 0)) {
                TM_LOG_WARNING(
                    "[NCCL] Window registration may cause memory leaks in NCCL 2.27, use NCCL 2.28+ or disable the feature by setting NCCL_WIN_ENABLE=0.");
            }
            load_symbol(apis.ncclCommWindowRegister, "ncclCommWindowRegister");
            load_symbol(apis.ncclCommWindowDeregister, "ncclCommWindowDeregister");
        }
        else {
            TM_LOG_WARNING(
                "[NCCL] Window registration is not supported by NCCL %d, use NCCL 2.28+ for better performance.",
                version);
        }
        if (version >= NCCL_VERSION(2, 19, 0)) {
            load_symbol(apis.ncclMemAlloc, "ncclMemAlloc");
            load_symbol(apis.ncclMemFree, "ncclMemFree");
            load_symbol(apis.ncclCommRegister, "ncclCommRegister");
            load_symbol(apis.ncclCommDeregister, "ncclCommDeregister");
        }
        if (version >= NCCL_VERSION(2, 18, 0)) {
            load_symbol(apis.ncclCommSplit, "ncclCommSplit");
        }
        else {
            TM_LOG_WARNING("[NCCL] Splitting communicators is not supported by NCCL %d, use NCCL 2.18+ if needed.",
                           version);
        }
        return apis;
    }();
    return value;
}

class NcclCommImpl: public DeviceCommImpl {
public:
    NcclCommImpl(ncclComm_t comm, int n_ranks, int rank, HostComm h_comm):
        h_comm_{h_comm}, global_n_ranks_{n_ranks}, global_rank_{rank}, groups_{comm}
    {
        handles_.emplace_back();
    }

    ~NcclCommImpl()
    {
        for (const auto& [ptr, _] : handles_.at(0)) {
            TM_LOG_WARNING("[NCCL][%d] Buffer %p is not deregistered", global_rank_, ptr);
        }

        for (const auto& [ptr, size] : buffers_) {
            TM_LOG_WARNING("[NCCL][%d] Allocation (%p, %lu) is not freed", global_rank_, ptr, size);
        }

        for (auto& c : groups_) {
            if (auto ec = ncclCommDestroy(c); ec != ncclSuccess) {
                TM_LOG_ERROR("[NCCL][%d] Failed to destroy communicator: %s", global_rank_, ncclGetErrorString(ec));
            }
        }
    }

    int rank(int group) const override
    {
        int rank{};
        NCCLCHECK(ncclCommUserRank(groups_.at(group), &rank));
        return rank;
    }

    int n_ranks(int group) const override
    {
        int n_ranks{};
        NCCLCHECK(ncclCommCount(groups_.at(group), &n_ranks));
        return n_ranks;
    }

    void* Allocate(size_t size) override
    {
        void* ptr{};
        if (auto alloc_fn = nccl_apis().ncclMemAlloc) {
            NCCLCHECK(alloc_fn(&ptr, size));
        }
        else {
            check_cuda_error(cudaMalloc(&ptr, size));
        }
        buffers_.emplace(ptr, size);
        return ptr;
    }

    void Free(void* ptr) override
    {
        if (auto it = buffers_.find(ptr); it != buffers_.end()) {
            if (auto free_fn = nccl_apis().ncclMemFree) {
                NCCLCHECK(free_fn(ptr));
            }
            else {
                check_cuda_error(cudaFree(ptr));
            }
            buffers_.erase(ptr);
        }
        else {
            TM_LOG_WARNING("[NCCL][%d] Freeing %p which is not allocated by NcclComm", global_rank_, ptr);
        }
    }

    void Register(void* ptr, size_t size) override
    {
        if (!handles_.at(0).count(ptr)) {
            for (size_t i = 0; i < handles_.size(); ++i) {
                Register(i, ptr, size);
            }
        }
        else {
            TM_LOG_WARNING("[NCCL][%d] Duplicated registration on (%p, %lu)", global_rank_, ptr, size);
        }
    }

    void Deregister(void* ptr) override
    {
        if (handles_.at(0).count(ptr)) {
            for (size_t i = 0; i < handles_.size(); ++i) {
                Deregister(i, ptr);
            }
        }
        else {
            TM_LOG_WARNING("[NCCL][%d] Deregistering non-registered address %p", global_rank_, ptr);
        }
    }

    void Register(int group, void* buff, size_t size)
    {
        void* handle{};
        auto  comm = groups_.at(group);
        if (auto func = nccl_apis().ncclCommWindowRegister) {
            NCCLCHECK(func(comm, buff, size, &handle, NCCL_WIN_COLL_SYMMETRIC));
        }
        else if (auto func = nccl_apis().ncclCommRegister) {
            NCCLCHECK(func(comm, buff, size, &handle));
        }
        handles_.at(group).emplace(buff, std::make_pair(handle, size));
    }

    void Deregister(int group, void* buff)
    {
        auto& handles = handles_.at(group);
        if (auto it = handles.find(buff); it != handles.end()) {
            if (auto func = nccl_apis().ncclCommWindowDeregister) {
                NCCLCHECK(func(groups_.at(group), it->second.first));
            }
            else if (auto func = nccl_apis().ncclCommDeregister) {
                NCCLCHECK(func(groups_.at(group), it->second.first));
            }
            handles.erase(it);
        }
    }

    int Split(int color, int key, int group) override
    {
        auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);

        ncclComm_t comm{};
        NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));

        int index = groups_.size();
        groups_.push_back(comm);
        handles_.emplace_back();

        // register all existing buffers on the group
        for (const auto& [k, v] : handles_.at(0)) {
            Register(index, k, v.second);
        }

        return index;
    }

    int Query(QueryAttr attr) const noexcept override
    {
        return 0;
    }

    void AllReduceSum(
        const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream) override
    {
        NCCLCHECK(ncclGroupStart());
        NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, count, to_nccl_dtype(type), ncclSum, groups_.at(group), stream));
        NCCLCHECK(ncclGroupEnd());
    }

    void AllGather(
        const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream) override
    {
        NCCLCHECK(ncclGroupStart());
        NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, to_nccl_dtype(type), groups_.at(group), stream));
        NCCLCHECK(ncclGroupEnd());
    }

    void ReduceScatter(
        const void* sendbuff, void* recvbuff, size_t recvcount, DataType type, int group, cudaStream_t stream) override
    {
        NCCLCHECK(ncclGroupStart());
        NCCLCHECK(
            ncclReduceScatter(sendbuff, recvbuff, recvcount, to_nccl_dtype(type), ncclSum, groups_.at(group), stream));
        NCCLCHECK(ncclGroupEnd());
    }

    void AllreduceResidualBiasRMSnorm(void*        hidden,
                                      void*        residual,
                                      const void*  bias,
                                      const void*  weights,
                                      float        eps,
                                      int          dim,
                                      int          token_num,
                                      DataType     dtype,
                                      int          group,
                                      cudaStream_t stream) override
    {
        const auto elem_size = byte_size(dtype);

        auto rms_norm = [&](int64_t first, int64_t count) {
            invokeResidualBiasRMSNorm((char*)hidden + elem_size * first * dim,
                                      (char*)residual + elem_size * first * dim,
                                      weights,
                                      bias,
                                      dtype,
                                      dim,
                                      count,
                                      eps,
                                      stream);
        };

        if (1) {
            AllReduceSum(hidden, hidden, token_num * dim, dtype, group, stream);
            rms_norm(0, token_num);
        }
        else {  // Only useful for large input size
            const int    n_ranks   = this->n_ranks(group);
            const int    rank      = this->rank(group);
            const int    slice     = (token_num + n_ranks - 1) / n_ranks;
            const size_t recvcount = slice * dim;
            auto         sendbuff  = hidden;
            auto         recvbuff  = (char*)hidden + elem_size * rank * recvcount;
            ReduceScatter(sendbuff, recvbuff, recvcount, dtype, group, stream);
            rms_norm(rank * slice, slice);
            AllGather(recvbuff, sendbuff, recvcount, dtype, group, stream);
        }
    }

    void AllreduceResidualBiasRMSnormEx(void*        hidden,
                                        void*        residual,
                                        const void*  bias,
                                        const void*  weights,
                                        float        eps,
                                        int          dim,
                                        DataType     type,
                                        int          group0,
                                        int          group1,
                                        const int*   local_token_nums,
                                        cudaStream_t stream) override
    {
        const size_t         elem_size = byte_size(type);
        const ncclDataType_t nccl_type = to_nccl_dtype(type);

        FT_CHECK(group0 == 0 || group1 == 0);

        ncclComm_t comm0 = groups_.at(group0);
        ncclComm_t comm1 = groups_.at(group1);

        int tp0{}, tp1{};
        NCCLCHECK(ncclCommCount(comm0, &tp0));
        NCCLCHECK(ncclCommCount(comm1, &tp1));

        const int inner_tp = std::min(tp0, tp1);

        FT_CHECK(tp0 % inner_tp == 0 && tp1 % inner_tp == 0);

        std::vector> tasks;
        tasks.reserve(global_n_ranks_);

        for (int i = 0, offset = 0; i < global_n_ranks_; ++i) {
            const int num   = local_token_nums[i / inner_tp];
            const int slice = (num + inner_tp - 1) / inner_tp;
            const int first = std::min(num, i % inner_tp * slice);
            const int last  = std::min(num, first + slice);
            tasks.emplace_back(offset, first, last - first);
            if ((i + 1) % inner_tp == 0) {
                offset += num;
            }
        }

        if (tp0 > 1) {
            NCCLCHECK(ncclGroupStart());
            for (int i = 0; i < global_n_ranks_; ++i) {
                if (auto& [offset, first, num] = tasks[i]; num > 0) {
                    char* buff = (char*)hidden + elem_size * (offset + first) * dim;
                    NCCLCHECK(ncclReduce(buff, buff, (size_t)num * dim, nccl_type, ncclSum, i % tp0, comm0, stream));
                }
            }
            NCCLCHECK(ncclGroupEnd());
            sync_check_cuda_error();
        }

        if (auto& [offset, first, num] = tasks[global_rank_]; num > 0) {
            char* buff = (char*)hidden + elem_size * (offset + first) * dim;
            invokeResidualBiasRMSNorm(
                buff, (char*)residual + elem_size * first * dim, weights, bias, type, dim, num, eps, stream);
            sync_check_cuda_error();
        }

        if (tp1 > 1) {
            NCCLCHECK(ncclGroupStart());
            for (int i = 0; i < global_n_ranks_; ++i) {
                if (auto& [offset, first, num] = tasks[i]; num > 0) {
                    char* buff = (char*)hidden + elem_size * (offset + first) * dim;
                    NCCLCHECK(ncclBroadcast(buff, buff, (size_t)num * dim, nccl_type, i % tp1, comm1, stream));
                }
            }
            NCCLCHECK(ncclGroupEnd());
            sync_check_cuda_error();
        }
    }

    void Broadcast(const void*  sendbuff,  //
                   void*        recvbuff,
                   size_t       count,
                   DataType     type,
                   int          root,
                   int          group,
                   cudaStream_t stream) override
    {
        NCCLCHECK(ncclBroadcast(recvbuff, recvbuff, count, to_nccl_dtype(type), root, groups_.at(group), stream));
    }

private:
    HostComm h_comm_;

    int global_n_ranks_;
    int global_rank_;

    std::vector groups_;

    std::vector>> handles_;

    std::unordered_map buffers_;
};

DeviceComm CreateNcclCommunicator(int n_ranks, int rank, HostComm h_comm)
{
    ncclUniqueId uid{};
    if (rank == 0) {
        NCCLCHECK(ncclGetUniqueId(&uid));
    }

    static_assert(std::is_trivially_copyable_v);
    Broadcast(h_comm, uid, 0);

    ncclComm_t comm{};
    NCCLCHECK(ncclCommInitRank(&comm, n_ranks, uid, rank));

    return DeviceComm{std::make_unique(comm, n_ranks, rank, h_comm)};
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/comm/test_comm.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

// #include 
#include 

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/utils/cuda_utils.h"

using namespace turbomind::comm;
using turbomind::data_type_v;
using turbomind::check;
using turbomind::myAssert;
using std::vector;

[[maybe_unused]] static constexpr bool is_ncu = 0;

struct Context {

    cudaStream_t stream;

    cudaEvent_t ev_start;
    cudaEvent_t ev_end;

    std::vector buffers;

    template
    float exec(F func)
    {
        check_cuda_error(cudaStreamSynchronize(stream));
        check_cuda_error(cudaEventRecord(ev_start, stream));

        func(stream);

        check_cuda_error(cudaEventRecord(ev_end, stream));
        check_cuda_error(cudaEventSynchronize(ev_end));
        float ms{};
        check_cuda_error(cudaEventElapsedTime(&ms, ev_start, ev_end));
        return ms;
    }

    template
    T* malloc(size_t count)
    {
        T* data;
        check_cuda_error(cudaMallocAsync(&data, sizeof(T) * count, stream));
        buffers.push_back(data);
        return data;
    }

    template
    void copy_n(const T* src, size_t count, T* dst)
    {
        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(T) * count, cudaMemcpyDefault, stream));
    }

    void sync()
    {
        check_cuda_error(cudaStreamSynchronize(stream));
    }

    Context(int device_id)
    {
        check_cuda_error(cudaSetDevice(device_id));
        check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
        check_cuda_error(cudaEventCreate(&ev_start));
        check_cuda_error(cudaEventCreate(&ev_end));
    }
    ~Context()
    {
        for (auto& p : buffers) {
            cudaFreeAsync(p, stream);
            p = {};
        }
        cudaStreamSynchronize(stream);
        cudaEventDestroy(ev_end);
        cudaEventDestroy(ev_start);
        cudaStreamDestroy(stream);
    }
};

struct TestComm {
    std::vector   h_comm_;
    std::vector d_comm_;
    std::vector   h_split_;
    std::vector        d_split_;

    int              warmup_;
    int              iters_;
    std::vector tokens_;
    size_t           max_tokens_;

    static auto Init(int n_ranks, int split, const std::string& backend)
    {

        std::unique_ptr group_id = CreateHostGroupId({});
        std::string                  group_id_data;
        if (1) {  // master
            group_id->Initialize();
            std::stringstream ss;
            group_id->Export(ss);
            group_id_data = ss.str();
        }

        std::vector d_comm(n_ranks);
        std::vector   h_comm(n_ranks);
        std::vector        d_split(n_ranks);
        std::vector   h_split(n_ranks);

        auto init = [&](int rank) {
            // initialize host communicators
            std::stringstream            ss(group_id_data);
            std::unique_ptr host_id = CreateHostGroupId({});
            host_id->Import(ss);
            h_comm[rank] = host_id->CreateCommunicator(n_ranks, rank);

            // initialize device communicators
            cudaSetDevice(rank);
            d_comm[rank] = CreateDeviceCommunicator(backend, n_ranks, rank, h_comm[rank]);

            // split communicators
            if (split) {
                h_split[rank] = h_comm[rank]->Split(rank / split, 0);
                d_split[rank] = d_comm[rank]->Split(rank / split, 0, 0);
            }
            else {
                h_split[rank] = h_comm[rank];
                d_split[rank] = 0;
            }
        };

        std::vector threads;
        for (int i = 0; i < n_ranks; ++i) {
            threads.emplace_back(init, i);
        }
        for (auto& t : threads) {
            t.join();
        }

        return std::make_tuple(h_comm, std::move(d_comm), h_split, d_split);
    }

    void Run(int hidden_dim, int vocab_size, int tp, int warmup, int iters, std::vector tokens)
    {
        int device_num{};
        cudaGetDeviceCount(&device_num);

        std::cout << "Device count: " << device_num << "\n";

        if (tp < 0) {
            tp = device_num;
        }

        std::tie(h_comm_, d_comm_, h_split_, d_split_) = Init(device_num, 4, "cuda-ipc");

        TM_CHECK_GT(h_comm_.size(), 0);
        TM_CHECK_GT(d_comm_.size(), 0);

        warmup_ = warmup;
        iters_  = iters;
        tokens_ = tokens;

        max_tokens_ = *std::max_element(tokens_.begin(), tokens_.end());

        const int g = 0;

        TestAllReduce(hidden_dim, 0);
        // TestAllreduceResidualBiasRMSnorm(hidden_dim, g);
        // TestAllreduceResidualBiasRMSnormEx(hidden_dim, 0, 0);
        // TestAllreduceResidualBiasRMSnormEx(hidden_dim, 1, 0);
        // TestAllreduceResidualBiasRMSnormEx(hidden_dim, 0, 1);
        // TestAllGather(hidden_dim / tp, g);  // tp embedding
        // TestAllGather(vocab_size / tp, g);
        // TestBroadcast(32768, g);
    }

    template
    void TestAllReduce(size_t dim, int group = 0)
    {
        const auto dtype = data_type_v;

        const int tp_size = d_comm_[0]->n_ranks(group);
        const int dp_size = d_comm_.size() / tp_size;

        //    dp         tp           dim
        std::vector>> data(dp_size);
        //    dp         dim
        std::vector> ref_data(dp_size);

        for (int i = 0; i < dp_size; ++i) {
            data[i].resize(tp_size);
            ref_data[i].resize(max_tokens_ * dim);
        }

        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {
            const int rank    = d_comm->rank(group);
            const int n_ranks = d_comm->n_ranks(group);
            const int g_rank  = d_comm->rank(0);
            const int d       = g_rank / n_ranks;

            const size_t max_count = max_tokens_ * dim;

            std::mt19937                  gen{(unsigned)index};
            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits
            if (g_rank == 0) {
                std::cout << "preparing data ... " << std::flush;
            }
            data[d][rank].resize(max_count);
            for (size_t i = 0; i < max_count; ++i) {
                data[d][rank][i] = T(dist(gen));
            }
            h_comm->Sync();
            const size_t slice = (max_count + n_ranks - 1) / n_ranks;
            for (int r = 0; r < n_ranks; ++r) {
                for (size_t i = rank * slice; i < (rank + 1) * slice && i < max_count; ++i) {
                    ref_data[d][i] += data[d][r][i];
                }
            }
            h_comm->Sync();
            if (g_rank == 0) {
                std::cout << "done.\n";
            }

            Context ctx{g_rank};

            T* d_data = ctx.malloc(max_count);

            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count);
            d_comm->Register(d_tmp, sizeof(T) * max_count);

            ctx.copy_n(data[d][rank].data(), max_count, d_data);

            [[maybe_unused]] auto verify = [&](auto count) {
                std::vector res(count);
                ctx.copy_n(d_tmp, count, res.data());
                ctx.sync();
                size_t diff = 0;
                for (size_t i = 0; i < count; ++i) {
                    auto& x = res[i];
                    auto& y = ref_data[d][i];
                    diff += x != y;
                    if (diff == 1) {
                        printf("%d: %f vs %f\n", (int)i, (float)x, (float)y);
                    }
                }
                if (diff) {
                    printf("[rank %d] count = %d, diff = %lu\n", g_rank, (int)count, diff);
                    std::this_thread::sleep_for(std::chrono::seconds(1));
                    std::abort();
                }
            };

            std::vector deltas;
            for (const auto& n : tokens_) {
                const size_t count = (size_t)n * dim;
                auto&        delta = deltas.emplace_back();
                h_comm->Sync();
                for (int i = 0; i < warmup_ + iters_; ++i) {
                    ctx.copy_n(d_data, count, d_tmp);
                    auto ms = ctx.exec([&](auto stream) {  //
                        d_comm->AllReduceSum(d_tmp, d_tmp, count, dtype, group, stream);
                    });
                    if (i >= warmup_) {
                        delta += ms;
                    }
                    // verify(count);
                }
                verify(count);
            }

            if (g_rank == 0) {
                SummaryHeader("allreduce", dim, n_ranks);
                for (size_t i = 0; i < tokens_.size(); ++i) {
                    const float  avg   = deltas[i] / iters_;
                    const size_t count = tokens_[i] * dim;
                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;
                    const float  busbw = algbw * (2 * (n_ranks - 1)) / n_ranks;
                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);
                }
            }

            d_comm->Deregister(d_tmp);
            d_comm->Free(d_tmp);
        };

        std::vector threads;
        for (size_t i = 0; i < d_comm_.size(); ++i) {
            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    template
    void TestAllreduceResidualBiasRMSnorm(size_t dim, int group)
    {
        vector weight(dim);
        vector bias(dim);

        constexpr float eps      = 1e-5;
        constexpr bool  has_bias = true;

        std::cout << "preparing data ... " << std::flush;

        {
            std::mt19937                  gen{};
            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits
            for (size_t i = 0; i < dim; ++i) {
                weight[i] = T(dist(gen));
            }
            if (has_bias) {
                for (size_t i = 0; i < dim; ++i) {
                    bias[i] = T(dist(gen));
                }
            }
        }

        const auto dtype = data_type_v;

        const int tp_size = d_comm_[0]->n_ranks(group);
        const int dp_size = d_comm_.size() / tp_size;
        // dp    tp     dim
        vector>> src_data(dp_size);
        // dp    dim
        vector> ref_data(dp_size);
        vector> src_res(dp_size);
        vector> ref_res(dp_size);

        for (int i = 0; i < dp_size; ++i) {
            src_data[i].resize(tp_size);
            ref_data[i].resize(max_tokens_ * dim);
            src_res[i].resize(max_tokens_ * dim);
            ref_res[i].resize(max_tokens_ * dim);
        }

        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {
            const int rank    = d_comm->rank(group);
            const int n_ranks = d_comm->n_ranks(group);
            const int g_rank  = d_comm->rank(0);
            const int d       = g_rank / n_ranks;

            const size_t max_count = max_tokens_ * dim;

            std::mt19937                  gen{(unsigned)index};
            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits

            src_data[d][rank].resize(max_count);
            for (size_t i = 0; i < max_count; ++i) {
                src_data[d][rank][i] = T(dist(gen));
            }
            h_comm->Sync();
            const size_t slice = (max_tokens_ + n_ranks - 1) / n_ranks;
            for (size_t t = rank * slice; t < (rank + 1) * slice && t < max_tokens_; ++t) {
                for (int r = 0; r < n_ranks; ++r) {
                    for (size_t i = 0; i < dim; ++i) {
                        ref_data[d][t * dim + i] += src_data[d][r][t * dim + i];
                    }
                }
                float sum = 0.f;
                for (size_t i = 0; i < dim; ++i) {
                    const size_t idx = t * dim + i;
                    src_res[d][idx]  = T(dist(gen));
                    ref_res[d][idx]  = src_res[d][idx] + ref_data[d][idx] + bias[i];  // r' <- r + (h + b)
                    sum += (float)ref_res[d][idx] * (float)ref_res[d][idx];
                }
                sum = 1 / (sqrtf(sum / dim) + eps);
                for (size_t i = 0; i < dim; ++i) {
                    const size_t idx = t * dim + i;
                    float        tmp = (float)ref_res[d][idx];
                    ref_data[d][idx] = tmp * sum * (float)weight[i];  // h' <- norm(r) * w
                }
            }
            h_comm->Sync();
            if (g_rank == 0) {
                std::cout << "done.\n";
            }

            Context ctx{g_rank};

            T* d_bias   = ctx.malloc(dim);
            T* d_weight = ctx.malloc(dim);

            T* d_data    = ctx.malloc(max_count);
            T* d_res     = ctx.malloc(max_count);
            T* d_tmp_res = ctx.malloc(max_count);

            T* d_tmp_data = (T*)d_comm->Allocate(sizeof(T) * max_count);
            d_comm->Register(d_tmp_data, sizeof(T) * max_count);

            ctx.copy_n(src_data[d][rank].data(), max_count, d_data);
            ctx.copy_n(src_res[d].data(), max_count, d_res);
            ctx.copy_n(bias.data(), dim, d_bias);
            ctx.copy_n(weight.data(), dim, d_weight);

            [[maybe_unused]] auto verify = [&](auto token_num) {
                const size_t count = (size_t)token_num * dim;
                vector    h_data(count);
                vector    h_res(count);
                ctx.copy_n(d_tmp_data, count, h_data.data());
                ctx.copy_n(d_tmp_res, count, h_res.data());
                ctx.sync();
                const size_t slice    = (token_num + n_ranks - 1) / n_ranks * dim;
                const size_t first    = rank * slice;
                const size_t last     = std::min(first + slice, count);
                size_t       res_diff = 0;
                for (size_t i = first; i < last; ++i) {
                    auto& x       = h_res[i];
                    auto& y       = ref_res[d][i];
                    int   is_diff = !(x == y);
                    if (!res_diff && is_diff) {
                        printf("[rank %d], %ld: %f vs %f\n", g_rank, i - first, (float)x, (float)y);
                    }
                    res_diff += is_diff;
                }
                float data_diff = 0;
                for (size_t i = 0; i < count; ++i) {
                    float diff = (float)h_data[i] - (float)ref_data[d][i];
                    data_diff += std::abs(diff);
                }
                data_diff /= count;
                if (res_diff || data_diff > 0.1f || std::isnan(data_diff)) {
                    printf("[rank %d] count = %d, res_diff = %lu, data_diff = %f\n",
                           g_rank,
                           (int)token_num,
                           res_diff,
                           data_diff);
                    std::this_thread::sleep_for(std::chrono::seconds(5));
                    std::abort();
                }
                else if (g_rank == 0) {
                    printf("[rank %d] count = %d, data_diff = %f\n", g_rank, (int)token_num, data_diff);
                }
            };

            vector deltas;
            for (const auto& n : tokens_) {
                const size_t count = (size_t)n * dim;
                auto&        delta = deltas.emplace_back();
                h_comm->Sync();
                for (int i = 0; i < warmup_ + iters_; ++i) {
                    ctx.copy_n(d_data, count, d_tmp_data);
                    ctx.copy_n(d_res, count, d_tmp_res);
                    auto ms = ctx.exec([&](auto stream) {  //
                        d_comm->AllreduceResidualBiasRMSnorm(d_tmp_data,
                                                             d_tmp_res,
                                                             has_bias ? d_bias : nullptr,
                                                             d_weight,
                                                             eps,
                                                             dim,
                                                             n,
                                                             dtype,
                                                             group,
                                                             stream);
                    });
                    if (i >= warmup_) {
                        delta += ms;
                    }
                    // verify(n);
                }
                verify(n);
            }

            d_comm->Deregister(d_tmp_data);
            d_comm->Free(d_tmp_data);

            if (g_rank == 0) {
                SummaryHeader("allreduce | rmsnorm", dim, n_ranks);
                for (size_t i = 0; i < tokens_.size(); ++i) {
                    const float  avg   = deltas[i] / iters_;
                    const size_t count = tokens_[i] * dim;
                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;
                    const float  busbw = algbw * (2 * (n_ranks - 1)) / n_ranks;
                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);
                }
            }
        };

        std::vector threads;
        for (size_t i = 0; i < d_comm_.size(); ++i) {
            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    template
    void TestAllGather(size_t dim, int group)
    {
        const auto dtype = data_type_v;

        const int tp_size = d_comm_[0]->n_ranks(group);
        const int dp_size = d_comm_.size() / tp_size;

        vector>> data(dp_size);

        for (int i = 0; i < dp_size; ++i) {
            data[i].resize(tp_size);
        }

        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {
            const int rank    = d_comm->rank(group);
            const int n_ranks = d_comm->n_ranks(group);
            const int g_rank  = d_comm->rank(0);
            const int d       = g_rank / n_ranks;

            const size_t max_count = max_tokens_ * dim;

            if (h_comm->rank() == 0) {
                std::cout << "preparing data ... " << std::flush;
            }
            std::mt19937                  gen{(unsigned)index};
            std::uniform_int_distribution dist{0, 100};
            data[d][rank].resize(max_count);
            for (size_t i = 0; i < max_count; ++i) {
                data[d][rank][i] = T(dist(gen));
            }
            h_comm->Sync();
            if (h_comm->rank() == 0) {
                std::cout << "done.\n";
            }

            Context ctx{g_rank};

            T* d_data = ctx.malloc(max_count);

            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count * n_ranks);
            d_comm->Register(d_tmp, sizeof(T) * max_count * n_ranks);

            ctx.copy_n(data[d][rank].data(), max_count, d_data);

            [[maybe_unused]] auto verify = [&](int64_t count) {
                auto           total_count = count * n_ranks;
                std::vector res(total_count);
                ctx.copy_n(d_tmp, total_count, res.data());
                ctx.sync();
                size_t diff = 0;
                for (int r = 0; r < n_ranks; ++r) {
                    for (auto i = 0; i < count; ++i) {
                        auto& x = res[r * count + i];
                        auto& y = data[d][r][i];
                        diff += (x != y);
                        if (diff == 1) {
                            printf("%d: %f vs %f\n", (int)i, (float)x, (float)y);
                        }
                    }
                }
                if (diff) {
                    printf("[rank %d] count = %d, diff = %lu\n", g_rank, (int)count, diff);
                    std::this_thread::sleep_for(std::chrono::seconds(1));
                    std::abort();
                }
            };

            std::vector deltas;
            for (const auto& n : tokens_) {
                const size_t count = (size_t)n * dim;  // dim = hidden_dim / tp
                auto&        delta = deltas.emplace_back();
                h_comm->Sync();
                for (int i = 0; i < warmup_ + iters_; ++i) {
                    check_cuda_error(cudaMemsetAsync(d_tmp, 0, sizeof(T) * count * n_ranks, ctx.stream));
                    ctx.copy_n(d_data, count, d_tmp + rank * count);
                    auto ms = ctx.exec([&](auto stream) {  //
                        if (d_comm->Query(kHasAllGather2D) && 0) {
                            d_comm->AllGather2D(
                                d_tmp + rank * count, d_tmp, dim, count, dim, n, dtype, {1, 1}, group, stream);
                        }
                        else {
                            d_comm->AllGather(d_tmp + rank * count, d_tmp, count, dtype, group, stream);
                        }
                    });
                    if (i >= warmup_) {
                        delta += ms;
                    }
                    // verify(count);
                }
                verify(count);
            }

            if (g_rank == 0) {
                SummaryHeader("allgather", dim, n_ranks);
                for (size_t i = 0; i < tokens_.size(); ++i) {
                    const float  avg   = deltas[i] / iters_;
                    const size_t count = n_ranks * tokens_[i] * dim;
                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;
                    const float  busbw = algbw * (n_ranks - 1) / n_ranks;

                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);
                }
            }

            d_comm->Deregister(d_tmp);
            d_comm->Free(d_tmp);
        };

        std::vector threads;
        for (size_t i = 0; i < d_comm_.size(); ++i) {
            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    template
    void TestBroadcast(size_t dim, int group)
    {
        const auto dtype = data_type_v;

        const int tp_size = d_comm_[0]->n_ranks(group);
        const int dp_size = d_comm_.size() / tp_size;

        constexpr int root = 0;

        vector> data(dp_size);

        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {
            const int rank    = d_comm->rank(group);
            const int n_ranks = d_comm->n_ranks(group);
            const int g_rank  = d_comm->rank(0);
            const int d       = g_rank / n_ranks;

            const size_t max_count = max_tokens_ * dim;

            if (h_comm->rank() == root) {
                std::cout << "preparing data ... " << std::flush;
                std::mt19937                  gen{(unsigned)index};
                std::uniform_int_distribution dist{0, 100};
                data[d].resize(max_count);
                for (size_t i = 0; i < max_count; ++i) {
                    data[d][i] = T(dist(gen));
                }
                std::cout << "done.\n";
            }

            h_comm->Sync();

            Context ctx{g_rank};

            T* d_data = ctx.malloc(max_count);

            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count);
            d_comm->Register(d_tmp, sizeof(T) * max_count);

            if (rank == root) {
                ctx.copy_n(data[d].data(), max_count, d_data);
            }

            [[maybe_unused]] auto verify = [&](int64_t count) {
                auto           total_count = count;
                std::vector res(total_count);
                ctx.copy_n(d_tmp, total_count, res.data());
                ctx.sync();
                size_t diff = 0;
                for (auto i = 0; i < count; ++i) {
                    auto& x = res[i];
                    auto& y = data[d][i];
                    diff += (x != y);
                    if (diff == 1) {
                        printf("%d: %f vs %f\n", (int)i, (float)x, (float)y);
                    }
                }
                if (diff) {
                    printf("[rank %d] count = %d, diff = %lu\n", g_rank, (int)count, diff);
                    std::this_thread::sleep_for(std::chrono::seconds(1));
                    std::abort();
                }
            };

            std::vector deltas;
            for (const auto& n : tokens_) {
                const size_t count = (size_t)n * dim;  // dim = hidden_dim / tp
                auto&        delta = deltas.emplace_back();
                h_comm->Sync();
                for (int i = 0; i < warmup_ + iters_; ++i) {
                    check_cuda_error(cudaMemsetAsync(d_tmp, 0, sizeof(T) * count, ctx.stream));
                    if (rank == root) {
                        ctx.copy_n(d_data, count, d_tmp);
                    }
                    auto ms = ctx.exec([&](auto stream) {  //
                        d_comm->Broadcast(d_tmp, d_tmp, count, dtype, 0, group, stream);
                    });
                    if (i >= warmup_) {
                        delta += ms;
                    }
                    // verify(count);
                }
                verify(count);
            }

            if (g_rank == 0) {
                SummaryHeader("broadcast", dim, n_ranks);
                for (size_t i = 0; i < tokens_.size(); ++i) {
                    const float  avg   = deltas[i] / iters_;
                    const size_t count = tokens_[i] * dim;
                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;
                    const float  busbw = algbw;
                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);
                }
            }

            d_comm->Deregister(d_tmp);
            d_comm->Free(d_tmp);
        };

        std::vector threads;
        for (size_t i = 0; i < d_comm_.size(); ++i) {
            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    template
    void TestAllreduceResidualBiasRMSnormEx(size_t dim, int group0, int group1)
    {
        const int tp_size_0 = d_comm_.at(0)->n_ranks(group0);
        const int tp_size_1 = d_comm_.at(0)->n_ranks(group1);
        const int dp_size_0 = d_comm_.size() / tp_size_0;
        const int dp_size_1 = d_comm_.size() / tp_size_1;

        const int inner_tp = std::gcd(tp_size_0, tp_size_1);

        const auto dtype = data_type_v;

        std::mt19937                  gen{};
        std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits

        TM_LOG_INFO("dp_size_0 %d, tp_size_0 %d", dp_size_0, tp_size_0);
        TM_LOG_INFO("dp_size_1 %d, tp_size_1 %d", dp_size_1, tp_size_1);
        TM_LOG_INFO("inner_tp %d", inner_tp);

        vector tokens = tokens_;
        for (auto& x : tokens) {
            x = (x + dp_size_0 - 1) / dp_size_0;
        }
        std::sort(tokens.begin(), tokens.end());
        tokens.erase(std::unique(tokens.begin(), tokens.end()), tokens.end());
        const size_t max_tokens = tokens.back();

        vector ref_data(dp_size_0 * max_tokens * dim);
        vector src_res(ref_data.size());
        vector ref_res(ref_data.size());

        vector weight(dim);
        vector bias(dim);

        constexpr float eps      = 1e-5;
        constexpr bool  has_bias = true;

        std::cout << "preparing data ... " << std::flush;

        for (size_t i = 0; i < dim; ++i) {
            weight[i] = T(dist(gen));
        }

        if (has_bias) {
            for (size_t i = 0; i < dim; ++i) {
                bias[i] = T(dist(gen));
            }
        }

        std::vector> src_data(tp_size_0);
        for (int r = 0; r < tp_size_0; ++r) {
            src_data[r].resize(ref_data.size());
            for (size_t i = 0; i < ref_data.size(); ++i) {
                src_data[r][i] = T(dist(gen));
            }
        }

        for (size_t i = 0; i < src_res.size(); ++i) {
            src_res[i] = T(dist(gen));
        }

        for (int r = 0; r < tp_size_0; ++r) {
            for (size_t i = 0; i < ref_data.size(); ++i) {
                ref_data[i] += src_data[r][i];
            }
        }

        for (size_t i = 0; i < dp_size_0 * max_tokens; ++i) {
            float sum = 0.f;
            for (size_t d = 0; d < dim; ++d) {
                size_t idx   = i * dim + d;
                ref_res[idx] = src_res[idx] + ref_data[idx] + bias[d];  // r' <- r + (h + b)
                sum += (float)ref_res[idx] * (float)ref_res[idx];
            }
            sum = 1 / (sqrtf(sum / dim) + eps);
            for (size_t d = 0; d < dim; ++d) {
                size_t idx    = i * dim + d;
                ref_data[idx] = (float)ref_res[idx] * sum * (float)weight[d];  // h' <- norm(r) * w
            }
        }

        std::cout << "done" << std::endl;

        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {
            const int g_rank    = d_comm->rank(0);
            const int g_n_ranks = d_comm->n_ranks(0);
            const int dp_rank_0 = g_rank / tp_size_0;
            const int dp_rank_1 = g_rank / tp_size_1;
            const int tp_rank_0 = d_comm->rank(group0);
            const int tp_rank_1 = d_comm->rank(group1);
            const int local_id  = g_rank / inner_tp;  // which local partition this rank belongs to

            // TM_LOG_INFO("g_rank %d, dp_rank_0 %d, tp_rank_0 %d, dp_rank_1 %d, tp_rank_1 %d, local_id %d",
            //             g_rank,
            //             dp_rank_0,
            //             tp_rank_0,
            //             dp_rank_1,
            //             tp_rank_1,
            //             local_id);

            const size_t max_count = max_tokens * dim;

            Context ctx{g_rank};

            T* d_bias    = ctx.malloc(dim);
            T* d_weight  = ctx.malloc(dim);
            T* d_data    = ctx.malloc(max_count);
            T* d_res     = ctx.malloc(max_count);
            T* d_tmp_res = ctx.malloc(max_count);

            T* d_tmp_data = (T*)d_comm->Allocate(sizeof(T) * dp_size_0 * max_count);
            d_comm->Register(d_tmp_data, sizeof(T) * dp_size_0 * max_count);

            ctx.copy_n(bias.data(), dim, d_bias);
            ctx.copy_n(weight.data(), dim, d_weight);

            [[maybe_unused]] auto verify = [&](auto n) {
                const size_t dst_tokens = n / dp_size_1 * dp_size_0;
                const size_t dst_count  = dst_tokens * dim;
                vector    h_data(dst_count);
                ctx.copy_n(d_tmp_data + dp_rank_1 * dst_count, dst_count, h_data.data());
                const size_t local_tokens = (size_t)n / dp_size_1;
                const size_t local_count  = local_tokens * dim;
                const size_t slice        = (local_tokens + inner_tp - 1) / inner_tp * dim;
                const size_t first        = std::min(local_count, g_rank % inner_tp * slice);
                const size_t last         = std::min(local_count, first + slice);
                vector    h_res(last - first);
                ctx.copy_n(d_tmp_res + first, h_res.size(), h_res.data());
                ctx.sync();
                size_t res_diff = 0;
                for (size_t i = first; i < last; ++i) {
                    auto& val  = h_res[i - first];
                    auto& ref  = ref_res[local_id * local_count + i];
                    int   diff = !(val == ref);
                    if (res_diff < 5 && diff) {
                        printf("[rank %d], %ld: %f vs %f\n", g_rank, i - first, (float)val, (float)ref);
                    }
                    res_diff += diff;
                }
                float data_diff = 0;
                for (size_t i = 0; i < dst_count; ++i) {
                    float diff = (float)h_data[i] - (float)ref_data[dp_rank_1 * dst_count + i];
                    data_diff += std::abs(diff);
                }
                data_diff /= dst_count;
                if (res_diff || data_diff > 0.1f || std::isnan(data_diff)) {
                    printf(
                        "[rank %d] count = %d, res_diff = %lu, data_diff = %f\n", g_rank, (int)n, res_diff, data_diff);
                    std::this_thread::sleep_for(std::chrono::seconds(5));
                    std::abort();
                }
                else if (tp_rank_1 == 0) {
                    printf("[rank %d] count = %d, data_diff = %f\n", g_rank, (int)n, data_diff);
                }
            };

            std::vector> stats;
            for (const auto& n : tokens) {
                if (n % dp_size_1) {
                    if (g_rank == 0) {
                        TM_LOG_INFO("Skipped %d", n);
                    }
                    continue;
                }
                // const int src_token_num = n;
                // const int dst_token_num = n / dp_size_1 * dp_size_0;
                const size_t count       = (size_t)n * dim;
                const size_t local_count = count / dp_size_1;
                std::vector  local_token_nums(dp_size_0 * dp_size_1, n / dp_size_1);
                ctx.copy_n(src_data[tp_rank_0].data() + dp_rank_0 * count, count, d_data);
                ctx.copy_n(src_res.data() + local_id * local_count, local_count, d_res);
                auto& [_, delta] = stats.emplace_back(n * dp_size_0, 0.f);
                h_comm->Sync();
                for (int i = 0; i < warmup_ + iters_; ++i) {
                    ctx.copy_n(d_data, count, d_tmp_data + dp_rank_0 * count);
                    ctx.copy_n(d_res, local_count, d_tmp_res);
                    auto ms = ctx.exec([&](auto stream) {  //
                        d_comm->AllreduceResidualBiasRMSnormEx(d_tmp_data,
                                                               d_tmp_res,
                                                               has_bias ? d_bias : nullptr,
                                                               d_weight,
                                                               eps,
                                                               dim,
                                                               dtype,
                                                               group0,
                                                               group1,
                                                               local_token_nums.data(),
                                                               stream);
                    });
                    if (i >= warmup_) {
                        delta += ms;
                    }
                    // verify(n);
                }
                verify(n);
            }

            d_comm->Deregister(d_tmp_data);
            d_comm->Free(d_tmp_data);

            if (g_rank == 0) {
                SummaryHeader("rs | rmsnorm | ag", dim, g_n_ranks);
                for (const auto& [num, ms] : stats) {
                    const float  avg    = ms / iters_;
                    const size_t count  = num * dim;
                    const float  algbw  = sizeof(T) * count / 1e9f / avg * 1000.f;
                    const float  factor = (tp_size_0 + tp_size_1 - 2) / (float)g_n_ranks;
                    const float  busbw  = algbw * factor;
                    // g_n_ranks;
                    SummaryEntry(num, count, sizeof(T), avg, algbw, busbw);
                }
            }
        };

        std::vector threads;
        for (size_t i = 0; i < d_comm_.size(); ++i) {
            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));
        }
        for (auto& t : threads) {
            t.join();
        }
    }

    void SummaryHeader(const char* name, int dim, int world_size)
    {
        printf("[%s] dim %d tp %d warmup %d iters %d\n", name, dim, world_size, warmup_, iters_);
        printf("%15s%15s%15s%15s%15s%15s\n", "num", "count", "size", "time", "algbw", "busbw");
        printf("%15s%15s%15s%15s%15s%15s\n", "(tokens)", "(elements)", "(MB)", "(us)", "(GB/s)", "(GB/s)");
    }

    void SummaryEntry(int num, size_t count, size_t elem_size, float time, float algbw, float busbw)
    {
        float mb_size = count * elem_size / (1024.f * 1024);
        printf("%15d%15ld%15.2f%15.3f%15.3f%15.3f\n", num, count, mb_size, time * 1e3f, algbw, busbw);
    }
};

int main(int argc, char* argv[])
{

    TestComm test;

    test.Run(2048,  //
             128000,
             -1,
             10,
             10000,
             //   {1024});
             //   {1024, 2048, 4096, 8192});
             // {512});
             //    {1, 2, 3, 4, 5, 6, 7, 8, 12, 16, 24, 32, 48, 64, 96, 128});
             //  {2, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128});
             //  {128, 256, 512, 1024, 2048, 4096, 8192});
             //  {8, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048, 4096, 6144, 8192});
             //   {8192, 16384, 32768});
             //  {1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 8192});
             {1,   2,   4,   6,   8,   12,   16,   24,   32,   48,   64,   96,   128,
              192, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 16384});

    return 0;
}


================================================
FILE: src/turbomind/comm/test_host_comm.cc
================================================

#include 
#include 
#include 

#include "src/turbomind/comm/host_comm.h"

using namespace turbomind;
using namespace turbomind::comm;

int main(int argc, char* argv[])
{
    const int                    N        = 32;
    std::unique_ptr group_id = CreateHostGroupId({});
    group_id->Initialize();
    std::vector threads;
    for (int r = 0; r < N; ++r) {
        threads.emplace_back([&, r] {
            HostComm world = group_id->CreateCommunicator(N, r);

            HostComm group = world;
            group          = world->Split(r / (N / 4), 0);

            auto tick = std::chrono::steady_clock::now();

            // int data = 100;
            // for (int i = 0; i < 10000; ++i, ++data) {
            //     group->Sync(true);
            // }

            volatile int a;
            volatile int b;
            for (int i = 0; i < 1; ++i) {
                a      = AllReduce(group, r, RedOp::kSum);
                auto v = AllGather(group, r);
                b      = std::accumulate(v.begin(), v.end(), 0);
                for (int j = 0; j < N; ++j) {
                    world->Sync();
                    if (j == r) {
                        std::cout << a << " " << b << std::endl;
                    }
                }
            }

            auto tock = std::chrono::steady_clock::now();

            for (int i = 0; i < N; ++i) {
                world->Sync();
                if (i == r) {
                    std::cout << std::chrono::duration(tock - tick).count() << std::endl;
                }
            }
        });
    }

    std::cout << "main thread waiting.\n";

    for (auto& t : threads) {
        t.join();
    }

    return 0;
}


================================================
FILE: src/turbomind/comm/thread_comm.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/comm/barrier.h"
#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/serdes.h"
namespace turbomind::comm {

struct ThreadCommImpl: public HostCommImpl {

    class State {
    public:
        explicit State(int n): n_{n}, channels_(n * n), barrier_{n} {}

        std::atomic& channel(int from, int to)
        {
            return channels_[from * n_ + to];
        }

        void sync()
        {
            barrier_.arrive_and_wait();
        }

    private:
        int                            n_;
        std::deque> channels_;
        Barrier                        barrier_;
    };

    std::shared_ptr state_;

    int n_ranks_;
    int rank_;

    ThreadCommImpl(int n_ranks, std::shared_ptr state, int rank):
        state_{std::move(state)}, n_ranks_{n_ranks}, rank_{rank}
    {
    }

    int rank() const override
    {
        return rank_;
    }

    int n_ranks() const override
    {
        return n_ranks_;
    }

    bool is_same_process() const override
    {
        return true;
    }

    std::atomic& channel(int from, int to)
    {
        return state_->channel(from, to);
    }

    std::shared_ptr Split(int color, int key) override
    {
        TM_CHECK(color >= 0);

        auto ranks = comm::AllGather(this, std::make_tuple(color, key, rank_));

        auto same_color = [&](auto x) { return std::get<0>(x) == color; };
        ranks.erase(std::stable_partition(ranks.begin(), ranks.end(), same_color), ranks.end());

        std::stable_sort(ranks.begin(), ranks.end(), [](auto& a, auto& b) { return a < b; });

        std::shared_ptr state;

        int rank = -1;
        for (int i = 0; i < ranks.size(); ++i) {
            if (std::get<2>(ranks[i]) == rank_) {
                rank = i;
            }
        }

        TM_CHECK_GE(rank, 0);

        if (rank == 0) {
            state = std::make_shared(ranks.size());
        }

        auto states = comm::AllGather(this, state);
        if (rank != 0) {
            const int root = std::get<2>(ranks[0]);
            state          = states[root];
        }

        return std::make_shared(ranks.size(), state, rank);
    }

    void Sync(bool blocking) override
    {
        if (n_ranks_ == 1) {
            return;
        }

        if (blocking) {
            state_->sync();
            return;
        }

        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(rank_, r);
                void* expected{};
                while (!c.compare_exchange_weak(expected, (void*)1, std::memory_order_release)) {
                    expected = {};
                }
            }
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c        = channel(r, rank_);
                void* expected = (void*)1;
                while (!c.compare_exchange_weak(expected, nullptr, std::memory_order_acquire)) {
                    expected = (void*)1;
                }
            }
        }
    }

    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override
    {
        TM_CHECK(copy);
        if (n_ranks_ == 1) {
            return;
        }
        // transform root to global rank
        if (rank_ == root) {
            for (int r = 0; r < n_ranks_; ++r) {
                if (r != rank_) {
                    auto& c = channel(rank_, r);
                    void* expected{};
                    while (!c.compare_exchange_weak(expected, data, std::memory_order_release)) {
                        expected = {};
                    }
                }
            }
            for (int r = 0; r < n_ranks_; ++r) {
                if (r != rank_) {
                    auto& c = channel(rank_, r);
                    while (c.load(std::memory_order_relaxed)) {}
                }
            }
        }
        else {
            auto& c = channel(root, rank_);
            void* incoming{};
            while (!(incoming = c.load(std::memory_order_acquire))) {}
            copy(incoming, count, data, 0);
            c.store(nullptr, std::memory_order_relaxed);
        }
    }

    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override
    {
        TM_CHECK(copy);
        if (n_ranks_ == 1) {
            return;
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(rank_, r);
                void* expected{};
                while (!c.compare_exchange_weak(expected, data, std::memory_order_release)) {
                    expected = {};
                }
            }
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(r, rank_);
                void* incoming{};
                while (!(incoming = c.load(std::memory_order_acquire))) {}
                copy(incoming, count, data, r * count);
                c.store(nullptr, std::memory_order_relaxed);
            }
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(rank_, r);
                while (c.load(std::memory_order_relaxed)) {}
            }
        }
    }

    template
    static void reduce(void* src, int n, void* dst, int offset)
    {
        for (int i = 0; i < n; ++i) {
            auto& s = *((T*)src + offset + i);
            auto& a = *((T*)dst + offset + i);
            if constexpr (op == RedOp::kSum) {
                a += s;
            }
            else if constexpr (op == RedOp::kMin) {
                a = std::min(a, s);
            }
            else if constexpr (op == RedOp::kMax) {
                a = std::max(a, s);
            }
            else {
                static_assert(sizeof(T) != sizeof(T), "not implemented");
            }
        }
    }

    static reduce_fn get_reduce(DataType dtype, RedOp red_op)
    {
        auto dispatch_op = [&](auto t) -> reduce_fn {
            using T = decltype(t);
            switch (red_op) {
                case RedOp::kSum:
                    return reduce;
                case RedOp::kMax:
                    return reduce;
                case RedOp::kMin:
                    return reduce;
                default:
                    return {};
            }
        };
        auto dispatch = [&]() -> reduce_fn {
            switch (dtype) {
                case kInt32:
                    return dispatch_op(int32_t{});
                case kInt64:
                    return dispatch_op(int64_t{});
                case kUint32:
                    return dispatch_op(uint32_t{});
                case kUint64:
                    return dispatch_op(uint64_t{});
                default:
                    return {};
            }
        };
        if (auto fn = dispatch()) {
            return fn;
        }
        else {
            throw std::runtime_error("not implemented");
            return {};
        }
    }

    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override
    {
        const auto reduce    = get_reduce(dtype, red_op);
        const auto elem_size = byte_size(dtype);
        if (n_ranks_ == 1) {
            return;
        }
        std::unique_ptr tmp((char*)::operator new[](elem_size* count));
        std::copy_n((char*)data, elem_size * count, tmp.get());
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(rank_, r);
                void* expected{};
                while (!c.compare_exchange_weak(expected, (void*)tmp.get(), std::memory_order_release)) {
                    expected = {};
                }
            }
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(r, rank_);
                void* incoming{};
                while (!(incoming = c.load(std::memory_order_acquire))) {}
                reduce(incoming, count, data, 0);
                c.store(nullptr, std::memory_order_relaxed);
            }
        }
        for (int r = 0; r < n_ranks_; ++r) {
            if (r != rank_) {
                auto& c = channel(rank_, r);
                while (c.load(std::memory_order_relaxed)) {}
            }
        }
    }
};

class ThreadGroupId: public HostGroupId {
public:
    void Initialize() override
    {
        internal_ = std::make_shared();
    }

    void Export(std::ostream& os) override
    {
        TM_CHECK((bool)internal_);  // `Initialize` must come befor `Export`

        const void* ptr = this;
        os.write((const char*)&ptr, sizeof(ptr));
    }

    void Import(std::istream& is) override
    {
        void* ptr{};
        is.read((char*)&ptr, sizeof(ptr));
        internal_ = reinterpret_cast(ptr)->internal_;

        TM_CHECK((bool)internal_);
    }

    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override
    {
        auto init_shared_state = [&] {  //
            internal_->state = std::make_shared(n_ranks);
        };

        TM_CHECK((bool)internal_);

        // One of the rank initialize the shared state
        std::call_once(internal_->flag, init_shared_state);

        TM_CHECK((bool)internal_->state);

        auto impl = std::make_shared(n_ranks, internal_->state, rank);

        return std::static_pointer_cast(impl);
    }

private:
    struct Internal {
        std::once_flag                         flag;
        std::shared_ptr state;
    };

private:
    std::shared_ptr internal_;
};

std::unique_ptr CreateThreadGroupId()
{
    return std::make_unique();
}

template
void save(Archive& ar, const std::shared_ptr& p)
{
    TM_CHECK(false) << "should never be called";
}

template
void load(Archive& ar, std::shared_ptr& p)
{
    TM_CHECK(false) << "should never be called";
}

}  // namespace turbomind::comm


================================================
FILE: src/turbomind/core/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)

add_library(core STATIC
        check.cc
        allocator.cc
        stream.cc
        context.cc
        buffer.cc
        layout.cc
        tensor.cc
        tensor.cu
        module.cc
        copy.cc)

target_link_libraries(core PUBLIC cuda_utils logger CUDA::cudart CUDA::cuda_driver)

set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET core PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

target_compile_options(core PRIVATE $<$:-Xptxas=-v>)

if (BUILD_TEST)
    add_executable(test_core test_core.cc)
    target_link_libraries(test_core PRIVATE core logger Catch2::Catch2WithMain)
endif ()


================================================
FILE: src/turbomind/core/allocator.cc
================================================

#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"

namespace turbomind::core {

AllocatorImpl::~AllocatorImpl() = default;

Stream AllocatorImpl::stream() const noexcept
{
    return Stream{};
}

class CudaMemPoolAllocator: public AllocatorImpl {
public:
    CudaMemPoolAllocator(Stream stream, bool use_default_pool):
        pool_{}, stream_{stream}, device_{kDEVICE}, use_default_pool_{use_default_pool}
    {
        check_cuda_error(cudaGetDevice(&device_.id));
        if (use_default_pool_) {
            check_cuda_error(cudaDeviceGetDefaultMemPool(&pool_, device_.id));
        }
        else {
            cudaMemPoolProps props{};
            props.allocType     = cudaMemAllocationTypePinned;
            props.handleTypes   = cudaMemHandleTypeNone;
            props.location.type = cudaMemLocationTypeDevice;
            props.location.id   = device_.id;
            check_cuda_error(cudaMemPoolCreate(&pool_, &props));
            cuuint64_t thres = (cuuint64_t)-1;
            check_cuda_error(cudaMemPoolSetAttribute(pool_, cudaMemPoolAttrReleaseThreshold, &thres));
        }
    }

    ~CudaMemPoolAllocator() override
    {
        if (!use_default_pool_) {
            check_cuda_error(cudaMemPoolDestroy(pool_));
        }
        pool_ = {};
    }

    void* allocate(ssize_t size) override
    {
        void* ptr{};
        check_cuda_error(cudaMallocFromPoolAsync(&ptr, size, pool_, stream_.handle()));
        return ptr;
    }

    void deallocate(void* p, ssize_t) override
    {
        check_cuda_error(cudaFreeAsync(p, stream_.handle()));
    }

    Device device() const noexcept override
    {
        return device_;
    }

    Stream stream() const noexcept override
    {
        return stream_;
    }

    void trim(size_t bytes_to_keep)
    {
        check_cuda_error(cudaMemPoolTrimTo(pool_, bytes_to_keep));
    }

private:
    cudaMemPool_t pool_;
    Stream        stream_;
    Device        device_;
    bool          use_default_pool_;
};

class CudaAllocator: public AllocatorImpl {
public:
    void* allocate(ssize_t size) override
    {
        void* ptr{};
        check_cuda_error(cudaMalloc(&ptr, size));
        return ptr;
    }

    void deallocate(void* p, ssize_t) override
    {
        check_cuda_error(cudaFree(p));
    }

    Device device() const noexcept override
    {
        return kDEVICE;
    }
};

class CudaHostAllocator: public AllocatorImpl {
public:
    void* allocate(ssize_t size) override
    {
        void* ptr{};
        check_cuda_error(cudaHostAlloc(&ptr, size, cudaHostAllocDefault));
        return ptr;
    }

    void deallocate(void* p, ssize_t) override
    {
        check_cuda_error(cudaFreeHost(p));
    }

    Device device() const noexcept override
    {
        return kCPUpinned;
    }
};

class HostAllocator: public AllocatorImpl {
public:
    void* allocate(ssize_t size) override
    {
        return ::operator new(size);
    }

    void deallocate(void* p, ssize_t) override
    {
        ::operator delete(p);
    }

    Device device() const noexcept override
    {
        return kCPU;
    }
};

Allocator::Allocator(DeviceType type)
{
    impl_ = [&]() -> shared_ptr {
        switch (type) {
            case kCPU:
                return std::make_shared();
            case kDEVICE:
                return std::make_shared();
            case kCPUpinned:
                return std::make_shared();
        }
        return {};
    }();
    TM_CHECK_NOTNULL(impl_);
}

Allocator::Allocator(Stream stream, bool use_default_pool)
{
    impl_ = std::make_shared(std::move(stream), use_default_pool);
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/allocator.h
================================================
#pragma once

#include 
#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/common.h"
#include "src/turbomind/core/stream.h"

#include "src/turbomind/kernels/core/math.h"

namespace turbomind {

enum class DeviceType : int
{
    kCPU,
    kCPUpinned,
    kDEVICE
};

inline constexpr DeviceType kCPU       = DeviceType::kCPU;
inline constexpr DeviceType kCPUpinned = DeviceType::kCPUpinned;
inline constexpr DeviceType kDEVICE    = DeviceType::kDEVICE;

constexpr const char* to_string(DeviceType device)
{
    switch (device) {
        case kCPU:
            return "cpu";
        case kCPUpinned:
            return "cpu_pinned";
        case kDEVICE:
            return "device";
    }
    return "";
}

inline std::ostream& operator<<(std::ostream& os, DeviceType device)
{
    return os << to_string(device);
}

}  // namespace turbomind

namespace turbomind::core {

struct Device {
    DeviceType type;
    int        id;
    Device(): Device{kCPU} {}
    Device(DeviceType type_): type{type_}, id{-1} {}
    Device(DeviceType type_, int device_): type{type_}, id{device_} {}
    friend bool operator==(const Device& a, const Device& b)
    {
        return a.type == b.type && a.id == b.id;
    }
    friend bool operator!=(const Device& a, const Device& b)
    {
        return !(a == b);
    }
};

class AllocatorImpl {
public:
    virtual ~AllocatorImpl();

    virtual void* allocate(ssize_t size) = 0;

    virtual void deallocate(void* p, ssize_t size) = 0;

    // Returns invalid stream by default
    virtual Stream stream() const noexcept;

    virtual Device device() const noexcept = 0;

    virtual void trim(size_t bytes_to_keep){};
};

class Allocator {
public:
    Allocator() = default;

    explicit Allocator(DeviceType type);

    Allocator(Stream stream, bool use_default_pool);

    Allocator(shared_ptr impl): impl_{std::move(impl)} {};

    AllocatorImpl* operator->() const
    {
        TM_CHECK_NOTNULL(impl_);
        return impl_.get();
    }

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

    friend bool operator==(const Allocator& a, const Allocator& b)
    {
        return a.impl_ == b.impl_;
    }

    friend bool operator!=(const Allocator& a, const Allocator& b)
    {
        return !(a == b);
    }

    template
    shared_ptr adapt(Args&&... args) const
    {
        return {std::make_shared(impl_, ((Args &&) args)...)};
    }

private:
    shared_ptr impl_;
};

class StackAllocatorImpl: public AllocatorImpl {
public:
    static constexpr ssize_t kAlignment = 256;

    explicit StackAllocatorImpl(shared_ptr underlying_impl): underlying_impl_{std::move(underlying_impl)}
    {
    }

    ~StackAllocatorImpl() override
    {
        if (cached_beg_) {
            underlying_impl_->deallocate(cached_beg_, cached_end_ - cached_beg_);
        }
    }

    void* allocate(ssize_t size) override
    {
        size = round_up(size, kAlignment);

        void* p{};
        if (cached_ptr_ + size <= cached_end_) {
            p = cached_ptr_;
            cached_ptr_ += size;
        }
        else {
            TM_CHECK(!cached_beg_);
            p = underlying_impl_->allocate(size);
        }

        // TM_LOG_ERROR("allocate %p, %ld", p, size);

        size_ += size;
        ++num_;
        max_size_ = std::max(size_, max_size_);
        num_      = std::max(num_, max_num_);
        return p;
    }

    void deallocate(void* p, ssize_t size) override
    {
        size = round_up(size, kAlignment);

        // TM_LOG_ERROR("deallocate %p, %p, %ld", p, cached_ptr_, size);

        if ((char*)p + size == cached_ptr_) {
            cached_ptr_ -= size;
        }
        else {
            TM_CHECK(!cached_beg_);
            underlying_impl_->deallocate(p, size);
        }
        size_ -= size;
        --num_;
    }

    Stream stream() const noexcept override
    {
        return underlying_impl_->stream();
    }

    Device device() const noexcept override
    {
        return underlying_impl_->device();
    }

    void iter()
    {
        TM_CHECK_EQ((void*)cached_beg_, (void*)cached_ptr_);
        auto excpected = max_size_ + kAlignment * max_num_;
        if (cached_end_ - cached_beg_ < excpected) {
            if (cached_beg_) {
                underlying_impl_->deallocate(cached_beg_, cached_end_ - cached_beg_);
            }
            cached_ptr_ = cached_beg_ = (char*)underlying_impl_->allocate(excpected);
            cached_end_               = cached_beg_ + excpected;
        }
        size_ = num_ = max_size_ = max_num_ = 0;
    }

private:
    ssize_t size_{};
    ssize_t num_{};
    ssize_t max_size_{};
    ssize_t max_num_{};

    char* cached_beg_{};
    char* cached_end_{};
    char* cached_ptr_{};

    std::shared_ptr underlying_impl_;
};

class SimpleAllocator: public AllocatorImpl {
public:
    template
    static Allocator Create(Alloc&& alloc, Dealloc&& dealloc, Device device)
    {
        return Allocator{std::make_shared((Alloc &&) alloc, (Dealloc &&) dealloc, device)};
    }

    template
    SimpleAllocator(Alloc&& alloc, Dealloc&& dealloc, Device device):
        alloc_{std::move(alloc)}, dealloc_{std ::move(dealloc)}, device_{device}
    {
    }

    void* allocate(ssize_t size) override
    {
        return alloc_(size);
    };

    void deallocate(void* p, ssize_t size) override
    {
        return dealloc_(p, size);
    }

    Device device() const noexcept override
    {
        return device_;
    }

private:
    std::function       alloc_;
    std::function dealloc_;
    Device                              device_;
};

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/buffer.cc
================================================

#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/stream.h"
namespace turbomind::core {

Buffer Buffer::view(DataType dtype) const
{
    auto b = *this;
    if (dtype == dtype_) {
        return b;
    }
    b.dtype_ = dtype;
    b.size_  = numel(dtype, byte_size());
    if (base_) {
        b.base_ = numel(dtype, turbomind::byte_size(dtype_, base_));
    }
    return b;
}

Buffer Buffer::slice(ssize_t base, ssize_t size) const
{
    TM_CHECK_LE(base + size, size_);
    auto b = *this;
    b.base_ += base;
    if (size == -1) {
        b.size_ -= base;
    }
    else {
        b.size_ = size;
    }
    return b;
}

std::ostream& operator<<(std::ostream& os, const Buffer& b)
{
    os << b.dtype() << "[" << b.size() << "]@" << b.data_;
    if (b.base_) {
        os << "+" << b.base_;
    }
    return os;
}

void Copy(const Buffer& a, ssize_t n, Ref b_, const Stream& stream)
{
    auto& b = b_.get();
    TM_CHECK_EQ(a.dtype(), b.dtype());
    TM_CHECK_LE(n, a.size());
    TM_CHECK_LE(n, b.size());
    if (auto size = byte_size(a.dtype(), n)) {
        check_cuda_error(cudaMemcpyAsync(b.raw_data(), a.raw_data(), size, cudaMemcpyDefault, stream.handle()));
    }
}

void Copy(const Buffer& a, ssize_t n, Ref b_)
{
    Copy(a, n, b_, Context::stream());
}

void Copy(const Buffer& a, Ref b_, const Stream& stream)
{
    TM_CHECK_EQ(a.size(), b_.get().size());
    Copy(a, a.size(), b_, stream);
}

void Copy(const Buffer& a, Ref b_)
{
    Copy(a, b_, Context::stream());
}

namespace detail {

void* Copy(const void* a, ssize_t n, void* b, const Stream& stream)
{
    if (n) {
        check_cuda_error(cudaMemcpyAsync(b, a, n, cudaMemcpyDefault, stream.handle()));
    }
    return (uint8_t*)b + n;
}

}  // namespace detail

void Clear(Ref b_, const Stream& stream)
{
    auto& b = b_.get();
    if (auto size = b.byte_size()) {
        check_cuda_error(cudaMemsetAsync(b.raw_data(), 0, b.byte_size(), stream.handle()));
    }
}

void Clear(Ref b_)
{
    Clear(b_, Context::stream());
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/buffer.h
================================================
#pragma once

#include 

#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/common.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/serdes.h"

namespace turbomind::core {

class Buffer {
public:
    Buffer(): data_{}, base_{}, size_{}, device_{}, dtype_{} {}

    // Typed empty buffer
    explicit Buffer(DataType dtype): Buffer()
    {
        dtype_ = dtype;
    }

    // Reference into `data` buffer
    template
    Buffer(T* data, ssize_t size, Device device):
        data_{data, [](auto) {}}, base_{}, size_{size}, device_{device}, dtype_{data_type_v}
    {
    }

    Buffer(void* data, ssize_t size, DataType dtype, Device device):
        data_{data, [](auto) {}}, base_{}, size_{size}, device_{device}, dtype_{dtype}
    {
    }

    // Share ownership of `data`
    Buffer(shared_ptr data, ssize_t size, DataType dtype, Device device):
        data_{std::move(data)}, base_{}, size_{size}, device_{device}, dtype_{dtype}
    {
    }

    // Create from the allocator
    Buffer(ssize_t size, DataType dtype, Allocator& alloc):
        base_{}, size_{size}, device_{alloc->device()}, dtype_{dtype}
    {
        auto bytes = turbomind::byte_size(dtype, size);
        data_      = {alloc->allocate(bytes), [=](auto p) { alloc->deallocate(p, bytes); }};
    }

    Buffer(ssize_t size, DataType dtype, Device device): Buffer{size, dtype, Context::alloc(device)} {}

    template
    T* data()
    {
        TM_CHECK_EQ(data_type_v, dtype_);
        return (T*)((char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size(base_));
    }

    template
    const T* data() const
    {
        return const_cast(this)->data();
    }

    void* raw_data(ssize_t offset = 0)
    {
        return (char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size(dtype_, base_ + offset);
    }

    const void* raw_data(ssize_t offset = 0) const
    {
        return const_cast(this)->raw_data(offset);
    }

    template
    T* data_or(T* other) noexcept
    {
        if constexpr (std::is_void_v) {
            return data_ ? (T*)raw_data() : other;
        }
        else {
            return data_ ? data() : other;
        }
    }

    template
    const T* data_or(const T* other) const noexcept
    {
        return const_cast(this)->data_or(other);
    }

    DataType dtype() const
    {
        return dtype_;
    }

    Device device() const
    {
        return device_;
    }

    ssize_t size() const
    {
        return size_;
    }

    ssize_t byte_size() const
    {
        return turbomind::byte_size(dtype_, size_);
    }

    explicit operator bool() const noexcept
    {
        return static_cast(data_);
    }

    Buffer view(DataType dtype) const;

    template
    Buffer view() const
    {
        return view(data_type_v);
    }

    Buffer slice(ssize_t base, ssize_t size) const;

    Buffer borrow() const
    {
        return Buffer{const_cast(raw_data()), size_, dtype_, device_};
    }

    friend bool operator==(const Buffer& a, const Buffer& b);

    friend bool operator!=(const Buffer& a, const Buffer& b);

    friend std::ostream& operator<<(std::ostream& os, const Buffer& b);

protected:
    auto as_tuple() const
    {
        return std::tie(data_, base_, size_, dtype_, device_);
    }

    shared_ptr data_;
    ssize_t          base_;
    ssize_t          size_;
    Device           device_;
    DataType         dtype_;
};

inline bool operator==(const Buffer& a, const Buffer& b)
{
    return a.as_tuple() == b.as_tuple();
}

inline bool operator!=(const Buffer& a, const Buffer& b)
{
    return !(a == b);
}

inline Buffer empty_like(const Buffer& buffer)
{
    return Buffer{buffer.size(), buffer.dtype(), buffer.device()};
}

inline Buffer empty_like(const Buffer& buffer, Device device)
{
    return Buffer{buffer.size(), buffer.dtype(), device};
}

inline Buffer empty_like(const Buffer& buffer, DataType dtype)
{
    return Buffer{buffer.size(), dtype, buffer.device()};
}

template
struct Buffer_: public Buffer {

    Buffer_(): Buffer{data_type_v} {}

    Buffer_(T* data, ssize_t size, Device device): Buffer{data, size, device} {}

    Buffer_(shared_ptr data, ssize_t size, Device device): Buffer{std::move(data), size, data_type_v, device}
    {
    }

    Buffer_(ssize_t size, Allocator& alloc): Buffer{size, data_type_v, alloc} {}

    Buffer_(ssize_t size, Device device): Buffer{size, data_type_v, device} {}

    Buffer_(const Buffer_&) = default;
    Buffer_& operator=(const Buffer_&) = default;

    Buffer_(Buffer_&&) noexcept = default;
    Buffer_& operator=(Buffer_&&) noexcept = default;

    Buffer_(const Buffer& b)
    {
        *static_cast(this) = ensure_dtype(b);
    }
    Buffer_(Buffer&& b) noexcept
    {
        *static_cast(this) = ensure_dtype(std::move(b));
    }

    T* data_or(T* other)
    {
        return data_ ? data() : other;
    }

    const T* data_or(const T* other) const
    {
        return data_ ? data() : other;
    }

    void* raw_data(ssize_t offset = 0)
    {
        return (char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size(base_ + offset);
    }

    const void* raw_data(ssize_t offset = 0) const
    {
        return const_cast(this)->raw_data(offset);
    }

    T* data()
    {
        return static_cast(raw_data());
    }

    const T* data() const
    {
        return static_cast(raw_data());
    }

    T* begin()
    {
        return data();
    }

    const T* begin() const
    {
        return data();
    }

    T* end()
    {
        return begin() + size();
    }

    const T* end() const
    {
        return begin() + size();
    }

    T& operator[](ssize_t i)
    {
        return data()[i];
    }

    const T& operator[](ssize_t i) const
    {
        return data()[i];
    }

    T& at(ssize_t i)
    {
        TM_CHECK_LT(i, size());
        return data()[i];
    }

    T& at(ssize_t i) const
    {
        TM_CHECK_LT(i, size());
        return data()[i];
    }

    constexpr DataType dtype() const noexcept
    {
        return data_type_v;
    }

private:
    template
    static decltype(auto) ensure_dtype(U&& u) noexcept
    {
        TM_CHECK_EQ(u.dtype(), data_type_v);
        return (U &&) u;
    }
};

template
class Ref {
public:
    Ref(T& x): ref_{x} {}
    Ref(T&& x): ref_{x} {}

    operator T&()
    {
        return ref_;
    }

    T& get()
    {
        return ref_;
    }

private:
    T& ref_;
};

void Copy(const Buffer& a, ssize_t n, Ref b_, const Stream& stream);

void Copy(const Buffer& a, ssize_t n, Ref b_);

void Copy(const Buffer& a, Ref b_, const Stream& stream);

void Copy(const Buffer& a, Ref b_);

// Static type checking
template
inline void Copy_(const Buffer_& a, ssize_t n, Buffer_& b_)
{
    Copy((const Buffer&)a, n, (Buffer&)b_);
}

namespace detail {

void* Copy(const void* a, ssize_t n, void* b, const Stream& stream);

}  // namespace detail

template
inline T* Copy(const T* a, ssize_t n, T* b, const Stream& stream)
{
    return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, stream);
}

template
inline T* Copy(const T* a, ssize_t n, T* b)
{
    return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, Context::stream());
}

struct CopyT {
    template
    auto operator()(Args&&... args) const
    {
        return Copy(((Args &&) args)...);
    }
};

void Clear(Ref b_, const Stream& stream);

void Clear(Ref b_);

template
std::vector to_vector(const Buffer_& b)
{
    TM_CHECK(b.device().type == kCPU || b.device().type == kCPUpinned);
    return std::vector(b.begin(), b.end());
}

// clang-format off
template
void save(Archive& ar, const Buffer& buffer)
{
    TM_CHECK(buffer.device().type == kCPU);
    ar & buffer.size();
    ar & buffer.dtype();
    ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size());
}

template
void load(Archive& ar, Buffer& buffer)
{
    decltype(buffer.size())  size;
    decltype(buffer.dtype()) dtype;

    ar & size;
    ar & dtype;
    buffer = Buffer(size, dtype, kCPU);
    ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size());
}
// clang-format on

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/check.cc
================================================

#include 
#include 
#include 
#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind::core {

namespace {

std::string StripSrcPrefix(const char* file)
{
    static const char* flag = std::getenv("TM_SRC_FULL_PATH");
    if (flag) {
        return file;
    }

    std::filesystem::path path{file};
    std::filesystem::path ret{path};  // return the original path if anchor is not found

    constexpr auto anchor = "turbomind";

    bool found = false;

    for (const auto& x : path) {
        if (x == anchor) {
            found = true;
            ret.clear();
        }
        else if (found) {
            ret /= x;
        }
    }

    return ret.string();
}

}  // namespace

CheckOpStringBuilder::CheckOpStringBuilder()
{
    oss_ = new std::ostringstream;
}

std::ostream* CheckOpStringBuilder::ForVal1()
{
    (*oss_) << "(";
    return oss_;
}
std::ostream* CheckOpStringBuilder::ForVal2()
{
    (*oss_) << " vs. ";
    return oss_;
}
std::string* CheckOpStringBuilder::NewString()
{
    (*oss_) << ")";
    return new std::string{oss_->str()};
}

CheckErrorStream::CheckErrorStream(const char* file, int line, const char* expr)
{
    oss_ = new std::ostringstream{};
    *oss_ << StripSrcPrefix(file) << "(" << line << "): Check failed: " << expr << " ";
}

CheckErrorStream::CheckErrorStream(const char* file, int line, const char* expr, std::string* str):
    CheckErrorStream{file, line, expr}
{
    *oss_ << *str << " ";
}

void CheckErrorStream::Report()
{
    // ! Be aware of `%` in expr
    std::cerr << "[TM][FATAL] " << oss_->str() << "\n";
    std::abort();
}

void ReportNullError(const char* file, int line, const char* expr)
{
    // ! Be aware of `%` in expr
    std::cerr << "[TM][FATAL] " << StripSrcPrefix(file) << "(" << line << "): '" << expr << "' Must be non NULL\n";
    std::abort();
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/check.h
================================================

// Inspired by 

#pragma once

#include 

namespace turbomind::core {

#if defined(_MSC_VER) && !defined(__clang__)
#define TM_LIKELY(expr) (expr)
#define TM_UNLIKELY(expr) (expr)
#define TM_NOINLINE
#define TM_UNREACHABLE __assume(0)
#else
#define TM_LIKELY(expr) (__builtin_expect(bool(expr), 1))
#define TM_UNLIKELY(expr) (__builtin_expect(bool(expr), 0))
#define TM_NOINLINE __attribute__((noinline))
#define TM_UNREACHABLE __builtin_unreachable()
#endif

#define TM_DISABLE_CHECK_STREAM 0
#define TM_DISABLE_CHECK_OP 0

class CheckErrorStream {
public:
    CheckErrorStream(const char* file, int line, const char* expr);

    CheckErrorStream(const char* file, int line, const char* expr, std::string* str);

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 4722)  // MSVC warns dtor never return
#endif
    ~CheckErrorStream()
    {
        Report();
    }
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif

    template
    CheckErrorStream& operator<<(const T& msg)
    {
#if TM_DISABLE_CHECK_STREAM
#else
        *oss_ << msg;
#endif
        return *this;
    }

private:
    [[noreturn]] void Report();

    std::ostringstream* oss_;
};

class CheckOpStringBuilder {
public:
    CheckOpStringBuilder();
    std::ostream* ForVal1();
    std::ostream* ForVal2();
    std::string*  NewString();

private:
    std::ostringstream* oss_;
};

template
std::string* MakeCheckOpString(const T1& v1, const T2& v2) TM_NOINLINE;

template
std::string* MakeCheckOpString(const T1& v1, const T2& v2)
{
    CheckOpStringBuilder builder;
    *builder.ForVal1() << v1;
    *builder.ForVal2() << v2;
    return builder.NewString();
}

#define DEFINE_CHECK_OP_IMPL(name, op)                                                                                 \
    template                                                                                       \
    inline std::pair name##Impl(const T1& v1, const T2& v2)                                        \
    {                                                                                                                  \
        if (TM_LIKELY(v1 op v2))                                                                                       \
            return {false, nullptr};                                                                                   \
        else                                                                                                           \
            return {true, MakeCheckOpString(v1, v2)};                                                                  \
    }

DEFINE_CHECK_OP_IMPL(Check_EQ, ==);
DEFINE_CHECK_OP_IMPL(Check_NE, !=);
DEFINE_CHECK_OP_IMPL(Check_LE, <=);
DEFINE_CHECK_OP_IMPL(Check_LT, <);
DEFINE_CHECK_OP_IMPL(Check_GE, >=);
DEFINE_CHECK_OP_IMPL(Check_GT, >);

#undef DEFINE_CHECK_OP_IMPL

// clang-format off
#define TM_CHECK(e)                                                                  \
    if (TM_UNLIKELY(!(e))) turbomind::core::CheckErrorStream(__FILE__, __LINE__, #e)

#define TM_CHECK_OP(name, op, a, b)                                                  \
    if (auto&& [__p, __s] = turbomind::core::Check##name##Impl(a, b); __p) \
        turbomind::core::CheckErrorStream(__FILE__, __LINE__, #a " " #op " " #b, __s)
// clang-format on

#if TM_DISABLE_CHECK_OP

#define TM_CHECK_EQ(a, b) TM_CHECK(a == b)
#define TM_CHECK_NE(a, b) TM_CHECK(a != b)
#define TM_CHECK_LE(a, b) TM_CHECK(a <= b)
#define TM_CHECK_LT(a, b) TM_CHECK(a < b)
#define TM_CHECK_GE(a, b) TM_CHECK(a >= b)
#define TM_CHECK_GT(a, b) TM_CHECK(a > b)

#else

#define TM_CHECK_EQ(a, b) TM_CHECK_OP(_EQ, ==, a, b)
#define TM_CHECK_NE(a, b) TM_CHECK_OP(_NE, !=, a, b)
#define TM_CHECK_LE(a, b) TM_CHECK_OP(_LE, <=, a, b)
#define TM_CHECK_LT(a, b) TM_CHECK_OP(_LT, <, a, b)
#define TM_CHECK_GE(a, b) TM_CHECK_OP(_GE, >=, a, b)
#define TM_CHECK_GT(a, b) TM_CHECK_OP(_GT, >, a, b)

#endif

[[noreturn]] void ReportNullError(const char* file, int line, const char* expr);

template
decltype(auto) EnsureNotNull(const char* file, int line, const char* expr, T&& p)
{
    if (TM_UNLIKELY(p == nullptr)) {
        ReportNullError(file, line, expr);
    }
    return (T &&) p;
}

#define TM_CHECK_NOTNULL(p) ::turbomind::core::EnsureNotNull(__FILE__, __LINE__, #p, (p))

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/common.h
================================================

#pragma once

#include 
#include 
#include 

/// TODO: remove this dependency
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::core {

class Allocator;
class Buffer;
class Stream;
class Event;
class Context;

using std::shared_ptr;
using std::vector;

using ssize_t = std::ptrdiff_t;

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/context.cc
================================================

#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/context.h"

namespace turbomind::core {

namespace {

struct ContextStorage {
    enum
    {
        stream_bit       = 1,
        host_alloc_bit   = 2,
        device_alloc_bit = 4,
        pinned_alloc_bit = 8,
    };

    std::stack    stream_;
    std::stack host_alloc_;
    std::stack device_alloc_;
    std::stack pinned_alloc_;
    std::stack       mask_;

    ContextStorage()
    {
        push(Allocator{kCPU});
    }

    void push(const Stream& stream)
    {
        int mask{};
        if (stream) {
            stream_.push(stream);
            mask = stream_bit;
        }
        mask_.push(mask);
    }

    void push(const Allocator& alloc)
    {
        int mask{};
        if (alloc) {
            const auto type = alloc->device().type;
            if (type == kCPU) {
                mask = host_alloc_bit;
                host_alloc_.push(alloc);
            }
            else if (type == kDEVICE) {
                mask = device_alloc_bit;
                device_alloc_.push(alloc);
            }
            else if (type == kCPUpinned) {
                mask = pinned_alloc_bit;
                pinned_alloc_.push(alloc);
            }
        }
        mask_.push(mask);
    }

    void pop()
    {
        if (mask_.top() & stream_bit) {
            stream_.pop();
        }
        if (mask_.top() & host_alloc_bit) {
            host_alloc_.pop();
        }
        if (mask_.top() & device_alloc_bit) {
            device_alloc_.pop();
        }
        if (mask_.top() & pinned_alloc_bit) {
            pinned_alloc_.pop();
        }
        mask_.pop();
    }

    static ContextStorage& instance()
    {
        thread_local ContextStorage inst{};
        return inst;
    }
};

}  // namespace

void Context::push(const Stream& stream)
{
    ContextStorage::instance().push(stream);
}

void Context::push(const Allocator& alloc)
{
    ContextStorage::instance().push(alloc);
}

void Context::pop()
{
    ContextStorage::instance().pop();
}

Stream& Context::stream()
{
    auto& stream_ = ContextStorage::instance().stream_;
    TM_CHECK(!stream_.empty()) << "No STREAM available in current context";
    return stream_.top();
}

Allocator& Context::host_alloc()
{
    auto& host_alloc_ = ContextStorage::instance().host_alloc_;
    TM_CHECK(!host_alloc_.empty()) << "No HOST memory allocator available in current context";
    return host_alloc_.top();
}

Allocator& Context::device_alloc()
{
    auto& device_alloc_ = ContextStorage::instance().device_alloc_;
    TM_CHECK(!device_alloc_.empty()) << "No DEVICE memory allocator available in current context";
    return device_alloc_.top();
}

Allocator& Context::pinned_alloc()
{
    auto& pinned_alloc_ = ContextStorage::instance().pinned_alloc_;
    TM_CHECK(!pinned_alloc_.empty()) << "No PINNED memory allocator available in current context";
    return pinned_alloc_.top();
}

Allocator& Context::alloc(Device device)
{
    switch (device.type) {
        case kDEVICE:
            return device_alloc();
        case kCPU:
            return host_alloc();
        case kCPUpinned:
            return pinned_alloc();
    }
    TM_UNREACHABLE;
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/context.h
================================================
#pragma once

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/common.h"
#include "src/turbomind/core/stream.h"

namespace turbomind::core {

class Context {
public:
    static Stream&    stream();
    static Allocator& host_alloc();
    static Allocator& device_alloc();
    static Allocator& pinned_alloc();
    static Allocator& alloc(Device device);

private:
    friend class ContextGuard;
    static void push(const Stream& stream);
    static void push(const Allocator& alloc);
    static void pop();
};

class ContextGuard {
public:
    template
    explicit ContextGuard(Args&&... args): n_{}
    {
        (Context::push((Args &&) args), ...);
        n_ = sizeof...(Args);
    }
    ~ContextGuard()
    {
        for (int i = 0; i < n_; ++i) {
            Context::pop();
        }
    }

private:
    int n_;
};

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/copy.cc
================================================

#include "src/turbomind/core/copy.h"

#include 
#include 
#include 

#include 

#include "src/turbomind/core/check.h"

namespace turbomind::core {

// picked from "cudaTypedefs.h" / "cuda.h"

typedef enum CUmemcpyFlags_enum
{
    CU_MEMCPY_FLAG_DEFAULT                     = 0x0,
    CU_MEMCPY_FLAG_PREFER_OVERLAP_WITH_COMPUTE = 0x1
} CUmemcpyFlags;

typedef enum CUmemcpySrcAccessOrder_enum
{
    CU_MEMCPY_SRC_ACCESS_ORDER_INVALID         = 0x0,
    CU_MEMCPY_SRC_ACCESS_ORDER_STREAM          = 0x1,
    CU_MEMCPY_SRC_ACCESS_ORDER_DURING_API_CALL = 0x2,
    CU_MEMCPY_SRC_ACCESS_ORDER_ANY             = 0x3,
    CU_MEMCPY_SRC_ACCESS_ORDER_MAX             = 0x7FFFFFFF
} CUmemcpySrcAccessOrder;

typedef struct CUmemcpyAttributes_st {
    CUmemcpySrcAccessOrder srcAccessOrder;
    CUmemLocation          srcLocHint;
    CUmemLocation          dstLocHint;
    unsigned int           flags;
} CUmemcpyAttributes_v1;

typedef CUresult(CUDAAPI* PFN_cuMemcpyBatchAsync_v12080)(CUdeviceptr_v2*        dsts,
                                                         CUdeviceptr_v2*        srcs,
                                                         size_t*                sizes,
                                                         size_t                 count,
                                                         CUmemcpyAttributes_v1* attrs,
                                                         size_t*                attrIdxs,
                                                         size_t                 numAttrs,
                                                         size_t*                failIdx,
                                                         CUstream               hStream);

/// TODO: add `PFN_cuMemcpyBatchAsync_v13000`

namespace {

const auto& GetCopyAPI()
{
    static auto inst = []() -> std::variant {
        const auto                      symbol = "cuMemcpyBatchAsync";
        cudaDriverEntryPointQueryResult status{};
        void*                           fpn{};
        TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);
        if (fpn && status == cudaDriverEntryPointSuccess) {
            return (PFN_cuMemcpyBatchAsync_v12080)fpn;
        }
        else {
            return {};
        }
    }();
    return inst;
}

}  // namespace

BatchCopy::~BatchCopy() = default;

BatchCopy::BatchCopy(): self_{this}
{
    Reset();
}

void BatchCopy::Run()
{
    if (src_.empty()) {
        return;
    }

    std::visit(
        [&](auto&& copy) {
            using T = std::decay_t;
            if constexpr (std::is_same_v) {
                CUmemcpyAttributes_v1 attr{};
                attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
                attr.flags          = CU_MEMCPY_FLAG_PREFER_OVERLAP_WITH_COMPUTE;
                std::vector ais(src_.size(), 0);
                size_t              fail_idx{SIZE_MAX};

                auto status = copy((CUdeviceptr_v2*)dst_.data(),
                                   (CUdeviceptr_v2*)src_.data(),
                                   size_.data(),
                                   src_.size(),
                                   &attr,
                                   ais.data(),
                                   1,
                                   &fail_idx,
                                   core::Context::stream().handle());

                if (auto i = fail_idx; i != SIZE_MAX) {
                    TM_CHECK(0) << (void*)src_[i] << " " << size_[i] << " " << (void*)dst_[i] << " code " << status;
                }
            }
            else {
                for (unsigned i = 0; i < src_.size(); ++i) {
                    core::Copy(src_[i], size_[i], dst_[i]);
                }
            }
        },
        GetCopyAPI());

    Reset();
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/copy.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/check.h"

namespace turbomind::core {

class BatchCopy {
public:
    ~BatchCopy();

    BatchCopy();

    BatchCopy(const BatchCopy&) = delete;
    BatchCopy& operator=(const BatchCopy&) = delete;
    BatchCopy(BatchCopy&&) noexcept        = delete;
    BatchCopy& operator=(BatchCopy&&) noexcept = delete;

    // clang-format off
    class Group {
    public:
        ~Group() { parent_.group_end(); }
        Group(BatchCopy& parent): parent_{parent} { parent_.group_begin(); }
        explicit constexpr operator bool() const noexcept { return true; }
    private:
        BatchCopy& parent_;
    };
    // clang-format on

    friend Group;

    Group group()
    {
        return {*this};
    }

    template
    T* operator()(const T* src, ssize_t size, T* dst)
    {
        // return core::Copy(src, size, dst);

        /// TODO: verify this is actually a fast path in a loop (without extra jump)
        if (TM_LIKELY(group_ && src == (const T*)src_ptr_ && dst == (T*)dst_ptr_)) {
            src_ptr_ += sizeof(T) * size;
            dst_ptr_ += sizeof(T) * size;
            gsize_ += sizeof(T) * size;
            count_ += 1;
            return dst + size;
        }
        else if (group_) {
            group_commit();
            gsize_   = sizeof(T) * size;
            src_ptr_ = reinterpret_cast(src + size);
            dst_ptr_ = reinterpret_cast(dst + size);
            count_ += 1;
            return dst + size;
        }
        else {
            gsize_   = sizeof(T) * size;
            src_ptr_ = reinterpret_cast(src + size);
            dst_ptr_ = reinterpret_cast(dst + size);
            count_   = 1;
            group_commit();
            return dst + size;
        }
    }

    void operator()(const Buffer& src, ssize_t size, Ref dst_)
    {
        auto& dst = dst_.get();
        TM_CHECK_EQ(src.dtype(), dst.dtype());
        TM_CHECK_LE(size, src.size());
        TM_CHECK_LE(size, dst.size());
        (*this)((const char*)src.raw_data(), byte_size(src.dtype(), size), (char*)dst.raw_data());
    }

    void Run();

    Buffer_ buf()
    {
        return {&self_, 1, kCPU};
    }

    friend std::ostream& operator<<(std::ostream& os, const BatchCopy& a)
    {
        os << "(" << a.count_ << ", " << a.src_.size() << ")";
        return os;
    }

private:
    void Reset()
    {
        src_.clear();
        dst_.clear();
        size_.clear();
        count_ = 0;
    }

    void group_begin()
    {
        TM_CHECK(!group_) << "Nested group is not supported";
        group_ = true;
    }

    void group_end()
    {
        TM_CHECK(group_) << "Mismatched group end";
        group_commit();
        group_ = false;
    }

    void group_commit()
    {
        if (gsize_) {
            src_.push_back(src_ptr_ - gsize_);
            dst_.push_back(dst_ptr_ - gsize_);
            size_.push_back(gsize_);
        }
        src_ptr_ = dst_ptr_ = {};
        gsize_              = {};
    }

private:
    std::vector src_;
    std::vector       dst_;
    std::vector      size_;

    int         group_   = 0;
    size_t      gsize_   = 0;
    const char* src_ptr_ = {};
    char*       dst_ptr_ = {};

    size_t count_;

    BatchCopy* self_;
};

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/core.h
================================================
#pragma once

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/copy.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/layout.h"
#include "src/turbomind/core/ranges.h"
#include "src/turbomind/core/stream.h"
#include "src/turbomind/core/tensor.h"

namespace turbomind {

using core::ssize_t;
using core::Buffer;
using core::Buffer_;
using core::Tensor;
using core::Tensor_;
using core::TensorMap;
using core::Ref;
using core::Layout;
using core::Allocator;
using core::Stream;
using core::Event;
using core::BatchCopy;

using core::subrange;

}  // namespace turbomind


================================================
FILE: src/turbomind/core/cuda_data_type.h
================================================


#include 
#include 

#include 

#include 
#include 
#include 

#include "src/turbomind/core/data_type.h"

namespace turbomind {

// clang-format off

constexpr cudaDataType to_cuda_dtype(DataType type)
{
    switch (type) {
        case kUint8:  return CUDA_R_8U;
        case kUint16: return CUDA_R_16U;
        case kUint32: return CUDA_R_32U;
        case kUint64: return CUDA_R_64U;
        case kInt8:  return CUDA_R_8I;
        case kInt16: return CUDA_R_16I;
        case kInt32: return CUDA_R_32I;
        case kInt64: return CUDA_R_64I;
        case kFloat16: return CUDA_R_16F;
        case kFloat32: return CUDA_R_32F;
        case kFloat64: return CUDA_R_64F;
        case kBfloat16: return CUDA_R_16BF;
        case kFloat8_e4m3: return CUDA_R_8F_E4M3;
        case kFloat8_e5m2: return CUDA_R_8F_E5M2;
        default:
            throw std::runtime_error("Not supported " + std::string{to_string(type)});
    }
}

constexpr DataType from_cuda_dtype(cudaDataType type) {
    switch (type) {
        case CUDA_R_8U:  return kUint8;
        case CUDA_R_16U: return kUint16;
        case CUDA_R_32U: return kUint32;
        case CUDA_R_64U: return kUint64;
        case CUDA_R_8I:  return kInt8;
        case CUDA_R_16I: return kInt16;
        case CUDA_R_32I: return kInt32;
        case CUDA_R_64I: return kInt64;
        case CUDA_R_16F: return kFloat16;
        case CUDA_R_32F: return kFloat32;
        case CUDA_R_64F: return kFloat64;
        case CUDA_R_16BF: return kBfloat16;
        case CUDA_R_8F_E4M3: return kFloat8_e4m3;
        case CUDA_R_8F_E5M2: return kFloat8_e5m2;
        default:
            throw std::runtime_error("Not supported " + std::string{std::to_string(type)});
    }
}

#if __CUDACC_VER_MAJOR__ >= 12

constexpr CUtensorMapDataType to_CUtensorMap_dtype(DataType type) {
    switch (type) {
        case kFloat32:
            return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
        case kFloat16:
            return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
        case kBfloat16:
            return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
        case kFloat8_e4m3:
        case kFloat8_e5m2:
            return CU_TENSOR_MAP_DATA_TYPE_UINT8;
        default:
            throw std::runtime_error("Not supported " + std::string{to_string(type)});
    }
}

#endif

// clang-format on

}  // namespace turbomind


================================================
FILE: src/turbomind/core/data_type.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/check.h"

#include 
#include 
#include 

// forward declarations for CUDA floating point types
struct __half;
struct __nv_bfloat16;
struct __nv_fp8_e4m3;
struct __nv_fp8_e5m2;

namespace turbomind {

// clang-format off

struct uint2_t {};
struct uint4_t {};
struct uint6_t {};

template 
struct int_constant: std::integral_constant {};

template 
struct bitsof_t: int_constant {};

template <> struct bitsof_t: int_constant<2> {};
template <> struct bitsof_t: int_constant<4> {};
template <> struct bitsof_t: int_constant<6> {};

template 
inline constexpr bitsof_t bitsof{};

using half_t = __half;
using bfloat16_t = __nv_bfloat16;
using fp8_e4m3_t = __nv_fp8_e4m3;
using fp8_e5m2_t = __nv_fp8_e5m2;

struct fp4_e2m1_t {};

template <> struct bitsof_t: int_constant<4> {};


constexpr int encode_data_type(bool sign, int exponent, int mantissa) {
    return ((sign << 16) | (exponent << 8) | mantissa);
}

enum class DataType: int {
    kNull        = 0,
    kBool        = 1,
    kUint8       = encode_data_type(0,  0,  8),
    kUint16      = encode_data_type(0,  0, 16),
    kUint32      = encode_data_type(0,  0, 32),
    kUint64      = encode_data_type(0,  0, 64),
    kInt8        = encode_data_type(1,  0,  8),
    kInt16       = encode_data_type(1,  0, 16),
    kInt32       = encode_data_type(1,  0, 32),
    kInt64       = encode_data_type(1,  0, 64),
    kFloat16     = encode_data_type(1,  5, 10),
    kFloat32     = encode_data_type(1,  8, 23),
    kFloat64     = encode_data_type(1, 11, 52),
    kBfloat16    = encode_data_type(1,  8,  7),
    kFloat4_e2m1 = encode_data_type(1,  2,  1),
    kFloat6_e2m3 = encode_data_type(1,  2,  3),
    kFloat6_e3m2 = encode_data_type(1,  3,  2),
    kFloat8_e4m3 = encode_data_type(1,  4,  3),
    kFloat8_e5m2 = encode_data_type(1,  5,  2),
    kUint2       = encode_data_type(0,  0,  2),
    kUint4       = encode_data_type(0,  0,  4),
    kUint6       = encode_data_type(0,  0,  6),
    kPointer,
    kUint        = kUint32,
    kInt         = kInt32,
    kFloat       = kFloat32,
    kHalf        = kFloat16,
    kDouble      = kFloat64,
    kE2m1        = kFloat4_e2m1,
    kE2m3        = kFloat6_e2m3,
    kE3m2        = kFloat6_e3m2,
    kE4m3        = kFloat8_e4m3,
    kE5m2        = kFloat8_e5m2,
};

inline constexpr DataType kNull = DataType::kNull;
inline constexpr DataType kBool = DataType::kBool;
inline constexpr DataType kPointer = DataType::kPointer;
inline constexpr DataType kUint8  = DataType::kUint8;
inline constexpr DataType kUint16 = DataType::kUint16;
inline constexpr DataType kUint32 = DataType::kUint32;
inline constexpr DataType kUint64 = DataType::kUint64;
inline constexpr DataType kInt8  = DataType::kInt8;
inline constexpr DataType kInt16 = DataType::kInt16;
inline constexpr DataType kInt32 = DataType::kInt32;
inline constexpr DataType kInt64 = DataType::kInt64;
inline constexpr DataType kFloat16 = DataType::kFloat16;
inline constexpr DataType kFloat32 = DataType::kFloat32;
inline constexpr DataType kFloat64 = DataType::kFloat64;
inline constexpr DataType kBfloat16 = DataType::kBfloat16;
inline constexpr DataType kFloat8_e4m3 = DataType::kFloat8_e4m3;
inline constexpr DataType kFloat8_e5m2 = DataType::kFloat8_e5m2;
inline constexpr DataType kFloat4_e2m1 = DataType::kFloat4_e2m1;
inline constexpr DataType kUint2  = DataType::kUint2;
inline constexpr DataType kUint4  = DataType::kUint4;
inline constexpr DataType kUint6  = DataType::kUint6;
inline constexpr DataType kUint = DataType::kUint;
inline constexpr DataType kInt = DataType::kInt;
inline constexpr DataType kHalf = DataType::kHalf;
inline constexpr DataType kFloat = DataType::kFloat;
inline constexpr DataType kDouble = DataType::kDouble;

template 
struct to_data_type;

template 
struct from_data_type;

#define CVT_DATA_TYPE(D, T) \
    template <> struct to_data_type { static constexpr auto value = DataType::D; }; \
    template <> struct from_data_type { using type = T; }

CVT_DATA_TYPE(kNull, void);

CVT_DATA_TYPE(kBool, bool);
CVT_DATA_TYPE( kUint8, uint8_t);
CVT_DATA_TYPE(kUint16, uint16_t);
CVT_DATA_TYPE(kUint32, uint32_t);
CVT_DATA_TYPE(kUint64, uint64_t);

CVT_DATA_TYPE( kInt8, int8_t);  // NOTE: `int8_t` is `signed char` and is different from `char`
CVT_DATA_TYPE(kInt16, int16_t);
CVT_DATA_TYPE(kInt32, int32_t);
CVT_DATA_TYPE(kInt64, int64_t);

CVT_DATA_TYPE(kFloat16, half_t);
CVT_DATA_TYPE(kFloat32, float);
CVT_DATA_TYPE(kFloat64, double);
CVT_DATA_TYPE(kBfloat16, bfloat16_t);
CVT_DATA_TYPE(kFloat4_e2m1, fp4_e2m1_t);
CVT_DATA_TYPE(kFloat8_e4m3, fp8_e4m3_t);
CVT_DATA_TYPE(kFloat8_e5m2, fp8_e5m2_t);

CVT_DATA_TYPE(kUint2, uint2_t);
CVT_DATA_TYPE(kUint4, uint4_t);
CVT_DATA_TYPE(kUint6, uint6_t);

#undef CVT_DATA_TYPE

template  struct to_data_type { static constexpr auto value = DataType::kPointer; };
template <>  struct from_data_type { using type = void*; };

template 
inline constexpr auto data_type_v = to_data_type>::value;

template 
using data_type_t = typename from_data_type::type;

constexpr std::ptrdiff_t byte_size(DataType type, std::ptrdiff_t size = 1) {
    switch (type) {
        case kNull: return 0;
        case kBool:
        case kUint8:
        case kInt8:
        case kFloat8_e4m3:
        case kFloat8_e5m2:
            return size;
        case kUint16:
        case kInt16:
        case kFloat16:
        case kBfloat16:
            return size * 2;
        case kUint32:
        case kInt32:
        case kFloat32:
            return size * 4;
        case kUint64:
        case kInt64:
        case kFloat64:
            return size * 8;
        case kUint2: return size * 2 / 8;
        case kUint4:
        case kFloat4_e2m1:
            return size * 4 / 8;
        case kUint6: return size * 6 / 8;
        case kPointer: return size * sizeof(void*);
        default:
            return 0;
    }
    return 0;
}

template 
constexpr std::ptrdiff_t byte_size(std::ptrdiff_t size = 1) { return byte_size(data_type_v, size); }

constexpr std::ptrdiff_t numel(DataType type, std::ptrdiff_t size = 1) {
    switch (type) {
        case kNull: return 0;
        case kBool:
        case kUint8:
        case kInt8:
        case kFloat8_e4m3:
        case kFloat8_e5m2:
            return size;
        case kUint16:
        case kInt16:
        case kFloat16:
        case kBfloat16:
            return size / 2;
        case kUint32:
        case kInt32:
        case kFloat32:
            return size / 4;
        case kUint64:
        case kInt64:
        case kFloat64:
            return size / 8;
        case kUint2: return size * 8 / 2;
        case kUint4:
        case kFloat4_e2m1:
            return size * 8 / 4;
        case kUint6: return size * 8 / 6;
        case kPointer: return size / sizeof(void*);
        default:
            return 0;
    }
    return 0;
}

template 
constexpr std::ptrdiff_t numel(std::ptrdiff_t size) { return numel(data_type_v, size); }

constexpr const char* to_string(DataType type) {
    switch (type) {
        case kNull: return "nil";
        case kBool: return "bool";
        case kUint8: return "u8";
        case kUint16: return "u16";
        case kUint32: return "u32";
        case kUint64: return "u64";
        case kInt8: return "i8";
        case kInt16: return "i16";
        case kInt32: return "i32";
        case kInt64: return "i64";
        case kFloat16: return "f16";
        case kFloat32: return "f32";
        case kFloat64: return "f64";
        case kBfloat16: return "bf16";
        case kFloat8_e4m3: return "e4m3";
        case kFloat8_e5m2: return "e5m2";
        case kFloat4_e2m1: return "e2m1";
        case kUint2: return "u2";
        case kUint4: return "u4";
        case kUint6: return "u8";
        case kPointer: return "pointer";
        default:
            return "unknown";
    }
    return "";
}

inline std::ostream& operator<<(std::ostream& os, DataType type) {
    os << to_string(type);
    return os;
}

/// TODO: mapping with DLPack

// clang-format on

#define TM_PP_NARGS(...) TM_PP_NARGS_IMPL(__VA_ARGS__, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#define TM_PP_NARGS_IMPL(_0, _1, _2, _3, _4, _5, _6, _7, N, ...) N

#define TM_PP_CAT(a, b) a##b
#define TM_PP_STR(x) #x

#define TM_PP_DISPATCH_N(macro, ...) TM_PP_DISPATCH_N_IMPL(macro, TM_PP_NARGS(__VA_ARGS__))
#define TM_PP_DISPATCH_N_IMPL(macro, x) TM_PP_CAT(macro, x)

#define TM_PP_INVOKE_1(macro, f, _0) macro(f, _0)

#define TM_PP_INVOKE_2(macro, f, _0, _1)                                                                               \
    macro(f, _0);                                                                                                      \
    macro(f, _1)

#define TM_PP_INVOKE_3(macro, f, _0, _1, _2)                                                                           \
    macro(f, _0);                                                                                                      \
    macro(f, _1);                                                                                                      \
    macro(f, _2)

#define TM_PP_INVOKE_4(macro, f, _0, _1, _2, _3)                                                                       \
    macro(f, _0);                                                                                                      \
    macro(f, _1);                                                                                                      \
    macro(f, _2);                                                                                                      \
    macro(f, _3)

#define TM_PP_INVOKE_5(macro, f, _0, _1, _2, _3, _4)                                                                   \
    macro(f, _0);                                                                                                      \
    macro(f, _1);                                                                                                      \
    macro(f, _2);                                                                                                      \
    macro(f, _3);                                                                                                      \
    macro(f, _4)

#define TM_DISPATCH_DTYPE_RET_CASE(f, t)                                                                               \
    case ::turbomind::data_type_v:                                                                                  \
        return f(t{});

#define TM_DISPATCH_DTYPE_CASE(f, t)                                                                                   \
    case ::turbomind::data_type_v:                                                                                  \
        f(t{});                                                                                                        \
        break

// clang-format off
#define TM_DISPATCH_DTYPES_RET(var, f, ...)                                                                            \
    switch (var) {                                                                                                     \
        TM_PP_DISPATCH_N(TM_PP_INVOKE_, __VA_ARGS__)(TM_DISPATCH_DTYPE_RET_CASE, f, __VA_ARGS__);                      \
        default:                                                                                                       \
            TM_CHECK(0) << "unsupported type: "  << to_string(var);                                                    \
            return {};                                                                                                 \
    }

#define TM_DISPATCH_DTYPES(var, f, ...)                                                                                \
    switch (var) {                                                                                                     \
        TM_PP_DISPATCH_N(TM_PP_INVOKE_, __VA_ARGS__)(TM_DISPATCH_DTYPE_CASE, f, __VA_ARGS__);                          \
        default:                                                                                                       \
            TM_CHECK(0) << "unsupported type: "  << to_string(var);                                                    \
    }
// clang-format on

#define TM_PRIMARY_DTYPES_0 ::turbomind::half_t

#if ENABLE_BF16
#define TM_PRIMARY_DTYPES_1 TM_PRIMARY_DTYPES_0, ::turbomind::bfloat16_t
#else
#define TM_PRIMARY_DTYPES_1 TM_PRIMARY_DTYPES_0
#endif

#if ENABLE_FP32
#define TM_PRIMARY_DTYPES TM_PRIMARY_DTYPES_1, float
#else
#define TM_PRIMARY_DTYPES TM_PRIMARY_DTYPES_1
#endif

#define TM_DISPATCH_PRIMARY_DTYPES(var, func) TM_DISPATCH_DTYPES(var, func, TM_PRIMARY_DTYPES)

#define TM_DISPATCH_PRIMARY_DTYPES_RET(var, func) TM_DISPATCH_DTYPES_RET(var, func, TM_PRIMARY_DTYPES)

}  // namespace turbomind


================================================
FILE: src/turbomind/core/interval.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

namespace turbomind {

class Interval {
public:
    struct Size {
        int      x;
        explicit operator int() const noexcept
        {
            return x;
        }
        friend bool operator<(const Size& a, const Size& b)
        {
            return a.x < b.x;
        }
    };

    Interval(): first_{0}, last_{0} {}

    explicit Interval(int first): first_{first}, last_{INT_MAX} {};

    Interval(int first, int last): first_{first}, last_{last} {}

    Interval(int first, Size size): first_{first}, last_{first + (int)size} {}

    bool empty() const noexcept
    {
        return first_ >= last_;
    }

    explicit operator bool() const noexcept
    {
        return !empty();
    }

    Size size() const noexcept
    {
        return Size{std::max(0, last_ - first_)};
    }

    int begin() const noexcept
    {
        return first_;
    }

    int end() const noexcept
    {
        return last_;
    }

    friend Interval operator&(const Interval& a, const Interval& b)
    {
        return {std::max(a.first_, b.first_), std::min(a.last_, b.last_)};
    }

    friend Interval operator|(const Interval& a, const Interval& b)
    {
        return {std::min(a.first_, b.first_), std::max(a.last_, b.last_)};
    }

    // dilate / erode left
    friend Interval operator|(int x, const Interval& a)
    {
        return {a.begin() - x, a.end()};
    }

    // dilate / erode right
    friend Interval operator|(const Interval& a, int x)
    {
        return {a.begin(), a.end() + x};
    }

    friend std::ostream& operator<<(std::ostream& os, const Interval& a)
    {
        return os << "[" << a.first_ << ", " << a.last_ << ")";
    }

    friend std::ostream& operator<<(std::ostream& os, const Interval* a)
    {
        return os << *a;
    }

private:
    int first_;
    int last_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/core/layout.cc
================================================

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/layout.h"

namespace turbomind::core {

Layout::Layout(std::vector shape): shape_{std::move(shape)}
{
    TM_CHECK(shape_.size());
    stride_.resize(shape_.size());
    size_ = 1;
    for (int i = shape_.size() - 1; i >= 0; --i) {
        stride_[i] = size_;
        size_ *= shape_[i];
    }
}

Layout::Layout(vector shape, vector stride): shape_{std::move(shape)}, stride_{std::move(stride)}
{
    TM_CHECK(shape_.size());
    TM_CHECK_EQ(shape_.size(), stride_.size());

    size_ = std::accumulate(shape_.begin(), shape_.end(), ssize_t{1}, std::multiplies<>{});

    TM_CHECK_GE(size_, 0);
}

ssize_t Layout::cosize() const noexcept
{
    if (rank() == 0) {
        return 0;
    }
    ssize_t value{1};
    for (size_t i = 0; i < shape_.size(); ++i) {
        value += (shape_[i] - 1) * stride_[i];
    }
    return value;
}

Layout Layout::coalesce() const noexcept
{
    vector shape{shape_.front()};
    vector stride{stride_.front()};

    for (size_t i = 1; i < shape_.size(); ++i) {
        if (shape_[i] == 1) {
            continue;
        }
        else if (shape.back() == 1) {
            shape.back()  = shape_[i];
            stride.back() = stride_[i];
        }
        else if (stride.back() == shape_[i] * stride_[i]) {
            stride.back() = stride_[i];
            shape.back() *= shape_[i];
        }
        else {
            shape.push_back(shape_[i]);
            stride.push_back(stride_[i]);
        }
    }

    return Layout{shape, stride};
}

Layout Layout::view(vector shape) const
{
    if (shape == shape_) {
        return *this;
    }

    TM_CHECK(!shape.empty());

    // size check & wildcard resolution
    auto wildcard = std::find(shape.begin(), shape.end(), -1);
    if (wildcard != shape.end()) {
        TM_CHECK(std::find(wildcard + 1, shape.end(), -1) == shape.end());
        *wildcard = 1;
    }
    auto new_size = std::accumulate(shape.begin(), shape.end(), ssize_t{1}, std::multiplies<>{});
    if (wildcard != shape.end()) {
        TM_CHECK(size_ % new_size == 0) << size_ << " % " << new_size;
        *wildcard = size_ / new_size;
    }
    else {
        TM_CHECK_EQ(size_, new_size);
    }

    if (is_contiguous()) {
        return Layout{shape};
    }

    const Layout c = coalesce();  // merge contiguous dimensions

    ssize_t p = c.rank();
    ssize_t s = 1;
    ssize_t d = 0;

    vector stride(shape.size());

    for (int i = shape.size() - 1; i >= 0; --i) {
        if (shape[i] == 1) {
            stride[i] = 0;
        }
        else {
            if (s == 1) {
                --p;
                s = c.shape().at(p);
                d = c.stride().at(p);
            }
            TM_CHECK_EQ(s % shape[i], 0);  // crossing non-contiguous dimensions
            stride[i] = d;
            d *= shape[i];
            s /= shape[i];
        }
    }
    return Layout{std::move(shape), std::move(stride)};
}

std::pair Layout::slice(const vector& base, vector shape) const
{
    TM_CHECK_EQ(base.size(), shape.size());
    TM_CHECK_EQ(shape_.size(), shape.size());
    ssize_t offset = 0;
    for (size_t i = 0; i < shape.size(); ++i) {
        const auto space = shape_[i] - base[i];
        TM_CHECK_GE(space, 0);
        if (shape[i] == -1) {
            shape[i] = space;
        }
        TM_CHECK_LE(shape[i], space);
        offset += base[i] * stride_[i];
    }
    return {Layout{std::move(shape), stride_}, offset};
}

std::ostream& operator<<(std::ostream& os, const Layout& x)
{
    os << "(";
    for (int i = 0; i < x.rank(); ++i) {
        os << (i ? "," : "") << x.shape_[i];
    }
    os << "):(";
    for (int i = 0; i < x.rank(); ++i) {
        os << (i ? "," : "") << x.stride_[i];
    }
    os << ")";
    return os;
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/layout.h
================================================

#pragma once

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/common.h"

namespace turbomind::core {

class Layout {
public:
    Layout(): size_{0} {}

    /* implicit */ Layout(vector shape);

    /* implicit */ Layout(std::initializer_list shape): Layout(vector(shape)) {}

    Layout(vector shape, vector stride);

    ssize_t size() const noexcept
    {
        return size_;
    }

    ssize_t cosize() const noexcept;

    ssize_t rank() const noexcept
    {
        return shape_.size();
    }

    auto& shape() const noexcept
    {
        return shape_;
    }

    auto shape(int i) const
    {
        return shape_.at(wrap(i));
    }

    template
    auto shapes(Is... is) const
    {
        return std::make_tuple(shape(is)...);
    }

    auto& stride() const noexcept
    {
        return stride_;
    }

    auto stride(int i) const
    {
        return stride_.at(wrap(i));
    }

    template
    auto strides(Is... is) const
    {
        return std::make_tuple(stride(is)...);
    }

    bool is_contiguous() const noexcept
    {
        if (stride_.back() != 1) {
            return false;
        }
        if (size() != cosize()) {
            return false;
        }
        for (int i = 0; i < rank() - 1; ++i) {
            // TODO: skip when shape == 1
            if (stride_[i] < stride_[i + 1]) {
                return false;
            }
        }
        return true;
    }

    Layout permute(const vector& dims) const
    {
        TM_CHECK((int)dims.size() == rank());
        auto a = *this;
        for (int i = 0; i < rank(); ++i) {
            a.shape_[i]  = shape_[dims[i]];
            a.stride_[i] = stride_[dims[i]];
        }
        return a;
    }

    Layout transpose(int a, int b) const
    {
        TM_CHECK_LT(a, rank());
        TM_CHECK_LT(b, rank());
        auto x = *this;
        std::swap(x.shape_[a], x.shape_[b]);
        std::swap(x.stride_[a], x.stride_[b]);
        return x;
    }

    ssize_t offset(const vector& idxs) const
    {
        TM_CHECK((int)idxs.size() < rank());
        ssize_t val = 0;
        for (size_t i = 0; i < idxs.size(); ++i) {
            TM_CHECK_LT(idxs[i], shape_[i]);
            val += idxs[i] * stride_[i];
        }
        return val;
    }

    ssize_t offset(ssize_t idx0) const
    {
        TM_CHECK(rank());
        TM_CHECK_LT(idx0, shape_[0]);
        return stride_[0] * idx0;
    }

    Layout coalesce() const noexcept;

    Layout view(vector shape) const;

    std::pair slice(const vector& base, vector shape) const;

    Layout squeeze(int dim) const
    {
        if (rank() == 1 || shape(dim) != 1) {
            return *this;
        }
        Layout a;
        a.shape_.reserve(rank() - 1);
        a.stride_.reserve(rank() - 1);
        for (int i = 0; i < rank(); ++i) {
            if (i != dim) {
                a.shape_.push_back(shape_[i]);
                a.stride_.push_back(stride_[i]);
            }
        }
        a.size_ = size_;
        return a;
    }

    friend std::ostream& operator<<(std::ostream& os, const Layout& x);

    friend bool operator==(const Layout& a, const Layout& b)
    {
        return a.shape_ == b.shape_ && a.stride_ == b.stride_;
    }

    friend bool operator!=(const Layout& a, const Layout& b)
    {
        return !(a == b);
    }

private:
    int wrap(int dim) const noexcept
    {
        return dim < 0 ? dim + shape_.size() : dim;
    }

private:
    vector shape_;
    vector stride_;
    ssize_t         size_;
};

inline std::string to_string(const Layout& x)
{
    std::stringstream ss;
    ss << x;
    return ss.str();
}

// clang-format off
template
void save(Archive& ar, const Layout& layout)
{
    ar & layout.shape();
    ar & layout.stride();
}

template
void load(Archive& ar, Layout& layout)
{
    vector shape;
    vector stride;
    ar & shape;
    ar & stride;
    layout = Layout(std::move(shape), std::move(stride));
}
// clang-format on

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/module.cc
================================================

#include "src/turbomind/core/module.h"
#include "src/turbomind/core/check.h"
#include 

namespace turbomind::core {

Module::Module(): parent_{} {}

Module::~Module()
{
    if (parent_) {
        parent_->remove_module(*this);
        parent_ = {};
    }
}

void Module::register_module(std::string name, Module& module, std::optional index)
{
    module.parent_ = this;
    if (index) {
        name += ".";
        name += std::to_string(*index);
    }
    // std::cout << "register Module " << name << " " << &module << ", parent " << this << "\n";
    modules_.emplace_back(std::move(name), &module);
}

void Module::register_parameter(std::string name, Tensor& param)
{
    // std::cout << "register Parameter " << name << " " << ¶m << " " << param.layout() << "\n";
    params_.emplace_back(std::move(name), ¶m);
}

void Module::remove_module(Module& module)
{
    for (auto it = modules_.begin(); it != modules_.end(); ++it) {
        if (it->second == &module) {
            // std::cout << "erase " << it->first << " " << &module << " from " << this << "\n";
            modules_.erase(it);
            return;
        }
    }
    TM_CHECK(0) << "module " << &module << " not found";
}

void Module::remove_parameter(Tensor& param)
{
    for (auto it = params_.begin(); it != params_.end(); ++it) {
        if (it->second == ¶m) {
            params_.erase(it);
            return;
        }
    }
    TM_CHECK(0) << "param " << ¶m << " not found";
}

std::unordered_map Module::get_parameters() const
{
    std::unordered_map m;
    get_parameters_impl({}, m);
    return m;
}

void Module::get_parameters_impl(std::string prefix, std::unordered_map& m) const
{
    if (!prefix.empty()) {
        prefix += ".";
    }
    for (const auto& [k, v] : params_) {
        m.emplace(prefix + k, v);
    }
    for (const auto& [k, v] : modules_) {
        v->get_parameters_impl(prefix + k, m);
    }
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/module.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TURBOMIND_CORE_MODULE_H
#define TURBOMIND_CORE_MODULE_H

#include "src/turbomind/core/tensor.h"

namespace turbomind::core {

class Module {
public:
    virtual ~Module();

    Module();

    Module(const Module&) = delete;
    Module& operator=(const Module&) = delete;

    Module(Module&&) noexcept = delete;
    Module& operator=(Module&&) noexcept = delete;

    void register_module(std::string name, Module& module, std::optional index = {});
    void register_parameter(std::string name, Tensor& param);

    void remove_module(Module& module);
    void remove_parameter(Tensor& param);

    std::unordered_map get_parameters() const;

private:
    void get_parameters_impl(std::string prefix, std::unordered_map& m) const;

protected:
    Module* parent_;

    std::vector> modules_;
    std::vector> params_;
};

}  // namespace turbomind::core

#endif  // TURBOMIND_CORE_MODULE_H


================================================
FILE: src/turbomind/core/ranges.h
================================================
#pragma once

namespace turbomind::core {

template
class subrange {
public:
    subrange(Iterator first, Iterator last): first_{first}, last_{last} {}

    Iterator begin()
    {
        return first_;
    }

    Iterator end()
    {
        return last_;
    }

    auto empty() const
    {
        return first_ == last_;
    }

    auto size() const
    {
        return last_ - first_;
    }

private:
    Iterator first_;
    Iterator last_;
};

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/serdes.h
================================================
#pragma once

#include 
#include 
#include 
#include 
#include 

namespace turbomind::core {

template typename F, class SFINAE, class... Args>
struct is_detected: std::false_type {
};

template typename F, class... Args>
struct is_detected>, Args...>: std::true_type {
};

template
using save_t = decltype(save(std::declval(), std::declval()));

template
inline constexpr bool has_save_v = is_detected::value;

template
using load_t = decltype(load(std::declval(), std::declval()));

template
inline constexpr bool has_load_v = is_detected::value;

template
using serdes_t = decltype(serdes(std::declval(), std::declval()));

template
inline constexpr bool has_serdes_v = is_detected::value;

template
class ArrayWrapper {
public:
    ArrayWrapper(T* t, std::size_t size): t_{t}, size_{size}
    {
        static_assert(std::is_trivially_copyable_v, "ArrayWrapper requires trivially copyable type");
    }

    T* data() const
    {
        return t_;
    }

    std::size_t count() const
    {
        return size_;
    }

    T* const          t_;
    const std::size_t size_;
};

template
inline constexpr bool is_array_wrapper_v = std::false_type{};

template
inline constexpr bool is_array_wrapper_v> = std::true_type{};

template
struct OutputArchive {
    static constexpr bool is_loading = false;

    template
    void operator&(T&& x)
    {
        if constexpr (has_save_v) {
            save(*this, (T &&) x);
        }
        else if constexpr (has_serdes_v) {
            serdes(*this, (T &&) x);
        }
        else {
            reinterpret_cast(this)->write((T &&) x);
        }
    }
};

template
struct InputArchive {
    static constexpr bool is_loading = true;

    template
    void operator&(T&& x)
    {
        if constexpr (has_load_v) {
            load(*this, (T &&) x);
        }
        else if constexpr (has_serdes_v) {
            serdes(*this, (T &&) x);
        }
        else {
            reinterpret_cast(this)->read((T &&) x);
        }
    }
};

struct BinarySizeArchive: OutputArchive {
    size_t size_{};

    size_t size()
    {
        return size_;
    }

    template
    void write(const T& x)
    {
        static_assert(std::is_trivially_copyable_v);
        size_ += sizeof(x);
    }

    template
    void write(const ArrayWrapper& arr)
    {
        static_assert(std::is_trivially_copyable_v);
        size_ += sizeof(T) * arr.count();
    }
};

struct BinaryOutputArchive: OutputArchive {

    ArrayWrapper external_;
    size_t                  ptr_;

    BinaryOutputArchive(ArrayWrapper external): external_{external}, ptr_{} {}

    template
    void write(const T& x)
    {
        static_assert(std::is_trivially_copyable_v);
        auto data = (const std::byte*)&x;
        TM_CHECK_LE(ptr_ + sizeof(T), external_.count());
        std::copy_n(data, sizeof(T), external_.data() + ptr_);
        ptr_ += sizeof(T);
    }

    template
    void write(const ArrayWrapper& arr)
    {
        static_assert(std::is_trivially_copyable_v);
        auto data = (const std::byte*)arr.data();
        TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count());
        std::copy_n(data, sizeof(T) * arr.count(), external_.data() + ptr_);
        ptr_ += sizeof(T) * arr.count();
    }
};

struct BinaryInputArchive: InputArchive {

    ArrayWrapper external_;
    size_t                  ptr_;

    BinaryInputArchive(ArrayWrapper external): external_{external}, ptr_{} {}

    template
    void read(T& x)
    {
        static_assert(std::is_trivially_copyable_v);
        TM_CHECK_LE(ptr_ + sizeof(T), external_.count());
        std::copy_n(external_.data() + ptr_, sizeof(T), (std::byte*)&x);
        ptr_ += sizeof(T);
    }

    template
    void read(ArrayWrapper&& arr)
    {
        static_assert(std::is_trivially_copyable_v);
        TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count());
        std::copy_n(external_.data() + ptr_, sizeof(T) * arr.count(), (std::byte*)arr.data());
        ptr_ += sizeof(T) * arr.count();
    }
};

template
void save(Archive& ar, const std::vector& xs)
{
    // clang-format off
    ar & xs.size();
    if constexpr (std::is_trivially_copyable_v) {
        ar & ArrayWrapper(xs.data(), xs.size());
    }
    else {
        for (const auto& x : xs) {
            ar & x;
        }
    }
    // clang-format on
}

template
void load(Archive& ar, std::vector& xs)
{
    // clang-format off
    decltype(xs.size()) size;
    ar & size;
    xs.resize(size);

    if constexpr (std::is_trivially_copyable_v) {
        ar & ArrayWrapper(xs.data(), size);
    } else {
        for (size_t i = 0; i < size; ++i) {
            ar & xs[i];
        }
    }
    // clang-format on
}

template
void save(Archive& ar, const std::string& s)
{
    // clang-format off
    ar & s.size();
    ar & ArrayWrapper(s.data(), s.size());
    // clang-format on
}

template
void load(Archive& ar, std::string& s)
{
    // clang-format off
    decltype(s.size()) size;
    ar & size;
    s.resize(size);
    ar & ArrayWrapper(s.data(), size);
    // clang-format on
}

template
void save(Archive& ar, const std::shared_ptr& p)
{
    // clang-format off
    ar & (bool)p;
    if (p) {
        ar & (*p);
    }
    // clang-format on
}

template
void load(Archive& ar, std::shared_ptr& p)
{
    // clang-format off
    bool pred;
    ar & pred;
    if (pred) {
        p = std::make_shared();
        ar & (*p);
    }
}

template
void serdes(Archive& ar, std::array& xs)
{
    // clang-format off
    if constexpr (std::is_trivially_copyable_v) {
        ar & ArrayWrapper(xs.data(), N);
    }
    else {
        for (size_t i = 0; i < N; ++i) {
            ar & xs[i];
        }
    }
    // clang-format on
}

template
void serdes(Archive& ar, std::tuple& tpl)
{
    std::apply([&](auto&... elems) { ((ar & elems), ...); }, tpl);
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/state.h
================================================

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/core.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/layout.h"
#include "src/turbomind/core/tensor.h"
#include 

namespace turbomind {

// Goals:
// 1. constant number of cudaMemcpy / kernel launches
// 2. single stream synchronization / iteration

struct State {

    Tensor data_[2];

    State() = default;

    State(const Layout& layout, DataType dtype, const core::Device& device)
    {
        data_[0] = {layout, dtype, device};
        data_[1] = {layout, dtype, device};
    }

    Tensor& front()
    {
        return data_[0];
    }

    Tensor& back()
    {
        return data_[1];
    }

    void Swap()
    {
        std::swap(data_[0], data_[1]);
    }
};

template
void Warp(const Tensor& a0, int size0, const Buffer_& perm, Tensor b1, Copy& copy)
{
    auto a0_ptr = (const uint8_t*)a0.raw_data();
    auto b1_ptr = (uint8_t*)b1.raw_data();

    const auto vec_size = byte_size(a0.dtype(), a0.stride(0));

    for (int i = 0; i < perm.size(); ++i) {
        if (const int j = perm[i]; TM_LIKELY(j < size0)) {
            copy(a0_ptr + j * vec_size, vec_size, b1_ptr + i * vec_size);
        }
    }
}

template
void Warp(const Tensor& a0, const Tensor& b1, int size0, const Buffer_& perm, Tensor c1, Copy& copy)
{
    auto a0_ptr = (const uint8_t*)a0.raw_data();
    auto b1_ptr = (const uint8_t*)b1.raw_data();
    auto c1_ptr = (uint8_t*)c1.raw_data();

    const auto vec_size = byte_size(a0.dtype(), a0.stride(0));

    for (int i = 0; i < perm.size(); ++i) {
        const uint8_t* src_ptr = TM_LIKELY(perm[i] < size0) ? a0_ptr + perm[i] * vec_size : b1_ptr + i * vec_size;
        copy(src_ptr, vec_size, c1_ptr + i * vec_size);
    }
}

template
void Warp(const Tensor&       src0,
          const Buffer_& offset0,
          int                 size0,
          const Tensor&       src1,
          const Buffer_& offset1,
          const Buffer_& perm0,
          Tensor              dst,
          Buffer_        offsetd,
          Copy&               copy)
{
    auto p_src0 = (const uint8_t*)src0.raw_data();
    auto p_src1 = (const uint8_t*)src1.raw_data();

    const ssize_t vec_size = byte_size(src0.dtype(), src0.stride(0));

    auto p_dst = (uint8_t*)dst.raw_data();

    offsetd[0] = 0;

    for (int i = 0; i < perm0.size(); ++i) {
        const uint8_t* p_src;
        ssize_t        n;
        if (const int j = perm0[i]; TM_LIKELY(j < size0)) {
            p_src = p_src0 + offset0[j] * vec_size;
            n     = offset0[j + 1] - offset0[j];
        }
        else {
            p_src = p_src1 + offset1[i] * vec_size;
            n     = offset1[i + 1] - offset1[i];
        }
        offsetd[i + 1] = offsetd[i] + n;
        copy(p_src, n * vec_size, p_dst + offsetd[i] * vec_size);
    }
}

// d1[i] = a0[perm[i]]:b0[perm[i]] if perm[i] < size0 else c1[i]
// where `a0` has variable size with fixed stride
//       `b0` has fixed size (1)
//       `a1` has variable size
//       `c1` has variable size with fixed stride
template
void Append(const Tensor&       a0,
            const Buffer_& a0_size,
            const Tensor&       b0,
            const Tensor&       c1,
            const Buffer_& c1_offset,
            const Buffer_& perm,
            int                 size0,
            Tensor              d1,
            Buffer_        d1_size,
            Copy&               copy)
{
    auto a0_ptr = (const uint8_t*)a0.raw_data();
    auto b0_ptr = (const uint8_t*)b0.raw_data();
    auto c1_ptr = (const uint8_t*)c1.raw_data();

    auto d1_ptr = (uint8_t*)d1.raw_data();

    TM_CHECK_EQ(a0.stride(0), d1.stride(0));

    const auto stride   = byte_size(a0.dtype(), a0.stride(0));
    const auto vec_size = byte_size(a0.dtype(), a0.stride(1));

    for (int i = 0; i < perm.size(); ++i) {
        if (const int j = perm[i]; TM_LIKELY(j < size0)) {
            uint8_t* out = copy(a0_ptr + j * stride, vec_size * a0_size[j], d1_ptr + i * stride);
            copy(b0_ptr + j * vec_size, vec_size, out);
            d1_size[i] = a0_size[j] + 1;
        }
        else {
            const auto n = c1_offset[i + 1] - c1_offset[i];
            copy(c1_ptr + c1_offset[i] * vec_size, n * vec_size, d1_ptr + i * stride);
            d1_size[i] = n;
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/core/stream.cc
================================================

#include "src/turbomind/core/stream.h"
#include 

namespace turbomind::core {

Stream Stream::create(int priority)
{
    Stream stream;
    stream.impl_ = std::make_shared(priority);
    return stream;
}

void StreamImpl::Wait(const Event& event)
{
    check_cuda_error(cudaStreamWaitEvent(stream_, event));
}

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/stream.h
================================================
#pragma once

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/common.h"

namespace turbomind::core {

class StreamImpl {
public:
    StreamImpl(int priority): stream_{}
    {
        check_cuda_error(cudaStreamCreateWithPriority(&stream_, cudaStreamNonBlocking, priority));
    }

    ~StreamImpl()
    {
        if (auto ec = cudaStreamDestroy(stream_); ec != cudaSuccess) {
            TM_LOG_ERROR(cudaGetErrorString(ec));
        }
        stream_ = {};
    }

    void Sync()
    {
        check_cuda_error(cudaStreamSynchronize(stream_));
    }

    void Wait(const Event& event);

    cudaStream_t handle() const
    {
        return stream_;
    }

public:
    cudaStream_t stream_;
};

class Stream {
public:
    Stream() = default;

    static Stream create(int priority = 0);

    void Sync()
    {
        impl_->Sync();
    }

    void Wait(const Event& event)
    {
        impl_->Wait(event);
    }

    cudaStream_t handle() const
    {
        return TM_CHECK_NOTNULL(impl_)->handle();
    }

    explicit operator cudaStream_t() const
    {
        return handle();
    }

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

    friend bool operator==(const Stream& a, const Stream& b)
    {
        return a.impl_ == b.impl_;
    }

    friend bool operator!=(const Stream& a, const Stream& b)
    {
        return !(a == b);
    }

    friend std::ostream& operator<<(std::ostream& os, const Stream& s)
    {
        os << s.impl_;
        return os;
    }

private:
    shared_ptr impl_;
};

class EventImpl {
public:
    explicit EventImpl(unsigned flags)
    {
        check_cuda_error(cudaEventCreateWithFlags(&event_, flags));
    }

    ~EventImpl()
    {
        if (auto ec = cudaEventDestroy(event_); ec != cudaSuccess) {
            TM_LOG_ERROR(cudaGetErrorString(ec));
        }
    }

    void Record(const Stream& stream)
    {
        check_cuda_error(cudaEventRecord(event_, stream.handle()));
    }

    void Sync() const
    {
        check_cuda_error(cudaEventSynchronize(event_));
    }

    cudaEvent_t handle() const
    {
        return event_;
    }

private:
    cudaEvent_t event_;
};

class Event {
public:
    Event() = default;

    static Event create(bool timing = false)
    {
        Event e{};
        e.impl_ = std::make_shared(timing ? 0 : cudaEventDisableTiming);
        return e;
    }

    void Record(const Stream& stream)
    {
        TM_CHECK_NOTNULL(impl_)->Record(stream);
    }

    void Sync() const
    {
        TM_CHECK_NOTNULL(impl_)->Sync();
    }

    operator cudaEvent_t() const
    {
        return TM_CHECK_NOTNULL(impl_)->handle();
    }

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

private:
    shared_ptr impl_;
};

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/tensor.cc
================================================

#include "src/turbomind/core/tensor.h"
#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/stream.h"

namespace turbomind::core {

std::ostream& operator<<(std::ostream& os, const Tensor& t)
{
    os << t.dtype() << "[" << t.layout() << "]@" << t.buffer_.data_or((void*)nullptr);
    return os;
}

Tensor& TensorMap::at(const std::string& key)
{
    auto it = find(key);
    TM_CHECK(it != end()) << get_out_of_range_msg(key);
    return it->second;
}

std::string TensorMap::get_out_of_range_msg(const std::string& key) const
{
    std::ostringstream oss;
    oss << "Cannot find a tensor of name '" << key << "' in the tensor map (keys: ";
    auto sep = "";
    for (const auto& [k, _] : *this) {
        oss << std::exchange(sep, ", ") << k;
    }
    oss << ")";
    return oss.str();
}

Tensor* TensorMap::try_(const std::string& key)
{
    auto it = find(key);
    if (it != end()) {
        return &it->second;
    }
    return nullptr;
}

void Copy(const Tensor& src, Ref dst_, const Stream& stream)
{
    auto& dst = dst_.get();
    TM_CHECK(src.dtype() == dst.dtype());
    TM_CHECK(src.shape() == dst.shape());
    TM_CHECK(src.is_contiguous());
    TM_CHECK(dst.is_contiguous());
    if (auto size = src.byte_size()) {
        check_cuda_error(cudaMemcpyAsync(dst.raw_data(), src.raw_data(), size, cudaMemcpyDefault, stream.handle()));
    }
}

void Copy(const Tensor& src, Ref dst_)
{
    Copy(src, dst_, Context::stream());
}

void Clear(Ref a_, const Stream& stream)
{
    auto& a = a_.get();
    TM_CHECK(a.is_contiguous());
    if (auto size = a.byte_size()) {
        check_cuda_error(cudaMemsetAsync(a.raw_data(), 0, size, stream.handle()));
    }
}

void Clear(Ref a_)
{
    Clear(a_, Context::stream());
}

#if 0

void Copy(const Tensor& src, Tensor& dst, Stream& stream)
{
    TM_CHECK(src.dtype() == dst.dtype());
    TM_CHECK(src.shape() == dst.shape());

    const DataType dtype = src.dtype();

    auto trivial = [&] {
        const ssize_t bytesize = get_byte_size(dtype, src.size());
        check_cuda_error(cudaMemcpyAsync(dst.raw_data(), src.raw_data(), bytesize, cudaMemcpyDefault, stream.handle()));
    };

    if (src.layout().is_contiguous() && dst.layout().is_contiguous()) {
        return trivial();
    }

    auto a = src.layout();
    auto b = dst.layout();

    vector idxs(a.rank());
    std::iota(idxs.begin(), idxs.end(), 0);
    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //
        return a.stride()[j] < a.stride()[i];
    });

    // innermost dim is not contiguous
    if (a.stride(idxs.back()) > 1 || b.stride(idxs.back()) > 1) {
        return GenericCopy(src, dst, stream);
    }

    a = a.reorder(idxs);
    b = b.reorder(idxs);

    // trivial after reorder (e.g. transposed matrices)
    if (a.is_contiguous() && b.is_contiguous()) {
        return trivial();
    }

    a = a.coalesce();
    b = b.coalesce();

    int rank = std::max(a.rank(), b.rank());

    if (rank > 3) {
        return GenericCopy(src, dst, stream);
    }

    if (a.rank() < rank) {
        a = a.view(b.shape());
    }
    else if (b.rank() < rank) {
        b = b.view(b.shape());
    }

    if (rank == 2) {
        check_cuda_error(cudaMemcpy2DAsync(dst.raw_data(),
                                           get_byte_size(dtype, b.stride(0)),
                                           src.raw_data(),
                                           get_byte_size(dtype, a.stride(0)),
                                           get_byte_size(dtype, a.shape(1)),
                                           a.shape(0),
                                           cudaMemcpyDefault,
                                           stream.handle()));
        return;
    }

    auto [a0, a1] = a.strides(0, 1);
    auto [b0, b1] = b.strides(0, 1);

    // make sure the underlying space is actually a cube [x % (y * z) == 0]
    if (rank == 3 && a0 % a1 == 0 && b0 % b1 == 0) {
        const auto xsz_a = get_byte_size(dtype, a.stride(1));
        const auto xsz_b = get_byte_size(dtype, b.stride(1));
        const auto ysz_a = a0 / a1;
        const auto ysz_b = b0 / b1;

        cudaMemcpy3DParms param{};
        param.srcPtr = make_cudaPitchedPtr((void*)src.raw_data(), xsz_a, xsz_a, ysz_a);
        param.dstPtr = make_cudaPitchedPtr((void*)dst.raw_data(), xsz_b, xsz_b, ysz_b);
        param.extent = make_cudaExtent(get_byte_size(dtype, a.shape(2)), a.shape(1), a.shape(0));
        param.kind   = cudaMemcpyDefault;

        if (auto ec = cudaMemcpy3DAsync(¶m, stream.handle()); ec == cudaSuccess) {
            TM_LOG_WARNING(cudaGetErrorString(ec));
            return;
        }
    }

    return GenericCopy(src, dst, stream);
}

void Copy(const Tensor& src, Tensor&& dst, Stream& stream)
{
    return Copy(src, dst, stream);
}

#endif

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/tensor.cu
================================================


#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/tensor.h"
#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

namespace turbomind::core {

#if 0

namespace kernel {

// This is going to be slow for transposing the innermost dim
template
__global__ void GenericCopy(const T*          a,
                            T*                b,
                            Array stride_a,
                            Array stride_b,
                            Array   shape,
                            int               ndim,
                            int64_t           size)
{
    Index idx = threadIdx.x + (Index)blockIdx.x * blockDim.x;

    if (idx >= size) {
        return;
    }

    Array coord;
    PRAGMA_UNROLL
    for (int i = 0; i < D; ++i) {
        if (i < ndim) {
            auto div = idx / shape[i];
            auto mod = idx % shape[i];
            coord[i] = mod;
            idx      = div;
        }
    }

    int64_t idx_a = 0;
    int64_t idx_b = 0;

    PRAGMA_UNROLL
    for (int i = 0; i < D; ++i) {
        if (i < ndim) {
            idx_a += coord[i] * stride_a[i];
            idx_b += coord[i] * stride_b[i];
        }
    }

    b[idx_b] = a[idx_a];
}

}  // namespace kernel

void GenericCopy(const Tensor& src, Tensor& dst, Stream& stream)
{
    auto a = src.layout();
    auto b = dst.layout();

    // Sort strides ascending
    vector idxs(a.rank());
    std::iota(idxs.begin(), idxs.end(), 0);
    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //
        return a.stride()[i] < a.stride()[j];
    });

    a = a.permute(idxs);
    b = b.permute(idxs);

    a = a.coalesce();
    b = b.coalesce();

    int rank = std::max(a.rank(), b.rank());

    if (a.rank() < rank) {
        a = a.view(b.shape());
    }
    else if (b.rank() < rank) {
        b = b.view(b.shape());
    }

    const DataType dtype = src.dtype();

    int64_t alignment = 16;

    auto align = [&](auto v) { alignment = std::gcd(alignment, v); };

    if (a.stride(0) > 1 || b.stride(0) > 1) {
        alignment = get_byte_size(dtype);
    }

    align(get_byte_size(dtype, a.shape(0)));

    auto data_a = src.raw_data();
    auto data_b = dst.raw_data();

    align(reinterpret_cast(data_a));
    align(reinterpret_cast(data_b));

    for (int i = 1; i < rank; ++i) {
        align(get_byte_size(dtype, a.stride(i)));
        align(get_byte_size(dtype, b.stride(i)));
    }

    const auto vec_size = get_elem_num(alignment, dtype);

    const auto size = a.size() / vec_size;

    int device{};
    check_cuda_error(cudaGetDevice(&device));
    int sm_num{};
    check_cuda_error(cudaDeviceGetAttribute(&sm_num, cudaDevAttrMultiProcessorCount, device));

    auto invoke = [&](auto vec_t, auto index_t, auto d) {
        using T         = decltype(vec_t);
        using Index     = decltype(index_t);
        constexpr int D = d.value;

        Array shape;
        std::fill(shape.begin() + rank, shape.end(), 1);
        std::copy_n(a.shape().data(), rank, shape.data());

        Array stride_a{};
        Array stride_b{};
        std::copy_n(a.stride().data(), rank, stride_a.data());
        std::copy_n(b.stride().data(), rank, stride_b.data());

        if (vec_size > 1) {
            shape[0] /= vec_size;
            for (int i = 0; i < rank; ++i) {
                stride_a[i] /= vec_size;
                stride_b[i] /= vec_size;
            }
        }

        auto func = kernel::GenericCopy;

        int min_waves  = INT_MAX;
        int block_size = 0;
        int grid_size  = 0;

        for (int threads = 256; threads <= 1024; threads *= 2) {
            int blocks = cdiv(size, block_size);
            int n_active{};
            check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_active, func, block_size, 0));
            int waves = cdiv(blocks, n_active * sm_num);
            if (waves < min_waves) {
                min_waves  = waves;
                block_size = threads;
                grid_size  = blocks;
            }
        }

        func<<>>(
            (const T*)data_a, (T*)data_b, stride_a, stride_b, shape, rank, a.size());
    };

    auto invoke_d = [&](auto vec_t, auto idx_t) {
        if (rank <= 2) {
            invoke(vec_t, idx_t, constant<2>{});
        }
        else if (rank <= 4) {
            invoke(vec_t, idx_t, constant<4>{});
        }
        else if (rank <= 8) {
            invoke(vec_t, idx_t, constant<8>{});
        }
        else {
            throw std::runtime_error("not implemented");
        }
    };

    auto invoke_i = [&](auto vec_t) {
        if (size < INT_MAX) {
            invoke_d(vec_t, int{});
        }
        else {
            invoke_d(vec_t, int64_t{});
        }
    };

    switch (alignment) {
        case 16:
            return invoke_i(uint4{});
        case 8:
            return invoke_i(uint2{});
        case 4:
            return invoke_i(uint{});
        case 2:
            return invoke_i(ushort{});
        default:
            return invoke_i(char{});
    }
}

#endif

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/tensor.h
================================================
#pragma once

#include 
#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/layout.h"

namespace turbomind::core {

class Tensor {
public:
    Tensor() = default;

    Tensor(Layout layout, DataType dtype, Device device): Tensor{layout, dtype, Context::alloc(device)} {}

    Tensor(Layout layout, DataType dtype, Allocator& alloc): layout_{std::move(layout)}
    {
        buffer_ = Buffer(layout_.cosize(), dtype, alloc);
    }

    Tensor(Buffer buffer, Layout layout): layout_{std::move(layout)}, buffer_{buffer.slice(0, layout_.cosize())} {}

    Tensor(Buffer buffer): layout_{buffer.size()}, buffer_{buffer} {}

    Tensor(void* data, Layout layout, DataType dtype, Device device):
        Tensor{Buffer{data, layout.cosize(), dtype, device}, layout}
    {
    }

    Tensor(std::shared_ptr data, Layout layout, DataType dtype, Device device):
        Tensor{Buffer{data, layout.cosize(), dtype, device}, layout}
    {
    }

    template
    Tensor(T* data, Layout layout, Device device): Tensor{Buffer{data, layout.cosize(), device}, layout}
    {
    }

    Buffer& buffer() noexcept
    {
        return buffer_;
    }

    const Buffer& buffer() const noexcept
    {
        return buffer_;
    }

    DataType dtype() const
    {
        return buffer_.dtype();
    }

    Device device() const
    {
        return buffer_.device();
    }

    ssize_t size() const noexcept
    {
        return layout_.size();
    }

    ssize_t byte_size() const noexcept
    {
        return turbomind::byte_size(dtype(), size());
    }

    explicit operator bool() const noexcept
    {
        return static_cast(buffer_);
    }

    template
    T* data()
    {
        return buffer_.data();
    }

    template
    const T* data() const
    {
        return const_cast(this)->data();
    }

    void* raw_data()
    {
        return buffer_.raw_data();
    }

    const void* raw_data() const
    {
        return const_cast(this)->raw_data();
    }

    template
    T* data_or(T* other)
    {
        return buffer_.data_or(other);
    }

    template
    const T* data_or(T* other) const
    {
        return buffer_.data_or(other);
    }

    Tensor view(std::vector shape) const
    {
        return Tensor{buffer_, layout_.view(std::move(shape))};
    }

    auto& layout() const noexcept
    {
        return layout_;
    }

    auto& shape() const noexcept
    {
        return layout_.shape();
    }

    auto shape(int i) const
    {
        return layout_.shape(i);
    }

    template
    auto shapes(Is&&... is) const
    {
        return layout_.shapes(((Is &&) is)...);
    }

    auto& stride() const noexcept
    {
        return layout_.stride();
    }

    auto stride(int i) const
    {
        return layout_.stride(i);
    }

    template
    auto strides(Is&&... is) const
    {
        return layout_.strides(((Is &&) is)...);
    }

    bool is_contiguous() const noexcept
    {
        return layout().is_contiguous();
    }

    Tensor slice(std::vector base, std::vector shape) const
    {
        auto&& [layout, offset] = layout_.slice(base, std::move(shape));
        const auto cosize       = layout.cosize();
        return Tensor{buffer_.slice(offset, cosize), std::move(layout)};
    }

    // The outermost dimension
    Tensor slice(ssize_t base, ssize_t size = 1) const
    {
        vector bases(shape().size());
        bases.front() = base;
        vector sizes{this->shape()};
        sizes.front() = size;
        return slice(bases, sizes);
    }

    Tensor borrow() const
    {
        return Tensor{buffer_.borrow(), layout_};
    }

    Tensor squeeze(int dim) const
    {
        return Tensor{buffer_, layout_.squeeze(dim)};
    }

    Tensor transpose(int a, int b) const
    {
        return Tensor{buffer_, layout_.transpose(a, b)};
    }

    Tensor t() const
    {
        TM_CHECK_EQ(ndim(), 2);
        return transpose(0, 1);
    }

    int ndim() const noexcept
    {
        return layout_.rank();
    }

    friend std::ostream& operator<<(std::ostream& os, const Tensor& t);

private:
    Layout layout_;
    Buffer buffer_;
};

inline Tensor empty_like(const Tensor& tensor)
{
    return Tensor{tensor.layout(), tensor.dtype(), tensor.device()};
}

inline Tensor empty_like(const Tensor& tensor, Device device)
{
    return Tensor{tensor.layout(), tensor.dtype(), device};
}

inline Tensor empty_like(const Tensor& tensor, DataType dtype)
{
    return Tensor{tensor.layout(), dtype, tensor.device()};
}

void Copy(const Tensor& src, Ref dst_, const Stream& stream);

void Copy(const Tensor& src, Ref dst_);

void Clear(Ref a_, const Stream& stream);

void Clear(Ref a_);

#if 0

void Copy(const Tensor& src, Tensor&& dst, Stream& stream);

// Launch a kernel to perform the complicated copying
void GenericCopy(const Tensor& src, Tensor& dst, Stream& stream);

Tensor Reshape(const Tensor& t, vector shape);

Tensor Transpoe(const Tensor& t, int dim0, int dim1);

Tensor Permute(const Tensor& t, vector dims);

Tensor Contiguous(const Tensor& t);
#endif

template
struct Tensor_: public Tensor {
    Tensor_() = default;

    Tensor_(Layout layout, Device device): Tensor{std::move(layout), data_type_v, device} {}

    Tensor_(Layout layout, Allocator& alloc): Tensor{std::move(layout), data_type_v, alloc} {}

    Tensor_(Buffer buffer, Layout layout): Tensor{ensure_dtype(std::move(buffer)), std::move(layout)} {}

    Tensor_(T* data, Layout layout, Device device): Tensor{data, std::move(layout), device} {}

    Tensor_(shared_ptr data, Layout layout, Device device):
        Tensor{Buffer{std::move(data), layout.cosize(), data_type_v, device}, layout}
    {
    }

    Tensor_(const Tensor_&) = default;
    Tensor_& operator=(const Tensor_&) = default;

    Tensor_(Tensor_&&) noexcept = default;
    Tensor_& operator=(Tensor_&&) noexcept = default;

    Tensor_(const Tensor& other)
    {
        *static_cast(this) = ensure_dtype(other);
    }
    Tensor_(Tensor&& other) noexcept
    {
        *static_cast(this) = ensure_dtype(std::move(other));
    }

    ssize_t offset(const vector& idxs)
    {
        return layout().offset(idxs);
    }

    T* data() noexcept
    {
        return Tensor::data();
    }

    const T* data() const noexcept
    {
        return Tensor::data();
    }

    T* data_or(T* other)
    {
        return Tensor::data_or(other);
    }

    const T* data_or(T* other) const
    {
        return Tensor::data_or(other);
    }

    constexpr DataType dtype() const noexcept
    {
        return data_type_v;
    }

private:
    template
    static decltype(auto) ensure_dtype(U&& u)
    {
        TM_CHECK_EQ(u.dtype(), data_type_v);
        return (U &&) u;
    }
};

class TensorMap: public std::unordered_map {
public:
    using std::unordered_map::unordered_map;

    Tensor& at(const std::string& key);

    const Tensor& at(const std::string& key) const
    {
        return const_cast(this)->at(key);
    }

    Tensor* try_(const std::string& key);

    const Tensor* try_(const std::string& key) const
    {
        return const_cast(this)->try_(key);
    }

    bool contains(const std::string& key) const
    {
        return find(key) != end();
    }

    void produce(const std::string& key, Tensor value)
    {
        TM_CHECK(emplace(key, std::move(value)).second);
    }

    Tensor try_consume(const std::string& key)
    {
        if (auto it = find(key); it != end()) {
            auto value = std::move(it->second);
            erase(it);
            return value;
        }
        return Tensor{};
    }

    Tensor consume(const std::string& key)
    {
        auto value = try_consume(key);
        TM_CHECK(value) << get_out_of_range_msg(key);
        return value;
    }

private:
    std::string get_out_of_range_msg(const std::string& key) const;
};

// clang-format off
template, int> = 0>
void save(Archive& ar, const T& tensor)
{
    TM_CHECK(tensor.size() == 0 || tensor.is_contiguous());
    ar & tensor.buffer(); // implicit convert to tensor
    ar & tensor.layout();
}

template
void load(Archive& ar, Tensor& tensor)
{
    Buffer buffer;
    Layout layout;
    ar & buffer;
    ar & layout;
    tensor = Tensor{std::move(buffer), std::move(layout)};
}


template
void save(Archive& ar, const TensorMap& map)
{
    ar & map.size();
    for (const auto& [k, t]: map) {
        ar & k;
        ar & t;
    }
}

template
void load(Archive& ar, TensorMap& map)
{
    map.clear();
    decltype(map.size()) size;
    ar & size;
    for (int i = 0; i < size; ++i) {
        std::string k;
        Tensor   t;
        ar & k;
        ar & t;
        map.emplace(std::move(k), std::move(t));
    }
}
// clang-format on

}  // namespace turbomind::core


================================================
FILE: src/turbomind/core/test_core.cc
================================================

#include 

#include "src/turbomind/core/core.h"

#include "catch2/catch_test_macros.hpp"

using namespace turbomind;

TEST_CASE("test check", "[check]")
{
    int zero = 0;

    TM_CHECK(!zero);

    TM_CHECK_EQ(42, 42) << "Ok";
    TM_CHECK_NE(42, 24) << "Ok";
    TM_CHECK_GE(50, 42) << "Ok";
    TM_CHECK_GT(50, 42) << "Ok";
    TM_CHECK_LE(42, 50) << "Ok";
    TM_CHECK_LT(42, 50) << "Ok";

    if (0) {
        TM_CHECK(zero);
        TM_CHECK_EQ(42, 43) << "Not "
                            << "Ok";
    }

    int  x = 42;
    auto p = TM_CHECK_NOTNULL(&x);
    REQUIRE(p == &x);

    if (0) {
        int* y{};
        TM_CHECK_NOTNULL(y);
        TM_CHECK_NOTNULL(std::shared_ptr{});
    }

    auto y = TM_CHECK_NOTNULL(std::make_shared(42));
    REQUIRE(*y == 42);

    TM_CHECK(y);
}

TEST_CASE("test allocator", "[allocator]")
{

    using core::Allocator;
    using core::Stream;

    Allocator a;
    REQUIRE(!a);

    Allocator b{kCPU};
    REQUIRE(b);
    REQUIRE(a != b);
    REQUIRE(b->device() == kCPU);
    Stream s{};
    REQUIRE(!b->stream());

    // std::vector v(1 << 20);
    // std::iota(v.begin(), v.end(), 0);

    // auto p = (int*)b->allocate(sizeof(int) * v.size());
    // std::iota(p, p + v.size(), 0);

    // REQUIRE(v == std::vector(p, p + v.size()));
}

TEST_CASE("test context", "[context]")
{
    using core::Context;
    using core::ContextGuard;
    using core::Stream;
    using core::Allocator;

    Stream s0 = Stream::create();

    ContextGuard g0{s0, Allocator{kCPU}};

    REQUIRE(Context::stream());
    REQUIRE(Context::stream() == s0);

    auto a0 = Context::host_alloc();

    {
        Allocator a1(Context::stream(), false);  // device allocator
        REQUIRE(a1->device().type == kDEVICE);

        ContextGuard g1{a1};

        REQUIRE(Context::stream() == s0);
        REQUIRE(Context::device_alloc() == a1);
        REQUIRE(Context::host_alloc() == a0);

        {
            ContextGuard g2{Stream::create(), Allocator(kDEVICE)};
            REQUIRE(Context::device_alloc() != a1);
            REQUIRE(Context::stream() != s0);
        }

        REQUIRE(Context::stream() == s0);
        REQUIRE(Context::device_alloc() == a1);
    }

    REQUIRE(Context::stream() == s0);
}

TEST_CASE("test basic buffer", "[buffer]")
{
    using core::Buffer;
    using core::Buffer_;
    using core::Allocator;

    Buffer a;
    REQUIRE(!a);

    Buffer b;
    REQUIRE(!b);
    REQUIRE(a == b);

    std::vector v{0, 1, 2, 3, 4, 5, 6, 7};

    SECTION("reference into v")
    {
        b = Buffer(v.data(), v.size(), kCPU);
        REQUIRE(b.data() == v.data());
        REQUIRE(b.raw_data() == v.data());
    }
    SECTION("shared ownership")
    {
        auto x = std::shared_ptr(new int[v.size()]);
        std::copy(v.begin(), v.end(), x.get());
        b = Buffer(x, v.size(), data_type_v, kCPU);
        REQUIRE(b.data() == x.get());
        REQUIRE(b.raw_data() == x.get());
    }
    SECTION("allocation")
    {
        Allocator alloc{kCPU};
        b = Buffer(v.size(), data_type_v, alloc);
        std::copy(v.begin(), v.end(), b.data());
    }

    REQUIRE(b);
    REQUIRE(b.size() == v.size());
    REQUIRE(b.dtype() == data_type_v);
    REQUIRE(b.byte_size() == sizeof(int) * v.size());
    auto c = b;
    REQUIRE(c == b);
    REQUIRE(b == c);
    REQUIRE(a != b);
    REQUIRE(b != a);
    REQUIRE(std::vector(b.data(), b.data() + b.size()) == v);

    auto s = b.slice(3, 2);
    REQUIRE(s.size() == 2);
    REQUIRE(s.raw_data() == b.data() + 3);

    Buffer_ x;
    Buffer_ y = Buffer{data_type_v};

    Buffer z = Buffer_(1024, kCPU);

    x = z;

    for (int i = 0; i < z.size(); ++i) {
        x[i] = i;
    }

    std::vector ref(1024);
    std::iota(ref.begin(), ref.end(), 0);
    REQUIRE(std::vector(x.begin(), x.end()) == ref);

    Buffer e;
    REQUIRE(!e.data_or((void*)0));
    REQUIRE(!e.data_or(nullptr));

    Buffer_ w;
    REQUIRE(!w.data_or(nullptr));
    REQUIRE(!std::as_const(w).data_or(nullptr));

    w = {1024, kCPU};
    REQUIRE(w.raw_data());
    REQUIRE(std::as_const(w).raw_data());
}

TEST_CASE("test buffer view", "[buffer]")
{
    using core::Buffer;

    std::vector v{0, 1, 2, 3, 4, 5, 6, 7};

    Buffer b(v.data(), v.size(), kCPU);

    auto c = b.slice(2, 4);
    REQUIRE(c.size() == 4);
    REQUIRE(c.raw_data() == b.data() + 2);

    std::cout << c << std::endl;

    auto d = c.view();

    REQUIRE(d.size() == c.size() * 2);
    REQUIRE(d.raw_data() == c.raw_data());
}

TEST_CASE("test layout", "[layout]")
{
    using core::Layout;

    Layout a;  // default ctor
    REQUIRE(a.size() == 0);
    REQUIRE(a.cosize() == 0);

    Layout b({20, 50});
    REQUIRE(b.size() == 1000);
    REQUIRE(b.cosize() == b.size());
    REQUIRE(to_string(b) == "(20,50):(50,1)");

    Layout c = b.coalesce();
    REQUIRE(c.size() == b.size());
    REQUIRE(c.cosize() == b.cosize());
    REQUIRE(to_string(c) == "(1000):(1)");

    Layout v = b.view({50, 20});
    REQUIRE(v.size() == b.size());
    REQUIRE(v.cosize() == b.cosize());
    REQUIRE(to_string(v) == "(50,20):(20,1)");

    v = b.view({25, -1});
    REQUIRE(to_string(v) == "(25,40):(40,1)");

    v = b.view({5, -1, 5});
    REQUIRE(to_string(v) == "(5,40,5):(200,5,1)");

    v = b.view({-1, 20, 10, 1});
    REQUIRE(to_string(v) == "(5,20,10,1):(200,10,1,1)");

    REQUIRE(to_string(v.coalesce()) == "(1000):(1)");

    auto [s, offset] = b.slice({10, 20}, {-1, -1});
    REQUIRE(to_string(s) == "(10,30):(50,1)");
    REQUIRE(offset == 520);

    v = s.view({2, -1, 3, 10});
    std::cout << v << std::endl;

    std::cout << v.coalesce() << std::endl;

    // v = s.view({30, 10});
    // std::cout << v << std::endl;
}

TEST_CASE("test tensor", "[tensor]")
{
    using core::Tensor;
    using core::Tensor_;
    using core::Allocator;

    Tensor a;
    REQUIRE(!a);

    Tensor_ b{{10, 20}, kCPU};
    Tensor_ c = b.slice(0, 5);

    std::cout << b << std::endl;

    REQUIRE(c.shape() == std::vector{5, 20});
    REQUIRE(c.data() == b.data());

    auto d = b.view({2, -1, 10});
    REQUIRE(d.shape() == std::vector{2, 10, 10});

    // this is typed
    Tensor_ x = Tensor_{};
    // while being empty
    REQUIRE(!x);

    if (0) {
        // empty Tensor has invalid type
        Tensor_ x = Tensor{};
    }
    a = {};
    x = {};

    Tensor y = core::Buffer{100, kInt32, kCPU};
    REQUIRE(y.ndim() == 1);
    REQUIRE(y.shape(0) == 100);
}


================================================
FILE: src/turbomind/engine/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)

add_library(engine STATIC
    gateway.cc
    request.cc
    request_queue.cc
    model_request.cc
    model_executor.cc
    engine.cc
    )
target_link_libraries(engine PRIVATE xgrammar core)
set_property(TARGET engine PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET engine PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)


================================================
FILE: src/turbomind/engine/batch.h
================================================

#pragma once

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/request.h"

namespace turbomind {

enum class BatchOp
{
    kAdd,      //  Se ->  Rc         H
    kSetup,    //  Rc -> (B  -> D)   H2D
    kPrepare,  // (D  ->  St)        D
    kForward,  //  St ->  St         D
    kUnprep,   // (St ->  D)         D
    kFetch,    // (D  ->  B)         D2H
    kUpdate,   //  B  ->  Rc         H
    kDel,      //  Rc ->  Se         H
};

// Se -> Rc -> (B -> D) -> St -> (D -> B) -> Rc -> Se

/*
Se -> Rc                   (add: rc)
    Rc -> B
        (B -> D)           (setup: rc, d, copy)
            (D -> St)
                St -> St   (forward)
            (St -> D)
        (D -> B)
    B -> Rc                (sync)
Rc -> Se                   (del: rc)
*/

struct BatchData {

    explicit BatchData(int phase): self{this}, phase{phase}
    {
        ready = Event::create();
        done  = Event::create();
        next  = Event::create();
    }

    BatchData(const BatchData&)     = delete;
    BatchData(BatchData&&) noexcept = delete;
    BatchData& operator=(const BatchData&) = delete;
    BatchData& operator=(BatchData&&) noexcept = delete;

    BatchData* self;

    const int phase;

    int bs0 = 0;
    int bsz = 0;

    Buffer_ perm;

    std::vector> rc;

    std::vector local_token_num;
    int              global_token_num = 0;

    Event ready;
    Event done;
    Event next;

    std::promise promise;

    Buffer buf()
    {
        return Buffer{&self, 1, kCPU};
    }

    void Notify()
    {
        next.Record(core::Context::stream());
        promise.set_value(next);
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/engine.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 

#include "nvtx3/nvToolsExt.h"

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/engine/engine.h"
#include "src/turbomind/engine/model_executor.h"
#include "src/turbomind/engine/request.h"

#include "src/turbomind/core/copy.h"
#include "src/turbomind/models/language_model.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/metrics.h"

// #include "dbg.h"

namespace turbomind {

using std::shared_ptr;
using std::unique_ptr;
using std::vector;

struct RequestData {
    vector> infer;  // incoming inference request
    vector> kill;   // incoming kill request

    vector cancel;  // canceled indices in current batch
    bool        abort;
};

template
void serdes(Archive& ar, RequestData& r)
{
    ar& r.infer;
    ar& r.kill;
    ar& r.cancel;
    ar& r.abort;
}

struct Engine::Impl {

    using Requests = vector>;
    using Signal   = std::function;

    Impl(DataType      dtype,
         EngineParam   param,
         LanguageModel model,
         Context&      ctx,
         Gateway&      gateway,
         int           device_id,
         int           queue_id,
         int           phases);

    void CreateSequenceManager();

    void InternalThreadEntry();

    void Validate(Requests& infer_rs, Requests& kill_rs);

    void Kill(const Requests& rs, vector& signals);

    vector GetCanceled();

    void Cancel(vector& indices, vector& signals);

    void Accept(const Requests& rs, vector& signals);

    void Interrupt(RequestCache& c);

    // Allocation of memory / compute resources
    void Schedule();

    // intiailize RC from `Sequence`
    void Setup(BatchData& d);

    // Sync vars from batch output to RC
    void Update(BatchData& d, std::vector& signals);

    void Run(BatchOp op, int phase, Ref env)
    {
        model_.Run(op, phase, env);
    }

    void Start()
    {
        internal_thread_ = std::thread(&Impl::InternalThreadEntry, this);
        executor_.Start();
    }

    void UpdateScheduleMetrics();

    ~Impl();

    const DataType    dtype_;
    const EngineParam param_;

    Gateway& gateway_;

    comm::HostComm& tp_group_;
    comm::HostComm& dp_group_;

    const int tp_rank_;
    const int dp_rank_;
    const int dp_size_;

    const int device_id_;
    const int queue_id_;

    const int async_;

    int& is_warm_up_;

    unique_ptr seq_mgr_;

    Queue> inbound_;
    Queue> outbound_;

    LanguageModel model_;
    ModelExecutor executor_;

    std::thread internal_thread_;

    int session_len_trunc_;

    shared_ptr metrics_;

    struct State {
        vector> rc;
        vector                      perm;

        int bs0     = 0;
        int active  = 0;
        int finish  = 0;
        int swapout = 0;

        int size() const noexcept
        {
            return rc.size();
        }
    };

    vector states_;

    struct Data {
    };
    vector data_;

    // staging buffers
    Buffer_ block_ptrs_buf_;
    Buffer_   block_ptrs_offsets_buf_;
};

Engine::Impl::~Impl()
{
    TM_LOG_INFO(__PRETTY_FUNCTION__);
    inbound_.close();
    outbound_.close();
    if (internal_thread_.joinable()) {
        internal_thread_.join();
    }
    executor_ = {};
}

Engine::Impl::Impl(DataType      dtype,
                   EngineParam   param,
                   LanguageModel model,
                   Context&      ctx,
                   Gateway&      gateway,
                   int           device_id,
                   int           queue_id,
                   int           phases):
    dtype_{dtype},
    param_{param},
    gateway_{gateway},
    tp_group_{ctx.comm.h_tp_group},
    dp_group_{ctx.comm.h_dp_group},
    tp_rank_{tp_group_->rank()},
    dp_rank_{dp_group_->rank()},
    dp_size_{dp_group_->n_ranks()},
    device_id_{device_id},
    queue_id_{queue_id},
    async_{phases > 1},
    is_warm_up_{*ctx.is_warm_up},
    model_{std::move(model)}
{
    states_.emplace_back();

    for (int i = 0; i < phases; ++i) {
        data_.emplace_back();
    }

    executor_ = ModelExecutor{model_, ctx, device_id_, outbound_, inbound_};

    CreateSequenceManager();  // initializes `session_len_trunc_`

    const ssize_t max_batch_block_num =
        param.max_batch_size * cdiv(session_len_trunc_, model_.attn_param().cache_block_seq_len);
    block_ptrs_buf_         = {max_batch_block_num, kCPUpinned};
    block_ptrs_offsets_buf_ = {param.max_batch_size + 1, kCPUpinned};
}

void Engine::Impl::CreateSequenceManager()
{
    const auto cache_block_seq_len = model_.attn_param().cache_block_seq_len;

    const auto& model_param = model_.model_param();

    const auto get_free_size = [&] {  //
        size_t free{}, total{};
        check_cuda_error(cudaMemGetInfo(&free, &total));
        return AllReduce(tp_group_, free, comm::RedOp::kMin);
    };

    seq_mgr_ = std::make_unique(model_param,
                                                 dtype_,
                                                 cache_block_seq_len,
                                                 param_.attn_tp_size,
                                                 param_.max_batch_size,
                                                 param_.cache_max_block_count,
                                                 param_.cache_chunk_size,
                                                 param_.enable_prefix_caching,
                                                 tp_rank_,
                                                 param_.attn_cp_size,
                                                 core::Context::alloc(kDEVICE),
                                                 get_free_size);

    const auto max_cached_tokens = seq_mgr_->max_block_count() * (size_t)cache_block_seq_len * param_.attn_cp_size;
    session_len_trunc_           = std::min(max_cached_tokens, (size_t)param_.session_len);
    TM_LOG_INFO("max cached tokens: %lld", max_cached_tokens);
    if (session_len_trunc_ != param_.session_len) {
        TM_LOG_WARNING("`session_len` truncated to %d due to limited KV cache memory", session_len_trunc_);
    }
}

void Engine::Impl::Validate(Requests& infer_reqs, Requests& kill_reqs)
{
    std::pmr::monotonic_buffer_resource    mbr;
    std::pmr::unordered_map occur(&mbr);

    const bool has_linear_attention = HasLinearAttention(model_.model_param());

    auto count = [&occur](const auto& reqs) {
        for (const auto& r : reqs) {
            ++occur[r->id];
        }
    };

    auto validate = [&](auto& reqs, const char* type, bool is_infer) {
        for (const auto& r : reqs) {
            if (occur[r->id] > 1) {
                TM_LOG_ERROR("Skip conflicting %s request for ID %lu", type, r->id);
                r->ec = Request::kConflict;
            }
            if (!r->ec && is_infer && has_linear_attention && !r->session.end_flag) {
                TM_LOG_ERROR("Skip inconsistent %s request for ID %lu. Linear attention only supports stateless "
                             "requests",
                             type,
                             r->id);
                r->ec = Request::kInconsistency;
            }
            if (param_.enable_prefix_caching) {
                if (r->session.step != 0) {
                    // Prefix caching is incompatible with interactive mode
                    TM_LOG_ERROR("Skip inconsistent %s request for ID %lu step %d", type, r->id, r->session.step);
                    r->ec = Request::kInconsistency;
                }
                else if (r->gen_cfg.output_logits == GenerationConfig::kAll
                         || r->gen_cfg.output_last_hidden_state == GenerationConfig::kAll) {
                    // Prefix caching is incompatible with outputting all tokens' logits or last_hidden_state
                    TM_LOG_ERROR("Skip inconsistent %s request for ID %lu. It cannot output logits or "
                                 "last_hidden_states for all tokens",
                                 type,
                                 r->id);
                    r->ec = Request::kInconsistency;
                }
            }
        }
    };

    for (const auto& s : states_) {
        for (int i = 0; i < s.size(); ++i) {
            if (s.rc[i]) {
                ++occur[s.rc[i]->req->id];
            }
        }
    }

    count(kill_reqs);
    count(infer_reqs);

    validate(kill_reqs, "kill", false);
    validate(infer_reqs, "infer", true);

    // New requests that never get a chance to start
    for (auto& r : infer_reqs) {
        if (r && r->cancel_flag.load(std::memory_order_acquire) == -1) {
            r->ec = Request::kCancel;
        }
    }
}

vector Engine::Impl::GetCanceled()
{
    auto& s = states_.at(0);

    vector idxs;
    for (int i = 0; i < s.size(); ++i) {  // current batch
        const auto& r = s.rc[i];
        if (r && r->req->cancel_flag.load(std::memory_order_acquire) == -1) {
            idxs.push_back(i);
        }
    }
    return idxs;
}

void Engine::Impl::Kill(const Requests& kills, vector& signals)
{
    for (auto& r : kills) {
        if (r) {
            int ec = r->ec;
            if (!ec) {
                if (!seq_mgr_->Erase(r->id)) {
                    ec = Request::kInvalid;
                }
            }
            signals.push_back([=] { r->end_cb ? r->end_cb(ec) : void(); });
        }
    }
}

void Engine::Impl::Interrupt(RequestCache& c)
{
    auto& s = *TM_CHECK_NOTNULL(c.seq);
    if (c.req->session.end_flag) {
        if (!is_warm_up_ && s.status != Sequence::kCached) {  // At least `Locked` status is required for caching
            seq_mgr_->CacheGeneration(s);
        }
        TM_CHECK(seq_mgr_->Erase(c.req->id));
    }
    else {
        if (s.recurrent_states && c.seq_len != s.cache_len) {
            TM_LOG_WARNING(
                "[Engine][Interrupt] Invalidating cache for ID %llu due to linear-state/cache mismatch (%d vs %d)",
                s.id,
                c.seq_len,
                s.cache_len);
            seq_mgr_->InvalidateStatesAndCache(s);
        }
        else {
            seq_mgr_->UpdateAndSetUnlock(s);
        }
    }
    c.seq = nullptr;
}

void Engine::Impl::Cancel(vector& indices, vector& signals)
{
    auto& s = states_.at(0);
    for (const auto& i : indices) {
        auto& c = TM_CHECK_NOTNULL(s.rc[i]);
        c->done = true;
        Interrupt(*c);
        signals.push_back([r = std::move(c->req), l = c->seq_len] {  //
            UpdateState(*r, Request::kCancel, l);
        });
        c.reset();
        s.finish += 1;
    }
}

void Engine::Impl::Accept(const Requests& rs, vector& signals)
{
    auto& s = states_.at(0);

    vector> incoming;
    incoming.reserve(rs.size());

    for (const auto& r : rs) {

        if (r->ec) {
            signals.push_back([r] { UpdateState(*r, r->ec, 0); });
            continue;
        }

        const int input_len = r->inputs.at("input_ids").shape(0);

        if (input_len > session_len_trunc_) {
            signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });
            continue;
        }

        auto ptr = r->session.start_flag ? seq_mgr_->Create(r->id) : seq_mgr_->Get(r->id);
        if (!ptr) {
            signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); });
            continue;
        }

        const int step = [&] {
            int s = r->session.step;
            if (s < 0) {
                s = ptr->tokens.size();
            }
            else if (s > ptr->tokens.size()) {
                if (tp_rank_ == 0) {
                    TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id);
                }
                s = ptr->tokens.size();
            }
            return s;
        }();

        if (step + input_len > session_len_trunc_) {
            signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });
            continue;
        }

        if (step && param_.enable_prefix_caching) {
            // step not supported in prefix-caching mode
            signals.push_back([r] { UpdateState(*r, Request::kInconsistency, 0); });
            continue;
        }

        auto& seq = *ptr;
        seq_mgr_->AcquireLinearStateSlot(seq);

        if (seq.recurrent_states) {
            if (step != seq.cache_len) {
                signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); });
                continue;
            }
        }

        auto c = std::make_unique(r, seq);

        if (step < seq.tokens.size()) {
            seq.tokens.resize(step);
            seq.cache_len = std::min(seq.cache_len, step);
        }

        c->step0 = step;

        // const int* input_ids = r->inputs.at("input_ids").data();
        auto& input_ids = r->inputs.at("input_ids");

        int* token_ids = c->token_ids = r->output_ids.data();

        /// TODO: move this somewhere else
        token_ids = std::copy_n(seq.tokens.data(), seq.tokens.size(), token_ids);
        token_ids = std::copy_n(input_ids.data(), input_len, token_ids);

        c->prompt_len = c->seq_len = token_ids - c->token_ids;  // all known tokens

        // Only prefix cache needs prompt data
        if (param_.enable_prefix_caching && input_len && r->session.start_flag) {
            seq.prompt.insert(seq.prompt.end(), input_ids.data(), input_ids.data() + input_len);
        }

        // dbg(seq.cache_len, seq.tokens.size(), input_len, c->seq_len);

        int max_seq_len = c->prompt_len + c->gen_cfg.max_new_tokens;
        if (max_seq_len > session_len_trunc_) {
            max_seq_len = session_len_trunc_;
            if (tp_rank_ == 0) {
                const int trunc_output_len = max_seq_len - c->prompt_len;
                // clang-format off
                TM_LOG_WARNING("[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d",
                    (long)seq.id, c->prompt_len, c->gen_cfg.max_new_tokens, session_len_trunc_, trunc_output_len);
                // clang-format on
            }
        }
        c->max_seq_len = max_seq_len;

        incoming.push_back(std::move(c));
    }

    Buffer_ buf(incoming.size(), kCPU);
    for (int i = 0; i < incoming.size(); ++i) {
        buf[i] = incoming[i].get();
    }

    // This includes checks from all modules handling `Add` operation
    Run(BatchOp::kAdd, -1, TensorMap{{"requests", buf}});

    for (auto& x : incoming) {
        if (x->status == 0) {
            s.rc.push_back(std::move(x));
        }
        else {
            Interrupt(*x);
            signals.push_back([r = x->req, ec = x->status] {  //
                UpdateState(*r, ec, 0);
            });
        }
    }
}

void Engine::Impl::Schedule()
{
    auto& s = states_.at(0);

    vector  sequences;
    vector status;
    vector              context_length;
    vector              alpha;
    vector         priorities;
    vector    cache;
    vector              inv;

    for (int i = 0; i < s.size(); ++i) {
        // skip invalid positions
        if (const auto& c = s.rc[i]) {
            cache.push_back(c.get());
            sequences.push_back(c->seq);
            status.push_back(c->seq->status);
            priorities.push_back(c->req->unique_id);
            context_length.push_back(c->seq_len + c->beta /* plus draft tokens */);
            alpha.push_back(c->alpha);
            TM_CHECK(c->seq->status == Sequence::kActive || c->alpha == 0) << c->seq->status << " " << c->alpha;
            inv.push_back(i);
            c->input_len = c->history_len = 0;
            // dbg(c->request->id, c->seq_len, c->sequence.cache_len, c->alpha, c->beta, c->is_decoding,
            // c->is_generate);
        }
    }

    // dbg("Schedule");

    seq_mgr_->Materialize(
        sequences, context_length, alpha, priorities, param_.max_forward_token_num, param_.max_context_token_num);

    vector idxs(sequences.size());
    std::iota(idxs.begin(), idxs.end(), 0);

    subrange active{idxs.begin(), std::stable_partition(idxs.begin(), idxs.end(), [&](int i) {
                        return sequences[i]->status == Sequence::kActive;  // IS active
                    })};

    TM_CHECK(sequences.empty() || !active.empty()) << "No enough blocks";

    if (is_warm_up_) {
        // Avoid extra iteration for warm up request in async mode (force inactivate)
        active = {active.begin(), std::stable_partition(active.begin(), active.end(), [&](int i) {  //
                      return alpha[i] == 0;
                  })};
    }

    subrange inactive{active.end(), idxs.end()};

    subrange existing{active.begin(), std::stable_partition(active.begin(), active.end(), [&](int i) {
                          return status[i] == Sequence::kActive;  // WAS active in active
                      })};

    subrange swap_in{existing.end(), active.end()};

    subrange swap_out{inactive.begin(), std::stable_partition(inactive.begin(), inactive.end(), [&](int i) {
                          return status[i] == Sequence::kActive;  // WAS active in inactive
                      })};

    // |<-- existing -->|<-- swap-in -->|<- swap-out ->|
    // |<----------- active ----------->|<------- inactive ----->|

    for (auto i : swap_in) {
        cache[i]->autoregres = {};
        cache[i]->generating = {};
    }

    if (param_.enable_metrics) {
        for (auto i : swap_in) {
            if (auto& m = cache[i]->req->metrics; TM_LIKELY(m)) {
                int64_t expected = 0;
                m->scheduled_time.compare_exchange_strong(
                    expected, RequestMetrics::timestamp(), std::memory_order_relaxed);
            }
        }
    }

    for (auto i : existing) {
        if (cache[i]->generating) {
            cache[i]->autoregres = true;
        }
    }

    for (auto i : active) {
        auto& s = *sequences[i];
        auto& c = *cache[i];
        if (s.cache_len + c.alpha + s.input_length == c.seq_len + c.beta) {
            c.generating = true;
        }
    }

    // move partially prefilled sequences to the back
    subrange partial{std::stable_partition(active.begin(), active.end(), [&](int i) { return cache[i]->generating; }),
                     active.end()};
    TM_CHECK_LE(partial.size(), 1);

    // dbg(inv);

    vector> rc(idxs.size());
    vector                      perm(idxs.size());
    for (int i = 0; i < idxs.size(); ++i) {
        perm[i] = inv[idxs[i]];              // inverse map to original indices
        rc[i]   = std::move(s.rc[perm[i]]);  // warp the request cache
    }
    s.rc.swap(rc);
    s.perm.swap(perm);

    for (auto& c : s.rc) {
        /// ! input_length not updated for inactive seqs
        c->input_len   = c->seq->input_length;
        c->history_len = c->seq->cache_len;
        // dbg(c->request->id,
        //     c->seq_len,
        //     c->history_len,
        //     c->input_len,
        //     c->alpha,
        //     c->beta,
        //     c->is_decoding,
        //     c->is_generate);
    }

    s.bs0     = std::exchange(s.active, active.size());
    s.swapout = swap_out.size();
    s.finish  = 0;
}

void Engine::Impl::Setup(BatchData& d)
{
    auto& st = states_.at(0);

    d.rc.resize(st.active);
    std::copy_n(st.rc.begin(), st.active, d.rc.begin());

    block_ptrs_offsets_buf_[0] = 0;
    auto block_ptrs            = block_ptrs_buf_.data();
    for (int i = 0; i < st.active; ++i) {
        const auto& s                  = *st.rc[i]->seq;
        block_ptrs_offsets_buf_[i + 1] = block_ptrs_offsets_buf_[i] + s.blocks.size();
        block_ptrs = std::transform(s.blocks.cbegin(), s.blocks.cend(), block_ptrs, [&](int block_id) {
            return seq_mgr_->GetBlockPtr(block_id);
        });
    }

    d.bs0 = st.bs0;
    d.bsz = st.active;

    d.perm = {d.bsz, kCPU};
    std::copy_n(st.perm.data(), d.bsz, d.perm.data());

    // dbg(d.bs0, d.bsz, d.perm);

    BatchCopy copy{};

    TensorMap env{{"batch", d.buf()},
                  {"copy", copy.buf()},
                  {"block_ptrs_offsets", block_ptrs_offsets_buf_},
                  {"block_ptrs", block_ptrs_buf_}};

    Run(BatchOp::kSetup, d.phase, env);

    // dbg(copy);
    copy.Run();

    d.local_token_num.resize(dp_size_);
    d.local_token_num[dp_rank_] = *env.at("token_num").data();
    if (dp_size_ > 1) {
        AllGather(dp_group_, d.local_token_num.data(), 1);
    }
    d.global_token_num = std::accumulate(d.local_token_num.begin(), d.local_token_num.end(), 0);
    // dbg(dp_group_->rank(), d.local_token_num, d.global_token_num);
}

void Engine::Impl::Update(BatchData& b, std::vector& signals)
{
    auto& s = states_.at(0);

    BatchCopy copy;

    TensorMap env{{"batch", b.buf()}, {"copy", copy.buf()}};

    // Copy outputs to host buffers
    Run(BatchOp::kFetch, b.phase, env);

    copy.Run();

    core::Context::stream().Sync();

    //
    Run(BatchOp::kUpdate, b.phase, env);

    Buffer_ finished        = env.at("finished").buffer();
    Buffer_ generating      = env.at("generating").buffer();
    Buffer_  output_ids      = env.at("output_ids").buffer();
    Buffer_  sequence_length = env.at("sequence_length").buffer();

    env = {};

    vector sequences_to_cache;

    for (int i = 0; i < b.rc.size(); ++i) {
        // In async mode, `seq` may be nullptr when the request is done
        if (auto& c = *b.rc[i]; c.seq) {
            if (auto& s = *c.seq; generating[i]) {
                c.token_ids[c.seq_len] = output_ids[i];
                c.seq_len              = sequence_length[i];
                s.cache_len            = sequence_length[i] - 1;
                if (const int new_tokens = c.seq_len - s.tokens.size()) {
                    s.tokens.insert(s.tokens.end(), c.token_ids + c.seq_len - new_tokens, c.token_ids + c.seq_len);
                }
                if (TM_UNLIKELY(finished[i])) {
                    signals.push_back([r = c.req, l = c.seq_len] {  //
                        UpdateState(*r, Request::kFinish, l);
                    });
                }
                else if (c.req->stream_output) {
                    signals.push_back([r = c.req, l = c.seq_len] {  //
                        UpdateState(*r, Request::kOk, l);
                    });
                }
            }
            else {
                s.cache_len = sequence_length[i];
            }
            c.done |= finished[i];
            if (c.seq->status != Sequence::kCached) {  // At least `Locked` status is required for caching
                sequences_to_cache.push_back(c.seq);
            }
            // dbg(c.seq_len, c.sequence.cache_len, c.alpha, c.beta, c.is_decoding, c.is_generate);
        }
    }

    if (!is_warm_up_) {
        seq_mgr_->CachePrompt(sequences_to_cache, sequences_to_cache.size());
    }

    b.rc.clear();

    if (async_) {
        const int size = s.active + s.swapout;
        for (int i = 0; i < size; ++i) {
            auto& c = *s.rc[i];
            if (i < s.active) {
                c.alpha = c.input_len;
                c.beta  = c.generating;
            }
            else {
                // Just got swaped-out
                c.alpha = c.beta = 0;
            }
        }
    }

    for (auto& x : s.rc) {
        if (TM_UNLIKELY(x->done)) {
            Interrupt(*x);
            x.reset();
            s.finish += 1;
        }
    }
}

void Engine::Impl::InternalThreadEntry()
{
    check_cuda_error(cudaSetDevice(device_id_));

    auto stream = Stream::create();

    core::ContextGuard ctx{stream, Allocator(kCPU), Allocator(stream, false)};

    unique_ptr d = std::make_unique(0);

    for (unsigned i = 1; i < data_.size(); ++i) {
        inbound_.push(std::make_unique(i));
    }

    while (true) {

        shared_ptr rs;

        auto& st = states_.at(0);

        if (tp_rank_ == 0) {
            rs = std::make_shared();

            const int  n_free   = param_.max_batch_size - st.size() + st.finish;
            const bool blocking = n_free == param_.max_batch_size;

            gateway_.pop(rs->infer, rs->kill, n_free, blocking, rs->abort, dp_group_, queue_id_);

            Validate(rs->infer, rs->kill);

            rs->cancel = GetCanceled();
        }

        if (st.size() - st.finish == 0 && tp_group_->is_same_process()) {
            // Only thread comm has blocking sync
            tp_group_->Sync(true);
        }

        if (tp_group_->n_ranks() > 1) {
            Broadcast(tp_group_, rs, 0);
        }

        if (rs->abort) {
            TM_LOG_INFO("[Engine] stop requested.");
            break;
        }

        vector signals;

        Kill(rs->kill, signals);

        Accept(rs->infer, signals);

        Cancel(rs->cancel, signals);

        gateway_.notify(std::move(signals), tp_rank_ == 0);

        int n_active = st.size() - st.finish;

        TM_CHECK_GE(n_active, 0);

        n_active = AllReduce(dp_group_, n_active, comm::RedOp::kSum);

        if (n_active) {

            Schedule();

            UpdateScheduleMetrics();

            Setup(*d);

            d->ready.Record(core::Context::stream());

            // auto future = (d->promise = {}).get_future();

            outbound_.push(std::move(d));

            if (!inbound_.pop(d)) {
                break;
            }

            // Must assume `d` is not the same one as above
            TM_CHECK_NOTNULL(d);

            core::Context::stream().Wait(d->done);

            Update(*d, signals);

            gateway_.notify(std::move(signals), tp_rank_ == 0);

            // if (future.valid()) {
            //     future.get().Sync();
            // }
        }

        // dbg("=========================================================================");
    }
}

Engine::~Engine() = default;

Engine::Engine()                  = default;
Engine::Engine(Engine&&) noexcept = default;
Engine& Engine::operator=(Engine&&) noexcept = default;

Engine::Engine(DataType      dtype,
               EngineParam   param,
               LanguageModel model,
               Context&      ctx,
               Gateway&      gateway,
               int           device_id,
               int           dp_rank,
               int           phases):
    impl_{std::make_unique(dtype, param, std::move(model), ctx, gateway, device_id, dp_rank, phases)}
{
}

void Engine::Start()
{
    return impl_->Start();
}

void Engine::Impl::UpdateScheduleMetrics()
{
    if (param_.enable_metrics) {
        const auto& [total, active, cached] = seq_mgr_->seq_stats();

        auto m = std::make_shared();

        m->total_seqs   = total;
        m->active_seqs  = active;
        m->waiting_seqs = total - active;

        m->total_blocks  = seq_mgr_->total_count();
        m->active_blocks = seq_mgr_->active_count();
        m->cached_blocks = seq_mgr_->cached_count();
        m->free_blocks   = seq_mgr_->free_count();

        std::atomic_store_explicit(&metrics_, std::move(m), std::memory_order_release);
    }
}

shared_ptr Engine::GetScheduleMetrics()
{
    if (impl_->param_.enable_metrics) {
        return std::atomic_load_explicit(&impl_->metrics_, std::memory_order_acquire);
    }
    return {};
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/engine.h
================================================

#pragma once

#include 

#include "src/turbomind/engine/gateway.h"

#include "src/turbomind/models/language_model.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

struct ScheduleMetrics;

class Engine {
public:
    ~Engine();

    Engine();
    Engine(Engine&&) noexcept;
    Engine& operator=(Engine&&) noexcept;

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

    Engine(DataType      dtype,
           EngineParam   param,
           LanguageModel model,
           Context&      ctx,
           Gateway&      gateway,
           int           device_id,
           int           queue_id,
           int           phases);

    void Start();

    std::shared_ptr GetScheduleMetrics();

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/gateway.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/engine/gateway.h"
#include "src/turbomind/engine/request_queue.h"

namespace turbomind {

Gateway::Gateway(int size, std::function()> ctx_factory):
    size_{size}, queues_(size_), dp_thr_{1}, ctx_factory_{ctx_factory}, next_{0}
{
    for (int i = 0; i < size_; ++i) {
        queues_[i] = std::make_unique();
    }

    signal_thread_ = std::thread(&Gateway::signal_thread_entry, this);
}

void Gateway::shutdown()
{
    for (auto& q : queues_) {
        q->close();
    }

    signal_buffer_.close();
    signal_thread_.join();
}

void Gateway::push(std::shared_ptr r)
{
    int rank = -1;

    if (TM_UNLIKELY(!r->session.start_flag)) {
        // route to corresponding rank
        rank = binding_.find(r->session.id);
    }
    else if (TM_LIKELY(size_)) {
        rank = next_.fetch_add(1, std::memory_order_relaxed) % size_;
    }
    else {
        TM_LOG_ERROR("[Gateway] No queues available for submitting the request");
        notify({[r = std::move(r)] { UpdateState(*r, Request::kNoQueue, 0); }});
        return;
    }

    if (TM_LIKELY(rank >= 0)) {
        queues_[rank]->push({std::move(r)});
    }
    else {
        TM_LOG_ERROR("[Gateway] Failed to find a binded queue for %lu", r->session.id);
        notify({[r = std::move(r)] { UpdateState(*r, Request::kInvalid, 0); }});
    }
}

void Gateway::pop(std::vector>& infer_reqs,
                  std::vector>& kill_reqs,
                  unsigned                               max_infer,
                  bool                                   blocking,
                  bool&                                  abort,
                  comm::HostComm&                        dp_group,
                  int                                    qid)
{
    TM_CHECK_GE(qid, 0);

    auto& q = *queues_.at(qid);

    infer_reqs.clear();
    kill_reqs.clear();

    if (dp_group->n_ranks() == 1) {
        q.pop(infer_reqs, kill_reqs, max_infer, blocking, abort);
    }
    else {
        union {
            uint16_t data[2];
            uint32_t value;
        };
        while (true) {
            q.pop(infer_reqs, kill_reqs, max_infer, false, abort);
            data[0] = !(blocking && infer_reqs.empty() && kill_reqs.empty());  // ready?
            data[1] = abort;
            value   = comm::AllReduce(dp_group, value, comm::RedOp::kSum);
            if (data[0] >= dp_thr_ || data[1]) {
                break;
            }
        }
        abort = data[1];
    }

    // Assign a monotonic increasing id for each infer request
    q.assign_unique_ids(infer_reqs);

    // Bind for stateful inference
    std::vector bind_ids;
    for (const auto& r : infer_reqs) {
        if (r->session.start_flag && !r->session.end_flag) {  // started but not ended
            bind_ids.push_back(r->session.id);
        }
    }

    /// TODO: fix qid <-> rank mapping
    if (!bind_ids.empty()) {
        binding_.bind(bind_ids, qid);
    }

    // Unbind for stateful kill
    std::vector unbind_ids;
    for (const auto& r : kill_reqs) {
        unbind_ids.push_back(r->session.id);
    }
    if (!unbind_ids.empty()) {
        binding_.unbind(unbind_ids, qid);
    }
}

void Gateway::cancel(std::shared_ptr r)
{
    // {-1: canceled, 0: queued, 1: active}
    if (r->cancel_flag.exchange(-1, std::memory_order_acq_rel) == 0) {
        notify({[r = std::move(r)] {  //
            UpdateState(*r, Request::kCancel, 0);
        }});
    }
    else {
        // request is picked up by engine
    }
}

void Gateway::kill(std::shared_ptr r)
{
    if (auto rank = binding_.find(r->session.id); rank >= 0) {
        queues_[rank]->kill(std::move(r));
    }
    else {
        TM_LOG_ERROR("[Gateway] Failed to find a binded queue for %lu", r->session.id);
        notify({[r = std::move(r)] {  //
            UpdateState(*r, Request::kInvalid, 0);
        }});
    }
}

void Gateway::notify(std::vector signals, bool pred)
{
    if (pred) {
        signal_buffer_.push(std::move(signals));
    }
}

void Gateway::signal_thread_entry() noexcept
{
    while (true) {
        bool                abort{};
        std::vector signals = signal_buffer_.take_all(abort);
        if (abort) {
            break;
        }
        else {
            auto ctx = ctx_factory_();
            for (const auto& s : signals) {
                s();
            }
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/gateway.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/engine/request.h"
#include "src/turbomind/engine/request_queue.h"
#include "src/turbomind/engine/signal_buffer.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

class SequenceBinding {
public:
    int find(uint64_t seq_id)
    {
        std::lock_guard lock{mutex_};
        if (auto it = map_.find(seq_id); it != map_.end()) {
            return it->second;
        }
        return -1;
    }

    void bind(const std::vector& seq_ids, int rank)
    {
        std::lock_guard lock{mutex_};
        for (const auto& x : seq_ids) {
            if (auto [it, success] = map_.emplace(x, rank); !success) {
                TM_LOG_WARNING("[TM][Gateway] Duplicated binding for %lu, %d vs %d", x, rank, it->second);
            }
        }
    }

    void unbind(const std::vector& seq_ids, int rank)
    {
        std::lock_guard lock{mutex_};
        for (const auto& x : seq_ids) {
            auto it = map_.find(x);
            if (it == map_.end()) {
                TM_LOG_WARNING("[TM][Gateway] No entry found for unbinding %lu, %d", x, rank);
            }
            else if (it->second != rank) {
                TM_LOG_WARNING("[TM][Gateway] Mismatched entry for unbinding %lu, %d vs %d", x, rank, it->second);
            }
            else {
                map_.erase(it);
            }
        }
    }

private:
    std::mutex                        mutex_;
    std::unordered_map map_;
};

class Gateway {
public:
    Gateway(int size, std::function()> ctx_factory);

    void shutdown();

    void push(std::shared_ptr r);

    void pop(std::vector>& infer_reqs,
             std::vector>& kill_reqs,
             unsigned                               max_infer,
             bool                                   blocking,
             bool&                                  abort,
             comm::HostComm&                        dp_group,
             int                                    qid);

    void cancel(std::shared_ptr r);

    void kill(std::shared_ptr r);

    void notify(std::vector signals, bool pred = true);

    void set_threshold(int value)
    {
        TM_LOG_INFO("set threshold %d -> %d", dp_thr_, value);
        dp_thr_ = value;
    }

private:
    void signal_thread_entry() noexcept;

private:
    const int size_;

    int dp_thr_;

    std::vector> queues_;

    std::function()> ctx_factory_;

    SignalBuffer signal_buffer_;
    std::thread  signal_thread_;

    SequenceBinding binding_;

    std::atomic next_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/model_executor.cc
================================================

#include "src/turbomind/engine/model_executor.h"

#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/copy.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/models/language_model.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/anomaly_handler.h"

// #include "dbg.h"

namespace turbomind {

using std::shared_ptr;
using std::unique_ptr;

struct ModelExecutor::Impl {

    LanguageModel& model_;
    LlamaLinear&   linear_;

    const int device_id_;

    Queue>& inbound_;
    Queue>& outbound_;

    std::thread internal_thread_;

    void InternalThreadEntry()
    {
        check_cuda_error(cudaSetDevice(device_id_));

        Stream    stream  = Stream::create();
        Allocator h_alloc = Allocator(kCPU);
        Allocator d_alloc = Allocator(stream, false);

        AnomalyHandler::instance().Init(0, 1000, 0, 1000, stream.handle());

        core::ContextGuard ctx{stream, h_alloc, d_alloc};

        unique_ptr d;

        while (inbound_.pop(d)) {
            TM_CHECK_NOTNULL(d);
            core::Context::stream().Wait(d->ready);
            Run(*d);
            d->done.Record(core::Context::stream());
            outbound_.push(std::move(d));
        }
    }

    void Run(BatchData& d)
    {
        auto batch = &d;

        BatchCopy copy;
        TensorMap env{{"batch", d.buf()}, {"copy", copy.buf()}};

        model_.Run(BatchOp::kPrepare, d.phase, env);
        // dbg(copy);
        copy.Run();

        model_.Run(BatchOp::kForward, d.phase, env);

        model_.Run(BatchOp::kUnprep, d.phase, env);
        // dbg(copy);
        copy.Run();

        // TM_CHECK(0);
        AnomalyHandler::instance().Summarize([](...) {});
        AnomalyHandler::instance().Reset();
    }

    Impl(LanguageModel&                model,
         Context&                      context,
         int                           device_id,
         Queue>& inbound,
         Queue>& outbound):
        model_{model}, linear_{*context.linear}, device_id_{device_id}, inbound_{inbound}, outbound_{outbound}
    {
    }

    ~Impl()
    {
        if (internal_thread_.joinable()) {
            internal_thread_.join();
        }
    }

    void Start()
    {
        internal_thread_ = std::thread(&Impl::InternalThreadEntry, this);
    }
};

ModelExecutor::~ModelExecutor() = default;

ModelExecutor::ModelExecutor()                         = default;
ModelExecutor::ModelExecutor(ModelExecutor&&) noexcept = default;
ModelExecutor& ModelExecutor::operator=(ModelExecutor&&) noexcept = default;

ModelExecutor::ModelExecutor(LanguageModel&                model,
                             Context&                      context,
                             int                           device_id,
                             Queue>& inbound,
                             Queue>& outbound):
    impl_{std::make_unique(model, context, device_id, inbound, outbound)}
{
}

void ModelExecutor::Start()
{
    return impl_->Start();
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/model_executor.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/core.h"

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/queue.h"
#include "src/turbomind/models/language_model.h"

#include "src/turbomind/models/llama/context.h"

namespace turbomind {

// Model executor for auto-regressive language models
class ModelExecutor {
public:
    ~ModelExecutor();

    ModelExecutor();
    ModelExecutor(ModelExecutor&&) noexcept;
    ModelExecutor& operator=(ModelExecutor&&) noexcept;

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

    ModelExecutor(LanguageModel&                     model,
                  Context&                           context,
                  int                                device_id,
                  Queue>& inbound,
                  Queue>& outbound);

    void Start();

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/model_request.cc
================================================


#include 
#include 
#include 
#include 
#include 

#include "xgrammar/compiler.h"
#include "xgrammar/matcher.h"

#include "src/turbomind/engine/model_request.h"
#include "src/turbomind/engine/request.h"
#include "src/turbomind/utils/constant.h"
#include "src/turbomind/utils/metrics.h"

namespace turbomind {

ModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim):
    gateway_{gateway},
    data_type_{data_type},
    session_len_{session_len},
    vocab_size_{vocab_size},
    hidden_dim_{hidden_dim}
{
}

void ModelRequest::Cancel()
{
    // request is finished if lock failed
    if (auto r = request_.lock()) {
        gateway_->cancel(std::move(r));
    }
}

void ModelRequest::End(std::function cb, uint64_t session_id)
{
    auto r = std::make_shared();

    r->id = r->session.id = session_id;
    r->session.kill_flag  = true;

    r->end_cb = std::move(cb);

    gateway_->kill(std::move(r));
}

auto ModelRequest::Forward(InputParam param, std::function cb) -> OutputParam
{
    inputs_  = std::make_shared();
    outputs_ = std::make_shared();

    auto add = [](auto& dest, auto key, auto dtype, auto where, auto shape, auto&&... dims) {
        Layout shape_;
        if constexpr (std::is_integral_v) {
            shape_ = {shape, dims...};
        }
        else {
            shape_ = {shape.cbegin(), shape.cend()};
        }
        dest->emplace(key, Tensor{shape_, dtype, where});
    };

    auto& inputs = *param.tensors;

    TM_CHECK_EQ(inputs.at("input_ids").ndim(), 1);

    const int input_len  = inputs.at("input_ids").shape(0);
    const int output_len = param.gen_cfg.max_new_tokens;

    // Max possible length of a sequence, this depends on `history_len` which isn't available here, so `session_len`
    // is used instead
    const int max_seq_len = session_len_ + 1;
    const int max_out_len = std::min(output_len, session_len_) + 1;
    // This does not include histroy length in interactive mode
    const int max_in_out_len = std::min(input_len + output_len, session_len_) + 1;

    for (auto& [k, v] : *param.tensors) {
        inputs_->emplace(k, v);
    }

    add(outputs_, "output_ids", data_type_v, kCPU, max_seq_len);
    add(outputs_, "sequence_length", data_type_v, kCPU, 1);

    if (param.gen_cfg.output_logits) {
        const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len;
        add(outputs_, "logits", data_type_, kCPU, len, vocab_size_);
    }

    if (param.gen_cfg.output_last_hidden_state) {
        const int len = param.gen_cfg.output_last_hidden_state == GenerationConfig::kAll ? max_in_out_len : max_out_len;
        add(outputs_, "last_hidden_state", data_type_, kCPU, len, hidden_dim_);
    }

    if (param.gen_cfg.output_logprobs) {
        add(outputs_, "logprob_vals", data_type_v, kCPU, max_out_len, kMaxLogProb);
        add(outputs_, "logprob_indexes", data_type_v, kCPU, max_out_len, kMaxLogProb);
        add(outputs_, "logprob_nums", data_type_v, kCPU, max_out_len);
    }

    auto r = std::make_shared();

    for (const auto& [k, v] : *inputs_) {
        r->inputs.emplace(k, v);
    }
    for (const auto& [k, v] : *outputs_) {
        r->outputs.emplace(k, v);
    }

    auto state = std::make_shared();

    auto metrics = param.enable_metrics ? std::make_shared() : nullptr;
    if (metrics) {
        metrics->enqueue_time.store(RequestMetrics::timestamp(), std::memory_order_relaxed);
        metrics->scheduled_time.store(0, std::memory_order_relaxed);
    }

    if (param.session.start_flag) {
        session_id_ = param.session.id;
    }

    r->id            = param.session.id;
    r->session       = param.session;
    r->gen_cfg       = param.gen_cfg;
    r->stream_output = param.stream_output;
    r->forward_cb    = std::move(cb);
    r->state         = state;
    r->metrics       = metrics;

    r->output_ids      = outputs_->at("output_ids");
    r->sequence_length = outputs_->at("sequence_length");

    if (grammar_) {
        r->grammar = std::move(grammar_);
        r->matcher = std::make_shared(*r->grammar);
    }

    // Keep a WEAK reference for canceling the request
    request_ = r;

    gateway_->push({std::move(r)});

    return OutputParam{outputs_, state, metrics};
}

void ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar)
{
    grammar_ = std::make_shared(grammar);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/model_request.h
================================================


#pragma once

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/gateway.h"

namespace xgrammar {
class CompiledGrammar;
}

namespace turbomind {

class ModelRequest {
public:
    virtual ~ModelRequest() = default;

    ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim);

    // Cancel running request
    void Cancel();

    // Reset the channel to uninitailized state, calls `notify` when done
    void End(std::function cb, uint64_t session_id);

    struct InputParam {
        std::shared_ptr tensors;

        SessionParam     session;
        GenerationConfig gen_cfg;

        bool stream_output;
        bool enable_metrics;
    };

    struct OutputParam {
        std::shared_ptr          tensors;
        std::shared_ptr state;
        std::shared_ptr     metrics;
    };

    OutputParam Forward(InputParam param, std::function cb);

    void setGrammar(const xgrammar::CompiledGrammar& grammar);

protected:
    Gateway* const gateway_;

    const DataType data_type_;

    const int session_len_;
    const int hidden_dim_;
    const int vocab_size_;

    uint64_t session_id_;

    std::weak_ptr request_;

    std::shared_ptr inputs_;
    std::shared_ptr outputs_;

    std::shared_ptr grammar_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/queue.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

namespace turbomind {

template
class Queue {
public:
    template
    void push(X&& x)
    {
        {
            std::lock_guard lock{mutex_};
            queue_.push(std::forward(x));
        }
        cv_.notify_one();
    }

    bool pop(T& x)
    {
        std::unique_lock lock{mutex_};
        cv_.wait(lock, [&] { return !queue_.empty() || is_closed_; });
        if (is_closed_) {
            return false;
        }
        x = std::move(queue_.front());
        queue_.pop();
        return true;
    }

    void close()
    {
        {
            std::lock_guard lock{mutex_};
            is_closed_ = true;
        }
        cv_.notify_all();
    }

private:
    std::queue           queue_;
    std::mutex              mutex_;
    std::condition_variable cv_;
    bool                    is_closed_{false};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/request.cc
================================================


#include "src/turbomind/engine/request.h"

#include 

namespace turbomind {

namespace {

template
inline std::ostream& operator<<(std::ostream& os, const std::vector& vec)
{
    os << "[";
    std::copy(vec.begin(), vec.end(), std::ostream_iterator(os, ", "));
    if (!vec.empty()) {
        os.seekp(-2, std::ios_base::end);
    }
    os << "]";
    return os;
}

}  // namespace

std::ostream& operator<<(std::ostream& os, const GenerationConfig& c)
{
    os << "GenerationConfig { ";
    os << "max_new_tokens=" << c.max_new_tokens;
    os << ", min_new_tokens=" << c.min_new_tokens;
    os << ", eos_ids=" << c.eos_ids;
    os << ", stop_ids=[" << c.stop_ids[0] << ", " << c.stop_ids[1] << "]";
    os << ", bad_ids=[" << c.bad_ids[0] << ", " << c.bad_ids[1] << "]";
    os << ", top_p=" << c.top_p;
    os << ", top_k=" << c.top_k;
    os << ", min_p=" << c.min_p;
    os << ", temperature=" << c.temperature;
    os << ", repetition_penalty=" << c.repetition_penalty;
    os << ", random_seed=" << c.random_seed;
    os << ", output_logprobs=" << c.output_logprobs;
    os << ", output_hidden_states=" << c.output_last_hidden_state;
    os << ", output_logits=" << c.output_logits;
    os << " }";
    return os;
}

void UpdateState(Request& r, int status, int seq_len)
{
    try {
        auto new_state = new RequestState{status, seq_len};
        auto old_state = r.state->exchange(new_state);
        if (!old_state && r.forward_cb) {
            r.forward_cb();
        }
    }
    catch (const std::exception& e) {
        TM_LOG_ERROR("Error invoking callback for (%lu): %s", r.id, e.what());
    }
    catch (...) {
        TM_LOG_ERROR("Unknown error invoking callback for (%lu)", r.id);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/request.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/core/interval.h"
#include "src/turbomind/utils/metrics.h"

namespace xgrammar {
class GrammarMatcher;  // forward declaration
class CompiledGrammar;
}  // namespace xgrammar

namespace turbomind {

struct GenerationConfig {
    int max_new_tokens = 0;
    int min_new_tokens = 0;

    std::vector eos_ids;  // only support single token id

    std::array, 2> stop_ids;  // (token_id, offset)
    std::array, 2> bad_ids;

    int   top_k       = 1;
    float top_p       = 0.f;
    float min_p       = 0.f;
    float temperature = 1.f;

    float repetition_penalty = 1.f;

    uint64_t random_seed = 0;

    int output_logprobs = 0;

    enum OutType
    {
        kNone       = 0,
        kAll        = 1,
        kGeneration = 2
    };
    int output_last_hidden_state = 0;
    int output_logits            = 0;
};

std::ostream& operator<<(std::ostream& os, const GenerationConfig& c);

struct SessionParam {
    uint64_t id;

    int step;

    bool start_flag;
    bool end_flag;
    bool kill_flag;
};

struct RequestState {
    int status;
    int seq_len;
};

struct AtomicRequestState {

    std::atomic data_;

    static_assert(std::atomic::is_always_lock_free);

    ~AtomicRequestState()
    {
        auto data = exchange(nullptr);
    }

    std::unique_ptr exchange(RequestState* data)
    {
        return std::unique_ptr{data_.exchange(data, std::memory_order_acq_rel)};
    }
};

struct Request {
    uint64_t id;         // sequence id
    uint64_t unique_id;  // monotonic increasing

    SessionParam     session;
    GenerationConfig gen_cfg;

    bool stream_output;

    // reference to IO tensors
    TensorMap inputs;
    TensorMap outputs;
    // fast path for accessing common output buffers
    Tensor_ output_ids;
    Tensor_ sequence_length;

    std::function end_cb;

    std::atomic cancel_flag;

    std::function forward_cb;

    std::shared_ptr state;

    std::shared_ptr metrics;

    int ec = 0;  // set when disabling conflicting requests

    enum
    {
        kOk            = 0,
        kInvalid       = 1,  // Sequence not exist or both `start` & `stop` (instead of `end`) is set
        kConflict      = 2,  // Concurrent requests to the same sequence
        kBusy          = 3,  // Sequence is already running
        kInactive      = 4,  // Sequence to `stop` is not active
        kFail          = 5,  // Can't find sequence for `stop` request or internal error during inference
        kTooLong       = 6,  // history + prompt > session_len,
        kFinish        = 7,
        kCancel        = 8,
        kInconsistency = 9,   // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode
        kNoQueue       = 10,  // No queue available for submitting the request (in current process)
    };

    std::shared_ptr grammar;
    std::shared_ptr  matcher;
};

void UpdateState(Request& r, int status, int seq_len);

class Sequence;

// Unlike `Request` which is shared by all local TP ranks, each rank has its own `RequestCache`.
struct RequestCache {
    std::shared_ptr req;
    const Sequence*          seq;  // May be NULL in `Update` (seq get erased when req is done)
    const GenerationConfig&  gen_cfg;

    RequestCache(std::shared_ptr r, const Sequence& s): req{std::move(r)}, seq{&s}, gen_cfg{req->gen_cfg} {}

    int status = Request::kOk;

    // These members may be opaque handles from individual modules (pointers to forward declared types), but we tend to
    // keep it simple as long as the complexity is manageable

    int*     token_ids    = nullptr;  // currently the `output_ids` buf of request
    uint8_t* random_state = nullptr;

    int step0       = 0;  // set at request init, constant, first prefill step
    int prompt_len  = 0;  // set at request init, constant, first decode step
    int max_seq_len = 0;  // set at request init, constant

    int hidden_states_offset = 0;  // set at request init, constant
    int logits_offset        = 0;  // set at request init, constant

    int seq_len = 0;  // set at request init, updated per step

    int input_len   = 0;  // set at schedule (set to `seq.input_len`)
    int history_len = 0;  // set at schedule (set to `seq.cache_len`)

    bool autoregres = false;  // set at schedule, `seq_len` and `input_ids` taken from the engine
    bool generating = false;  // set at schedule

    bool done = false;  // set at cancel / update, is the request finished / canceled

    int alpha = 0;  // pending growth of cache_len (draft_len + input_len)
    int beta  = 0;  // pending growth of seq_len (draft_len + {0,1})

    float rope_base = 0.f;

    Interval output_hidden_states;
    Interval output_logits;
};

template
void serdes(Archive& ar, GenerationConfig& g)
{
    // clang-format off
    ar & g.max_new_tokens;
    ar & g.min_new_tokens;
    ar & g.eos_ids;
    ar & g.stop_ids[0];
    ar & g.stop_ids[1];
    ar & g.bad_ids[0];
    ar & g.bad_ids[1];
    ar & g.top_k;
    ar & g.top_p;
    ar & g.min_p;
    ar & g.temperature;
    ar & g.repetition_penalty;
    ar & g.random_seed;
    ar & g.output_logprobs;
    ar & g.output_last_hidden_state;
    ar & g.output_logits;
    // clang-format on
}

template
void save_req_output(Archive& ar, const TensorMap& map)
{
    // clang-format off
    ar & map.size();
    for (const auto& [k, t] : map) {
        TM_CHECK(t.device().type == kCPU);
        ar & k;
        ar & t.layout();
        ar & t.dtype();
    }
    // clang-format on
}

template
void load_req_output(Archive& ar, TensorMap& map)
{
    // clang-format off
    decltype(map.size()) size;
    ar & size;
    for (int i = 0; i < size; ++i) {
        std::string k;
        Layout      layout;
        DataType    dtype;
        ar & k;
        ar & layout;
        ar & dtype;
        map.emplace(std::move(k), Tensor{layout, dtype, kCPU});
    }
    // clang-format on
}

template
void serdes(Archive& ar, Request& r)
{
    // clang-format off
    ar & r.id;
    ar & r.unique_id;
    ar & r.session;
    ar & r.gen_cfg;
    ar & r.stream_output;
    ar & r.inputs;
    if constexpr(Archive::is_loading) {
        load_req_output(ar, r.outputs);
        r.output_ids      = r.outputs.at("output_ids");
        r.sequence_length = r.outputs.at("sequence_length");
    } else {
        save_req_output(ar, r.outputs);
    }
    ar & r.ec;
    // clang-format on
}

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/request_queue.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/engine/request_queue.h"
#include "src/turbomind/engine/gateway.h"

#include "src/turbomind/engine/request.h"

namespace turbomind {

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/request_queue.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 

#include "src/turbomind/engine/request.h"

namespace turbomind {

class RequestQueue {
public:
    explicit RequestQueue(): queue_{&pool_} {}

    void push(std::shared_ptr r)
    {
        {
            std::lock_guard lock{mutex_};
            if (closed_) {
                throw std::runtime_error("Queue is closed");
            }
            queue_.push_back(std::move(r));
        }
        cv_.notify_one();
    }

    void kill(std::shared_ptr r)
    {
        {
            std::lock_guard lock{mutex_};
            if (closed_) {
                throw std::runtime_error("Queue is closed");
            }
            kill_.push_back(std::move(r));
        }
        cv_.notify_one();
    }

    void pop(std::vector>& infer_reqs,
             std::vector>& kill_reqs,
             unsigned                               max_infer,
             bool                                   blocking,
             bool&                                  abort)
    {
        std::unique_lock lock{mutex_};

        if (blocking) {
            cv_.wait(lock, [this] { return !(queue_.empty() && kill_.empty()) || closed_; });
        }

        if (closed_) {
            abort = true;
        }

        while (!queue_.empty() && infer_reqs.size() < max_infer) {
            auto& r = queue_.front();
            if (r->cancel_flag.exchange(1, std::memory_order_acq_rel) == 0) {
                infer_reqs.push_back(std::move(r));
            }
            queue_.pop_front();
        }

        kill_reqs.insert(kill_reqs.end(), kill_.begin(), kill_.end());
        kill_.clear();
    }

    void close()
    {
        {
            std::lock_guard lock(mutex_);
            closed_ = true;
        }
        cv_.notify_all();
    }

    void notify()
    {
        cv_.notify_all();
    }

    void assign_unique_ids(std::vector>& rs)
    {
        for (auto& r : rs) {
            r->unique_id = unique_id_.fetch_add(1, std::memory_order_relaxed);
        }
    }

private:
    std::atomic unique_id_{};

    std::pmr::unsynchronized_pool_resource   pool_;
    std::pmr::list> queue_;

    std::vector> kill_;

    std::mutex              mutex_;
    std::condition_variable cv_;

    bool closed_{};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/engine/signal_buffer.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

namespace turbomind {

using Signal = std::function;

class SignalBuffer {
public:
    void push(std::vector signals)
    {
        if (signals.empty()) {
            return;
        }
        {
            std::lock_guard lock{mutex_};
            signals_.insert(signals_.end(), std::move_iterator{signals.begin()}, std::move_iterator{signals.end()});
        }
        cv_.notify_one();
    }

    void close()
    {
        {
            std::lock_guard lock{mutex_};
            aborted_ = true;
        }
        cv_.notify_all();
    }

    std::vector take_all(bool& abort)
    {
        std::vector signals;
        {
            std::unique_lock lock{mutex_};
            cv_.wait(lock, [&] { return !signals_.empty() || aborted_; });
            if (aborted_) {
                abort = true;
            }
            else {
                signals.swap(signals_);
            }
        }
        return signals;
    }

private:
    std::vector signals_;

    std::mutex              mutex_;
    std::condition_variable cv_;

    bool aborted_{false};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/CMakeLists.txt
================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.11)

add_library(guided_decoding STATIC guided_decoding.cc)
target_link_libraries(guided_decoding PRIVATE
    apply_token_bitmask_inplace_cuda
    xgrammar
    core)
set_property(TARGET guided_decoding PROPERTY POSITION_INDEPENDENT_CODE ON)

add_library(generation STATIC
    generation.cc
    logits_processor.cc
    sampling.cc
    stop_criteria.cc)
set_property(TARGET generation PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET generation PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(generation PUBLIC
    ban_bad_words
    sampling_penalty_kernels
    sampling_topk_kernels
    sampling_topp_kernels
    sampling_kernels
    stop_criteria
    guided_decoding
    memory_utils
    CUDA::cudart)


================================================
FILE: src/turbomind/generation/base_param.h
================================================


#pragma once

namespace turbomind {

class BaseGenerationParam {
public:
    explicit BaseGenerationParam(int max_batch_size, int vocab_size, int vocab_size_padded):
        max_batch_size_{max_batch_size}, vocab_size_{vocab_size}, vocab_size_padded_{vocab_size_padded}
    {
    }

protected:
    int max_batch_size_;
    int vocab_size_;
    int vocab_size_padded_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/generation.cc
================================================

#include 

#include "src/turbomind/generation/generation.h"

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/copy.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/state.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/request.h"

#include "src/turbomind/generation/guided_decoding.h"
#include "src/turbomind/generation/logits_processor.h"
#include "src/turbomind/generation/sampling.h"
#include "src/turbomind/generation/stop_criteria.h"

#include "src/turbomind/kernels/sampling_topk_kernels.h"  // InitializeRandomStates

#include "src/turbomind/models/llama/llama_kernels.h"  // invokePadLastTokenIds

// #include "dbg.h"

namespace turbomind {

using std::unique_ptr;
using std::shared_ptr;
using std::vector;

struct GenerationData {
    Buffer_  random_state;
    Buffer_ random_seed;
    Buffer_     random_init;
    Buffer_      max_seq_len;
    Buffer_     token_ids_ptrs;
    Buffer_      output_ids;

    bool random_init_needed;
    int  generation_size;
};

struct Generation::Impl {

    // child modules
    unique_ptr logits_processor_;
    unique_ptr        sampling_;
    shared_ptr    stop_criteria_;
    unique_ptr  guided_decoding_;

    // persistent
    Tensor_ token_ids_;

    // scheduling states
    vector h_token_ids_ptrs_;
    vector h_token_ids_free_;

    // execution states
    State random_state_;

    // immutable states
    Buffer_ output_ids_;

    std::vector> data_;

    // staging buffers
    Buffer_  random_state_buf_;
    Buffer_ random_seed_buf_;
    Buffer_     random_init_buf_;
    Buffer_     token_ids_ptrs_buf_;
    Buffer_      token_ids_buf_;
    Buffer_      output_ids_buf_;

    const int max_batch_size_;
    const int session_len_;

    Impl(DataType              dtype,
         int                   max_batch_size,
         int                   session_len,
         int                   vocab_size,
         int                   vocab_size_padded,
         const comm::HostComm& tp_group,
         int                   phases):
        max_batch_size_{max_batch_size}, session_len_{session_len}
    {
        TM_CHECK_EQ(dtype, kFloat32);
        BaseGenerationParam base{max_batch_size, vocab_size, vocab_size_padded};
        logits_processor_ = std::make_unique(base, phases);
        sampling_         = std::make_unique(base, phases);
        stop_criteria_    = std::make_unique(base, phases);
        guided_decoding_  = std::make_unique(base, tp_group, phases);

        static_assert(sizeof(curandState_t) % alignof(curandState_t) == 0);
        random_state_ = {{max_batch_size_, (int)sizeof(curandState_t)}, kUint8, kDEVICE};
        token_ids_    = {{max_batch_size_, session_len_}, kDEVICE};
        output_ids_   = {max_batch_size_, kDEVICE};
        for (int i = 0; i < max_batch_size_; ++i) {
            h_token_ids_free_.push_back(token_ids_.data() + i * token_ids_.stride(0));
        }
        h_token_ids_ptrs_.resize(max_batch_size_);

        random_state_buf_ = {max_batch_size_ * (int)sizeof(curandState_t), kCPUpinned};
        random_seed_buf_  = {max_batch_size_, kCPUpinned};
        random_init_buf_  = {max_batch_size_, kCPUpinned};

        token_ids_ptrs_buf_ = {max_batch_size_, kCPUpinned};
        token_ids_buf_      = {max_batch_size_ * (ssize_t)session_len_, kCPUpinned};

        output_ids_buf_ = {max_batch_size_, kCPUpinned};

        for (int i = 0; i < phases; ++i) {
            auto d = std::make_unique();

            d->random_state   = empty_like(random_state_buf_, kDEVICE);
            d->random_seed    = empty_like(random_seed_buf_, kDEVICE);
            d->random_init    = empty_like(random_init_buf_, kDEVICE);
            d->token_ids_ptrs = empty_like(token_ids_ptrs_buf_, kDEVICE);
            d->output_ids     = empty_like(output_ids_, kDEVICE);

            data_.push_back(std::move(d));
        }
    }

    void Setup(int phase, TensorMap& env)
    {
        auto& d = *data_.at(phase);

        auto& b    = *env.at("batch").data()[0];
        auto& copy = *env.at("copy").data()[0];

        const auto& rc = b.rc;

        // random states
        d.random_init_needed = false;
        for (int i = 0; i < b.perm.size(); ++i) {
            const auto& c = *rc[i];
            if (TM_LIKELY(b.perm[i] < b.bs0)) {  // existing
                random_init_buf_[i] = false;
            }
            else if (c.random_state) {  // already initialized
                std::copy_n(
                    c.random_state, sizeof(curandState_t), random_state_buf_.data() + i * sizeof(curandState_t));
            }
            else {  // uninitialized
                d.random_init_needed = true;
                random_init_buf_[i]  = true;
                random_seed_buf_[i]  = rc[i]->gen_cfg.random_seed;
            }
        }
        copy(random_state_buf_, b.bsz, d.random_state);
        if (d.random_init_needed) {
            copy(random_init_buf_, b.bsz, d.random_init);
            copy(random_seed_buf_, b.bsz, d.random_seed);
        }

        vector used(b.bs0);
        for (int i = 0; i < b.bsz; ++i) {
            if (b.perm[i] < b.bs0) {
                used[b.perm[i]] = 1;
            }
        }
        for (int i = 0; i < b.bs0; ++i) {
            if (!used[i]) {  // free unused chunks
                h_token_ids_free_.push_back(h_token_ids_ptrs_[i]);
            }
        }
        // swap-in token_ids
        int* token_ids_buf = token_ids_buf_.data();
        for (int i = 0; i < rc.size(); ++i) {
            if (const auto& c = *rc[i]; TM_UNLIKELY(b.perm[i] >= b.bs0)) {
                // allocation
                TM_CHECK(!h_token_ids_free_.empty());
                token_ids_ptrs_buf_[i] = h_token_ids_free_.back();
                h_token_ids_free_.pop_back();
                // copy to staging buffer
                std::copy_n(c.token_ids, c.seq_len, token_ids_buf);
                copy(token_ids_buf, c.seq_len, token_ids_ptrs_buf_[i]);
                token_ids_buf += c.seq_len;
            }
            else {
                token_ids_ptrs_buf_[i] = h_token_ids_ptrs_[b.perm[i]];
            }
        }

        copy(token_ids_ptrs_buf_, b.bsz, d.token_ids_ptrs);

        // update `h_token_ids_ptrs_`
        std::copy_n(token_ids_ptrs_buf_.data(), b.bsz, h_token_ids_ptrs_.data());

        d.generation_size = 0;
        for (int i = 0; i < rc.size(); ++i) {
            const auto& c = *rc[i];
            d.generation_size += c.generating;
        }
        // dbg(d.generation_size);

        logits_processor_->Setup(phase, env);
        sampling_->Setup(phase, env);
        stop_criteria_->Setup(phase, env);
        guided_decoding_->Setup(phase, env);
    }

    void Prepare(int phase, TensorMap& env)
    {
        auto& d = *data_.at(phase);

        auto& b    = *env.at("batch").data()[0];
        auto& copy = *env.at("copy").data()[0];

        if (auto g = copy.group()) {
            Warp(random_state_.front(), d.random_state, b.bs0, b.perm, random_state_.back(), copy);
            random_state_.Swap();
        }
    }

    void Unprep(int phase, TensorMap& env)
    {
        auto& d    = *data_.at(phase);
        auto& b    = *env.at("batch").data()[0];
        auto& copy = *env.at("copy").data()[0];

        // state -> data
        copy(random_state_.front().buffer(), b.bsz * sizeof(curandState_t), d.random_state);
        copy(output_ids_, b.bsz, d.output_ids);
    }

    void Fetch(int phase, TensorMap& env)
    {
        auto& d    = *data_.at(phase);
        auto& copy = *env.at("copy").data()[0];

        copy(d.random_state, d.random_state.size(), random_state_buf_);
        env.produce("random_state", random_state_buf_);

        copy(d.output_ids, d.output_ids.size(), output_ids_buf_);
        env.produce("output_ids", output_ids_buf_);

        sampling_->Fetch(phase, env);
    }

    void Update(int phase, TensorMap& env)
    {
        sampling_->Update(phase, env);
    }

    void Forward(int phase, TensorMap& env)
    {
        auto& d = *data_.at(phase);
        auto& b = *env.at("batch").data()[0];

        const auto stream = core::Context::stream().handle();

        if (d.random_init_needed) {
            InitializeRandomStates((curandState_t*)random_state_.front().raw_data(),
                                   d.random_seed.data(),
                                   d.random_init.data(),
                                   b.bsz,
                                   stream);
            sync_check_cuda_error();
        }

        env.emplace("output_ids", output_ids_);              // out
        env.emplace("curand_state", random_state_.front());  // inout

        if (const int gs = d.generation_size) {

            env.emplace("token_ids_ptrs", d.token_ids_ptrs.slice(0, gs));

            auto logits = env.consume("logits");

            if (logits.dtype() != kFloat32) {
                auto tmp = empty_like(logits, kFloat32);
                invokeCastFloat2D(logits, tmp, stream);
                logits = std::move(tmp);
            }

            env.produce("logits", logits.slice(0, gs));

            Buffer_ output_pos{max_batch_size_, kDEVICE};
            Copy(env.at("sequence_length").buffer(), gs, output_pos);

            logits_processor_->Forward(phase, env);

            guided_decoding_->FillMask(phase, env);
            guided_decoding_->ApplyMask(phase, env);

            sampling_->Forward(phase, env);

            guided_decoding_->Update(phase, env);

            AppendTokenIds(d.token_ids_ptrs.data(), output_ids_.data(), output_pos.data(), gs, stream);

            stop_criteria_->Forward(phase, env);
        }
    }
};

Generation::~Generation() = default;

Generation::Generation(DataType              dtype,
                       int                   max_batch_size,
                       int                   session_len,
                       int                   vocab_size,
                       int                   vocab_size_padded,
                       const comm::HostComm& tp_group,
                       int                   phases):
    impl_{std::make_unique(dtype, max_batch_size, session_len, vocab_size, vocab_size_padded, tp_group, phases)}
{
}

void Generation::Run(BatchOp op, int phase, TensorMap& env)
{
    if (op == BatchOp::kSetup) {
        return impl_->Setup(phase, env);
    }
    else if (op == BatchOp::kPrepare) {
        return impl_->Prepare(phase, env);
    }
    else if (op == BatchOp::kForward) {
        return impl_->Forward(phase, env);
    }
    else if (op == BatchOp::kUnprep) {
        return impl_->Unprep(phase, env);
    }
    else if (op == BatchOp::kFetch) {
        return impl_->Fetch(phase, env);
    }
    else if (op == BatchOp::kUpdate) {
        return impl_->Update(phase, env);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/generation.h
================================================


#pragma once

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/batch.h"

namespace turbomind {

namespace comm {
class HostComm;
}

struct GenerationData;

class Generation {
public:
    ~Generation();

    Generation(DataType              data_type,  //
               int                   max_batch_size,
               int                   session_len,
               int                   vocab_size,
               int                   vocab_size_padded,
               const comm::HostComm& tp_group,
               int                   phases);

    void Run(BatchOp op, int phase, TensorMap& env);

private:
    struct Impl;

    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/guided_decoding.cc
================================================
#include "src/turbomind/generation/guided_decoding.h"

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/allocator.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h"
#include "xgrammar/matcher.h"
#include 

namespace turbomind {

struct GuidedDecoding::Data {
    Tensor_ bitmask;
    bool             active{};

    std::vector> matchers;
};

GuidedDecoding::GuidedDecoding(const BaseGenerationParam& base, const comm::HostComm& tp_group, int phases):
    BaseGenerationParam{base},        //
    tp_group_{tp_group->Split(0, 0)}  // duplicate to avoid data race
{
    const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_);

    bitmask_buf_    = {{max_batch_size_, bitmask_size}, kCPUpinned};
    output_ids_buf_ = {max_batch_size_, kCPUpinned};

    for (int i = 0; i < phases; ++i) {
        auto& d    = data_.emplace_back(std::make_shared());
        d->bitmask = empty_like(bitmask_buf_);
    }
}

void GuidedDecoding::Setup(int phase, TensorMap& env)
{
    auto& d = *data_.at(phase);
    auto& b = *env.at("batch").data()[0];

    d.matchers.clear();
    d.active = false;
    for (const auto& r : b.rc) {
        if (d.matchers.emplace_back(r->req->matcher)) {
            d.active = true;
        }
    }
}

void GuidedDecoding::FillMask(int phase, TensorMap& env)
{
    if (auto& d = *data_.at(phase); d.active) {
        static_assert(sizeof(ssize_t) == sizeof(int64_t));
        DLTensor dlbitmask{bitmask_buf_.data(),
                           DLDevice{kDLCPU, 0},
                           bitmask_buf_.ndim(),
                           xgrammar::GetBitmaskDLType(),
                           (int64_t*)bitmask_buf_.shape().data(),
                           nullptr,
                           0};
        if (tp_group_->rank() == 0) {
            for (size_t i = 0; i < d.matchers.size(); ++i) {
                if (const auto& matcher = d.matchers[i]; matcher && !matcher->IsTerminated()) {
                    matcher->FillNextTokenBitmask(&dlbitmask, i);
                }
                else {
                    std::fill_n(bitmask_buf_.data() + i * bitmask_buf_.stride(0),
                                bitmask_buf_.stride(0),
                                static_cast(-1));
                }
            }
        }
    }
}

void GuidedDecoding::ApplyMask(int phase, TensorMap& env)
{
    if (auto& d = *data_.at(phase); d.active) {
        const ssize_t numel = d.matchers.size() * bitmask_buf_.stride(0);
        if (tp_group_->n_ranks() > 1) {
            // bcast the data instead of `bitmask_buf` instance (which may avoid copying the data)
            comm::Broadcast(tp_group_, bitmask_buf_.data(), numel, 0);
        }
        Copy(bitmask_buf_.buffer(), numel, d.bitmask.buffer());
        ApplyTokenBitmaskInplace(env.at("logits"), d.bitmask.slice(0, d.matchers.size()));
    }
}

void GuidedDecoding::Update(int phase, TensorMap& env)
{
    if (auto& d = *data_.at(phase); d.active) {
        Copy(env.at("output_ids").buffer(), d.matchers.size(), output_ids_buf_);
        core::Context::stream().Sync();
        if (tp_group_->rank() == 0) {
            for (size_t i = 0; i < d.matchers.size(); ++i) {
                if (const auto& matcher = d.matchers[i]; matcher && !matcher->IsTerminated()) {
                    matcher->AcceptToken(output_ids_buf_[i]);
                }
            }
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/guided_decoding.h
================================================
#pragma once

#include 

#include "src/turbomind/generation/base_param.h"

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/core.h"

namespace turbomind {

class GuidedDecoding: public BaseGenerationParam {
public:
    explicit GuidedDecoding(const BaseGenerationParam& base, const comm::HostComm& tp_group, int phases);

    void Setup(int phase, TensorMap& env);

    void FillMask(int phase, TensorMap& env);

    void ApplyMask(int phase, TensorMap& env);

    void Update(int phase, TensorMap& env);

private:
    comm::HostComm tp_group_;

    struct Data;
    std::vector> data_;

    Tensor_ bitmask_buf_;
    Buffer_     output_ids_buf_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/logits_processor.cc
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/request.h"

#include "src/turbomind/kernels/ban_bad_words.h"
#include "src/turbomind/kernels/sampling_penalty_kernels.h"

#include "src/turbomind/generation/logits_processor.h"
#include "src/turbomind/generation/utils.h"

namespace turbomind {

struct LogitsProcessor::Data {

    Data(int max_batch_size, DeviceType device)
    {
        repetition_penalty_buf = {max_batch_size, device};
        min_lengths_buf        = {max_batch_size, device};
        temperature_buf        = {max_batch_size, device};
        bad_words_buf          = {max_batch_size * 2 * kMaxStopBadWordsLen, device};
        end_ids_buf            = {max_batch_size * kMaxEndIdsSize, device};
    }

    Buffer_ repetition_penalty_buf;
    Buffer_   min_lengths_buf;
    Buffer_ temperature_buf;
    Buffer_   bad_words_buf;
    Buffer_   end_ids_buf;

    Tensor_ bad_words_ten;
    Tensor_ end_ids_ten;

    bool has_repetition_penalty{};
    bool has_bad_words_penalty{};
    bool has_min_length_penalty{};
    bool has_temperature_penalty{};
};

LogitsProcessor::LogitsProcessor(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}
{
    buf_ = std::make_shared(max_batch_size_, kCPUpinned);
    for (int i = 0; i < phases; ++i) {
        data_.push_back(std::make_shared(max_batch_size_, kDEVICE));
    }
}

void LogitsProcessor::Forward(int phase, TensorMap& env)
{
    // apply repetition penalty -> ban bad words -> min length penalty -> temperature penalty
    // the order is same with transformerss
    TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);

    Tensor_      logits          = env.at("logits");
    const Buffer_ token_ids_ptrs  = env.at("token_ids_ptrs").buffer();
    const Buffer_  sequence_length = env.at("sequence_length").buffer();

    const auto bsz = logits.shape(0);

    auto& d = *data_.at(phase);

    auto stream = core::Context::stream().handle();

    // repetition penalty
    if (d.has_repetition_penalty) {
        ApplyRepetitionPenalty(logits, d.repetition_penalty_buf, token_ids_ptrs, sequence_length, stream);
        sync_check_cuda_error();
    }

    // ban bad words
    if (auto& bad_words = d.bad_words_ten) {
        BanBadWords(logits, token_ids_ptrs, sequence_length, bad_words, stream);
        sync_check_cuda_error();
    }

    // min length
    if (d.has_min_length_penalty) {
        invokeMinLengthPenalty(logits.data(),
                               d.min_lengths_buf.data(),
                               sequence_length.data(),
                               vocab_size_padded_,
                               bsz,
                               d.end_ids_ten.data(),
                               d.end_ids_ten.shape(1),
                               stream);
        sync_check_cuda_error();
    }

    // temperature
    if (d.has_temperature_penalty) {
        invokeBatchApplyTemperaturePenalty_v2(logits.data(),  //
                                              (float*)nullptr,
                                              d.temperature_buf.data(),
                                              bsz,
                                              vocab_size_,
                                              vocab_size_padded_,
                                              stream);
        sync_check_cuda_error();
    }

    TM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

void LogitsProcessor::Setup(int phase, TensorMap& env)
{
    TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);

    auto& d = *data_.at(phase);

    const auto& rs   = env.at("batch").data()[0]->rc;
    auto&       copy = *env.at("copy").data()[0];

    const int bsz = rs.size();

    auto& repetition_penalty = buf_->repetition_penalty_buf;
    auto& temperature        = buf_->temperature_buf;
    auto& min_lengths        = buf_->min_lengths_buf;

    d.has_temperature_penalty = {};
    d.has_min_length_penalty  = {};
    d.has_repetition_penalty  = {};
    d.has_bad_words_penalty   = {};

    for (int i = 0; i < bsz; ++i) {
        auto& g = rs[i]->gen_cfg;

        // repetition_penalty
        repetition_penalty[i] = g.repetition_penalty;
        if (repetition_penalty[i] != 1.f) {
            d.has_repetition_penalty = true;
        }

        // temperature
        temperature[i] = g.temperature;
        if (g.temperature != 1.f) {
            d.has_temperature_penalty = true;
        }

        // min_length
        min_lengths[i] = rs[i]->prompt_len + g.min_new_tokens;
        if (rs[i]->seq_len + rs[i]->beta < min_lengths[i]) {
            d.has_min_length_penalty = true;
        }
    }

    if (d.has_temperature_penalty) {
        copy(temperature, bsz, d.temperature_buf);
    }

    if (d.has_repetition_penalty) {
        copy(repetition_penalty, bsz, d.repetition_penalty_buf);
    }

    if (d.has_min_length_penalty) {
        copy(min_lengths, bsz, d.min_lengths_buf);
    }

    sync_check_cuda_error();

    d.bad_words_ten = {};
    init_stop_bad_words(&GenerationConfig::bad_ids,  //
                        "bad_words",
                        rs,
                        buf_->bad_words_buf.data(),
                        d.bad_words_buf.data(),
                        d.bad_words_ten,
                        copy);

    if (d.has_min_length_penalty) {  // end ids for min length
        d.end_ids_ten  = {};
        int max_length = 0;
        for (int i = 0; i < bsz; ++i) {
            max_length = std::max(max_length, (int)rs[i]->gen_cfg.eos_ids.size());
        }
        if (max_length) {
            max_length     = std::min(max_length, kMaxEndIdsSize);
            int* h_end_ids = buf_->end_ids_buf.data();
            std::fill(h_end_ids, h_end_ids + std::min(kMaxEndIdsSize, max_length) * bsz, -1);
            for (int i = 0; i < bsz; ++i) {
                const auto& eos_ids = rs[i]->gen_cfg.eos_ids;
                if (eos_ids.size() == 0) {
                    continue;
                }
                if (TM_UNLIKELY(eos_ids.size() > kMaxEndIdsSize)) {
                    TM_LOG_WARNING("[InitializeSampling] [%ld] eos length (%d) exceeds %d, truncated to %d",
                                   (long)rs[i]->req->id,
                                   (int)eos_ids.size(),
                                   kMaxEndIdsSize,
                                   kMaxEndIdsSize);
                }
                std::copy_n(eos_ids.begin(), std::min((int)eos_ids.size(), kMaxEndIdsSize), h_end_ids);
                h_end_ids += max_length;
            }
            copy(buf_->end_ids_buf, bsz * max_length, d.end_ids_buf);
            d.end_ids_ten = {d.end_ids_buf.data(), {bsz, max_length}, kDEVICE};
        }
    }

    TM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/logits_processor.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 

#include "src/turbomind/core/core.h"

#include "src/turbomind/generation/base_param.h"

namespace turbomind {

class LogitsProcessor: public BaseGenerationParam {
public:
    explicit LogitsProcessor(const BaseGenerationParam& base, int phases);

    void Setup(int phase, TensorMap& env);

    void Forward(int phase, TensorMap& env);

private:
    struct Data;

    std::vector> data_;

    std::shared_ptr buf_;  // temp host buffer
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/sampling.cc
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/generation/sampling.h"

#include "src/turbomind/kernels/sampling_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/request.h"

#include "src/turbomind/utils/constant.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

struct SamplingData {

    explicit SamplingData(int max_batch_size, DeviceType device)
    {
        top_k_buf = {max_batch_size, device};
        top_p_buf = {max_batch_size, device};
        min_p_buf = {max_batch_size, device};
        kept_buf  = {max_batch_size, device};

        sampled_logprobs = {max_batch_size * (ssize_t)kMaxLogProb, device};
        sampled_indices  = {max_batch_size * (ssize_t)kMaxLogProb, device};
        sampled_nums     = {max_batch_size, device};
    }

    int   max_topk = 0;
    int   min_topk = 0;
    float min_topp = 0;
    float max_minp = 0;

    Buffer_   top_k_buf;
    Buffer_ top_p_buf;
    Buffer_ min_p_buf;

    Buffer_ kept_buf;  // kept sample

    bool output_logprobs = 0;

    Buffer_ sampled_logprobs;
    Buffer_   sampled_indices;
    Buffer_   sampled_nums;
};

Sampling::Sampling(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}
{
    top_k_ = {max_batch_size_, kCPUpinned};
    top_p_ = {max_batch_size_, kCPUpinned};
    min_p_ = {max_batch_size_, kCPUpinned};
    kept_  = {max_batch_size_, kCPUpinned};

    sampled_logprobs_buf_ = {max_batch_size_ * (ssize_t)kMaxLogProb, kCPUpinned};
    sampled_indices_buf_  = {max_batch_size_ * (ssize_t)kMaxLogProb, kCPUpinned};
    sampled_nums_buf_     = {max_batch_size_, kCPUpinned};

    // constant array
    std::fill_n(kept_.data(), max_batch_size_, vocab_size_);

    for (int i = 0; i < phases; ++i) {
        data_.push_back(std::make_shared(max_batch_size_, kDEVICE));
    }
}

void Sampling::Forward(int phase, TensorMap& args)
{
    // step1:
    //  - use topk / topp_minp kernel to sort and filter the scores
    //  - softmax the left score
    // step2:
    //  - sampling from left and sorted scores

    TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);

    auto& d = *data_.at(phase);

    Tensor_ logits = args.at("logits");

    const auto bsz = logits.shape(0);

    Buffer_ indices(bsz * vocab_size_padded_, kDEVICE);

    auto stream = core::Context::stream().handle();

    // use topk sort if some request use topk filter
    if (d.max_topk > 0) {
        // TODO: top_k >= 64 is much slower than torch.topk()
        TopKSortFilterParams params{};
        params.logits            = logits.data();
        params.sorted_logits     = logits.data();
        params.sorted_indices    = indices.data();
        params.kept              = d.kept_buf.data();
        params.top_ks            = d.top_k_buf.data();
        params.max_top_k         = d.max_topk;
        params.batch_size        = bsz;
        params.vocab_size        = vocab_size_;
        params.vocab_size_padded = vocab_size_padded_;
        invokeTopKSortFilter(params, stream);
    }

    // use topp sort if some request skip topk filter
    if (d.min_topk == 0) {
        invokeSoftmax(logits.data(), vocab_size_padded_, vocab_size_, bsz, d.kept_buf.data(), stream);

        TopPSortParams params{};
        params.logits            = logits.data();
        params.sorted_logits     = logits.data();
        params.sorted_indices    = indices.data();
        params.kept              = d.kept_buf.data();
        params.top_ks            = d.top_k_buf.data();
        params.top_ps            = d.top_p_buf.data();
        params.batch_size        = bsz;
        params.vocab_size        = vocab_size_;
        params.vocab_size_padded = vocab_size_padded_;
        invokeTopPSort(params, stream);
    }

    // apply topp minp filter
    if (d.max_minp != 0.f || d.min_topp != 1.f) {
        TopPMinPFilterParams params{};
        params.sorted_logits     = logits.data();
        params.sorted_indices    = indices.data();
        params.kept              = d.kept_buf.data();
        params.top_ps            = d.top_p_buf.data();
        params.min_ps            = d.min_p_buf.data();
        params.batch_size        = bsz;
        params.vocab_size        = vocab_size_;
        params.vocab_size_padded = vocab_size_padded_;
        invokeTopPMinPFilter(params, stream);
    }

    // sample
    {
        SamplingParams params{};
        params.logits          = logits.data();
        params.stride          = vocab_size_padded_;
        params.indices         = indices.data();
        params.kept            = d.kept_buf.data();
        params.curandstate     = (curandState_t*)args.at("curand_state").raw_data();
        params.batch_size      = bsz;
        params.output_ids      = args.at("output_ids").data();  // (B, 1)
        params.sequence_length = args.at("sequence_length").data();

        if (d.output_logprobs) {
            params.sampled_logprobs = d.sampled_logprobs.data();
            params.sampled_indexes  = d.sampled_indices.data();
            params.sampled_nums     = d.sampled_nums.data();
        }

        invokeSampling(params, stream);
        sync_check_cuda_error();
    }

    TM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

void Sampling::Setup(int phase, TensorMap& env)
{

    const auto& rc   = env.at("batch").data()[0]->rc;
    auto&       copy = *env.at("copy").data()[0];

    const auto bsz = rc.size();

    for (int i = 0; i < bsz; ++i) {
        top_k_[i] = rc[i]->gen_cfg.top_k;
        top_p_[i] = rc[i]->gen_cfg.top_p;
        min_p_[i] = rc[i]->gen_cfg.min_p;
    }

    auto& d = *data_.at(phase);

    d.max_topk = *std::max_element(top_k_.begin(), top_k_.begin() + bsz);
    d.min_topk = *std::min_element(top_k_.begin(), top_k_.begin() + bsz);
    d.min_topp = *std::min_element(top_p_.begin(), top_p_.begin() + bsz);
    d.max_minp = *std::max_element(min_p_.begin(), min_p_.begin() + bsz);

    copy(top_k_.data(), bsz, d.top_k_buf.data());
    copy(top_p_.data(), bsz, d.top_p_buf.data());

    copy(min_p_.data(), bsz, d.min_p_buf.data());
    copy(kept_.data(), bsz, d.kept_buf.data());

    d.output_logprobs = std::any_of(rc.begin(), rc.end(), [](auto& x) { return x->gen_cfg.output_logprobs; });
}

void Sampling::Fetch(int phase, TensorMap& env)
{
    auto& d    = *data_.at(phase);
    auto& b    = *env.at("batch").data()[0];
    auto& copy = *env.at("copy").data()[0];

    if (d.output_logprobs) {
        copy(d.sampled_logprobs, b.bsz * kMaxLogProb, sampled_logprobs_buf_);
        copy(d.sampled_indices, b.bsz * kMaxLogProb, sampled_indices_buf_);
        copy(d.sampled_nums, b.bsz, sampled_nums_buf_);
    }
}

void Sampling::Update(int phase, TensorMap& env)
{
    auto& d = *data_.at(phase);
    auto& b = *env.at("batch").data()[0];

    if (d.output_logprobs) {
        float* logprob_buf = sampled_logprobs_buf_.data();
        int*   indices_buf = sampled_indices_buf_.data();
        int*   n_buf       = sampled_nums_buf_.data();
        for (int i = 0; i < b.rc.size(); ++i) {
            if (auto& x = *b.rc[i]; x.gen_cfg.output_logprobs) {
                // output buffers
                auto logprob_out = x.req->outputs.at("logprob_vals").data();
                auto indices_out = x.req->outputs.at("logprob_indexes").data();
                auto n_out       = x.req->outputs.at("logprob_nums").data();
                // offset into output buffers
                const int offset = x.seq_len - x.prompt_len;
                std::copy_n(logprob_buf + i * kMaxLogProb, n_buf[i], logprob_out + offset * kMaxLogProb);
                std::copy_n(indices_buf + i * kMaxLogProb, n_buf[i], indices_out + offset * kMaxLogProb);
                n_out[offset] = n_buf[i];
            }
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/sampling.h
================================================

#pragma once

#include "src/turbomind/core/core.h"
#include "src/turbomind/generation/base_param.h"

namespace turbomind {

struct SamplingData;

class Sampling: public BaseGenerationParam {
public:
    explicit Sampling(const BaseGenerationParam& base, int phases);

    void Setup(int phase, TensorMap& env);

    void Forward(int phase, TensorMap& env);

    void Fetch(int phase, TensorMap& env);

    void Update(int phase, TensorMap& env);

private:
    std::vector> data_;

    // host buffer
    Buffer_   kept_;
    Buffer_   top_k_;
    Buffer_ top_p_;
    Buffer_ min_p_;

    Buffer_ sampled_logprobs_buf_;
    Buffer_   sampled_indices_buf_;
    Buffer_   sampled_nums_buf_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/stop_criteria.cc
================================================


#include "src/turbomind/generation/stop_criteria.h"
#include "src/turbomind/generation/utils.h"

#include "src/turbomind/kernels/stop_criteria_kernels.h"

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/request.h"

namespace turbomind {

struct StopCriteriaData {
    explicit StopCriteriaData(int batch_size)
    {
        stop_words  = {batch_size * 2 * kMaxStopBadWordsLen, kDEVICE};
        max_seq_len = {batch_size, kDEVICE};
    }
    Buffer_ stop_words;
    Buffer_ max_seq_len;
    Tensor_ stop_words_ten;  // reference int `stop_words`
};

StopCriteria::StopCriteria(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}
{
    stop_words_buf_  = {max_batch_size_ * 2 * kMaxStopBadWordsLen, kCPUpinned};
    max_seq_len_buf_ = {max_batch_size_, kCPUpinned};
    for (int i = 0; i < phases; ++i) {
        data_.push_back(std::make_shared(max_batch_size_));
    }
}

void StopCriteria::Setup(int phase, TensorMap& env)
{
    auto& d = *data_.at(phase);

    const auto& rs   = env.at("batch").data()[0]->rc;
    auto&       copy = *env.at("copy").data()[0];

    for (int i = 0; i < rs.size(); ++i) {
        max_seq_len_buf_[i] = rs[i]->max_seq_len;
    }
    copy(max_seq_len_buf_, rs.size(), d.max_seq_len);

    d.stop_words_ten = {};
    init_stop_bad_words(&GenerationConfig::stop_ids,  //
                        "stop_words",
                        rs,
                        stop_words_buf_.data(),
                        d.stop_words.data(),
                        d.stop_words_ten,
                        copy);
}

void StopCriteria::Forward(int phase, TensorMap& env)
{
    auto& d = *data_.at(phase);

    const Buffer_ token_ids_ptrs  = env.at("token_ids_ptrs").buffer();
    const Buffer_  sequence_length = env.at("sequence_length").buffer();

    Buffer_ finished = env.at("finished").buffer();

    const int batch_size = token_ids_ptrs.size();

    auto stream = core::Context::stream().handle();

    if (auto& stop_words = d.stop_words_ten) {
        TM_CHECK_EQ(stop_words.ndim(), 3);  // [batch, 2, len]
        size_t stop_words_len = stop_words.shape(2);
        invokeStopWordsCriterion_v2((const int**)token_ids_ptrs.data(),
                                    sequence_length.data(),
                                    stop_words.data(),
                                    finished.data(),
                                    stop_words_len,
                                    batch_size,
                                    stream);
        sync_check_cuda_error();
    }

    invokeLengthCriterion_v2(finished.data(),  //
                             sequence_length.data(),
                             d.max_seq_len.data(),
                             batch_size,
                             stream);
    sync_check_cuda_error();
}

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/stop_criteria.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "src/turbomind/core/core.h"

#include "src/turbomind/generation/base_param.h"

namespace turbomind {

struct StopCriteriaData;

class StopCriteria: public BaseGenerationParam {
public:
    explicit StopCriteria(const BaseGenerationParam& base, int phases);

    void Setup(int phase, TensorMap& env);

    void Forward(int phase, TensorMap& env);

private:
    std::vector> data_;

    Buffer_ stop_words_buf_;
    Buffer_ max_seq_len_buf_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/generation/utils.h
================================================

#include 
#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

constexpr int kMaxStopBadWordsLen = 32;
constexpr int kMaxEndIdsSize      = 32;

namespace {

template
void init_stop_bad_words(G getter, const char* key, const Rs& rs, T* h_buf, T* d_buf, Tensor_& out, Copy& copy)
{
    const int bsz        = rs.size();
    int       max_length = 0;

    std::vector> copy_tokens(bsz);
    std::vector> copy_offsets(bsz);
    for (int i = 0; i < bsz; ++i) {
        const auto& [token_ids, offsets] = std::invoke(getter, rs[i]->gen_cfg);
        if (offsets.size() == 0 || token_ids.size() == 0) {
            continue;
        }
        FT_CHECK(offsets.back() == token_ids.size());
        if (offsets.back() <= kMaxStopBadWordsLen) {
            copy_tokens[i]  = std::make_pair(token_ids.data(), (int)token_ids.size());
            copy_offsets[i] = std::make_pair(offsets.data(), (int)offsets.size());
            max_length      = std::max(max_length, (int)token_ids.size());
        }
        else {
            auto trunc_offset_size =
                std::upper_bound(offsets.begin(),
                                 offsets.begin() + std::min(kMaxStopBadWordsLen, (int)offsets.size()),
                                 kMaxStopBadWordsLen)
                - offsets.begin();
            TM_LOG_WARNING("[InitializeSampling] [%ld] %s length (%d) exceeds %d, truncated to %d",
                           rs[i]->req->id,
                           key,
                           offsets.back(),
                           kMaxStopBadWordsLen,
                           trunc_offset_size);
            if (trunc_offset_size > 0) {
                int trunc_token_size = offsets[trunc_offset_size - 1];
                copy_tokens[i]       = std::make_pair(token_ids.data(), trunc_token_size);
                copy_offsets[i]      = std::make_pair(offsets.data(), trunc_offset_size);
                max_length           = std::max(max_length, trunc_token_size);
            }
        }
    }
    if (!max_length) {
        return;
    }
    std::fill_n(h_buf, bsz * 2 * max_length, -1);
    for (int i = 0; i < bsz; ++i) {
        if (copy_tokens[i].first != nullptr) {
            std::copy_n(copy_tokens[i].first, copy_tokens[i].second, h_buf + i * 2 * max_length);
        }
        if (copy_offsets[i].first != nullptr) {
            std::copy_n(copy_offsets[i].first, copy_offsets[i].second, h_buf + i * 2 * max_length + max_length);
        }
    }
    copy(h_buf, bsz * 2 * max_length, d_buf);
    // Construct a tensor from the device buffer
    out = {d_buf, {bsz, 2, max_length}, kDEVICE};
};

}  // namespace

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/CMakeLists.txt
================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.11)

add_library(ban_bad_words STATIC ban_bad_words.cu)
set_property(TARGET ban_bad_words PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET ban_bad_words PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(stop_criteria STATIC stop_criteria_kernels.cu)
set_property(TARGET stop_criteria PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET stop_criteria PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(activation_kernels STATIC activation_kernels.cu)
set_property(TARGET activation_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET activation_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(activation STATIC activation.cu)
set_property(TARGET activation PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET activation PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

add_library(quantization_kernels STATIC quantization.cu)
set_property(TARGET quantization_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET quantization_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

if (BUILD_TEST)
add_executable(test_quantization test_quantization.cc gemm/test/test_utils.cu)
target_link_libraries(test_quantization PRIVATE quantization_kernels core)
endif ()

add_library(logprob_kernels STATIC logprob_kernels.cu)
set_property(TARGET logprob_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET logprob_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(unfused_attention_kernels STATIC unfused_attention_kernels.cu)
set_property(TARGET unfused_attention_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET unfused_attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(decoding_kernels STATIC decoding_kernels.cu)
set_property(TARGET decoding_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET decoding_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(gpt_kernels STATIC gpt_kernels.cu)
set_property(TARGET gpt_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET gpt_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(sampling_topk_kernels STATIC sampling_topk_kernels.cu)
set_property(TARGET sampling_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET sampling_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(sampling_topp_kernels STATIC sampling_topp_kernels.cu)
set_property(TARGET sampling_topp_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET sampling_topp_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(sampling_penalty_kernels STATIC sampling_penalty_kernels.cu)
set_property(TARGET sampling_penalty_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(sampling_kernels STATIC sampling_kernels.cu)
set_property(TARGET sampling_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET sampling_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_library(apply_token_bitmask_inplace_cuda STATIC apply_token_bitmask_inplace_cuda.cu)
set_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)

add_subdirectory(attention)
add_subdirectory(gemm)
add_subdirectory(norm)


================================================
FILE: src/turbomind/kernels/activation.cu
================================================

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"

namespace turbomind {

template
struct SiluGptOss {
    __device__ T operator()(T gate, T up) const noexcept
    {
        gate = __hmin((T)7.f, gate);
        up   = __hmax((T)-7.f, __hmin((T)7.f, up));
        return static_cast(fdividef((float)gate, 1.f + expf((float)-gate * 1.702f)) * (1.f + (float)up));
    }
};

template
struct Silu {
    __device__ T operator()(T gate, T up) const noexcept
    {
        return static_cast(fdividef((float)gate, 1.f + expf(-(float)gate)) * (float)up);
    }
};

template
__global__ void ActivationKernel(
    T* gate_buf, const T* __restrict__ up_buf, Activation activation, int64_t stride, int token_num, int dim)
{
    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v)) {
        const int di = threadIdx.x + blockIdx.y * blockDim.x;
        const int ti = blockIdx.x;

        dim /= vec_size;

        if (di >= dim) {
            return;
        }

        using Vec = Array;

        auto p_gate = reinterpret_cast(gate_buf + ti * stride);
        auto p_up   = reinterpret_cast(up_buf + ti * stride);

        Vec gate;
        Load(gate, (const T*)&p_gate[di]);

        Vec up;
        Ldg(up, (T*)&p_up[di]);

        PRAGMA_UNROLL
        for (int i = 0; i < vec_size; ++i) {
            gate[i] = activation(gate[i], up[i]);
        }

        Store((T*)&p_gate[di], gate);
    }
}

void Activation(Ref gate_, const Tensor& up, ActivationType type, cudaStream_t stream)
{
    auto& gate = gate_.get();

    TM_CHECK(gate.shape() == up.shape());

    int num, dim;
    std::tie(num, dim) = gate.shapes(0, 1);

    auto invoke = [&](auto t, auto act) {
        using T = decltype(t);

        constexpr int vec_size = 4;
        constexpr int threads  = 512;

        const dim3 blocks(num, cdiv(dim, threads * vec_size));

        ActivationKernel<<>>(gate.data(),  //
                                                                   up.data(),
                                                                   act,
                                                                   gate.stride(0),
                                                                   num,
                                                                   dim);
    };

    auto dispatch = [&](auto t) {
        using T = decltype(t);
        if (type == ActivationType::kSilu) {
            return invoke(t, Silu{});
        }
        else if (type == ActivationType::kSiluGptOss) {
            return invoke(t, SiluGptOss{});
        }
        else {
            TM_CHECK(0) << "unknown activation type: " << (int)type;
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(gate.dtype(), dispatch);
}

template
__global__ void ActivationKernel(
    T* gate_up, const T* bias, const int* group_ids, int64_t stride, Activation activation, int token_num, int dim)
{
    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v)) {
        const int di = (threadIdx.x + blockIdx.y * blockDim.x) * vec_size;
        const int ti = blockIdx.x;
        const int gi = group_ids ? group_ids[ti] : 0;

        if (di >= dim) {
            return;
        }

        using Vec = Array;

        Vec gate_bias{}, up_bias{};
        Ldg(gate_bias, &bias[gi * stride + di]);
        Ldg(up_bias, &bias[gi * stride + dim + di]);

        Vec gate, up;
        Load(gate, &gate_up[ti * stride + di]);
        Load(up, &gate_up[ti * stride + dim + di]);

        {
            using namespace ops;
            gate = gate + gate_bias;
            up   = up + up_bias;
        }

        PRAGMA_UNROLL
        for (int i = 0; i < vec_size; ++i) {
            gate[i] = activation(gate[i], up[i]);
        }

        Store(&gate_up[ti * stride + di], gate);
    }
}

void Activation(Tensor&             gate_up,  //
                const Tensor&       bias,
                const Buffer_& group_ids,
                ActivationType      type,
                cudaStream_t        stream)
{
    const int num = gate_up.shape(0);
    const int dim = gate_up.shape(1) / 2;

    if (!bias) {
        Activation(gate_up.slice({0, 0}, {-1, dim}),  //
                   gate_up.slice({0, dim}, {-1, -1}),
                   type,
                   stream);
        return;
    }

    TM_CHECK_EQ(gate_up.shape(-1), bias.shape(-1));

    auto invoke = [&](auto t, auto act) {
        using T = decltype(t);

        constexpr int vec_size = 4;
        constexpr int threads  = 512;

        const dim3 blocks(num, cdiv(dim, threads * vec_size));

        ActivationKernel<<>>(gate_up.data(),  //
                                                                   bias.data_or((T*)nullptr),
                                                                   group_ids.data_or(nullptr),
                                                                   gate_up.stride(0),
                                                                   act,
                                                                   num,
                                                                   dim);
    };

    auto dispatch = [&](auto t) {
        using T = decltype(t);
        if (type == ActivationType::kSilu) {
            return invoke(t, Silu{});
        }
        else if (type == ActivationType::kSiluGptOss) {
            return invoke(t, SiluGptOss{});
        }
        else {
            TM_CHECK(0) << "unknown activation type: " << (int)type;
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(gate_up.dtype(), dispatch);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/activation.h
================================================
#pragma once

#include "src/turbomind/core/core.h"

namespace turbomind {

enum class ActivationType
{
    kSilu,
    kSiluGptOss
};

void Activation(Ref gate, const Tensor& up, ActivationType type, cudaStream_t stream);

void Activation(Tensor&             gate_up,  //
                const Tensor&       bias,
                const Buffer_& group_ids,
                ActivationType      type,
                cudaStream_t        stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/activation_kernels.cu
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/core/core.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/activation_kernels.h"
#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif

namespace turbomind {

/* Gelu Activation */

__forceinline__ __device__ float copysignf_pos(float a, float b)
{
    float r;
    r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
    return r;
}

__inline__ __device__ float tanh_opt(float x)
{
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
    float r;
    asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
    return r;
#else
    const float exp_val = -1.f * fabs(2 * x);
    return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif
}

template
struct GeluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val))));
        return val * cdf;
    }
};

template<>
struct GeluActivation {
    using return_type = half2;
    static __device__ __forceinline__ half2 apply(const half2& val)
    {
        half2  val_pow3 = __hmul2(val, __hmul2(val, val));
        float2 tmp_pow  = __half22float2(val_pow3);
        float2 tmp      = __half22float2(val);

        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
        return __hmul2(val, __float22half2_rn(tmp));
    }
};

#ifdef ENABLE_BF16
template<>
struct GeluActivation<__nv_bfloat162> {
    using return_type = __nv_bfloat162;
    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
    {
        __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val));
        float2         tmp_pow  = bf1622float2(val_pow3);
        float2         tmp      = bf1622float2(val);

        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
        return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y));
    }
};
#endif

/* Relu Activation */

template
struct ReluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return val > static_cast(0.0f) ? val : static_cast(0.0f);
    }
};

template<>
struct ReluActivation {
    using return_type = half2;
    static __device__ __forceinline__ half2 apply(const half2& val)
    {
        const half zero_half = static_cast(0.0f);
        return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half);
    }
};

#ifdef ENABLE_BF16
template<>
struct ReluActivation<__nv_bfloat162> {
    using return_type = __nv_bfloat162;
    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
    {
        const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
        return turbomind::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#else
        return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#endif
    }
};
#endif

/* Silu Activation */

template
struct SiluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return (T)((float)val / (1.0f + __expf((float)-val)));
    }
};

template<>
struct SiluActivation {
    using return_type = float2;
    static __device__ __forceinline__ float2 apply(const half2& val)
    {
        return make_float2(SiluActivation::apply(val.x), SiluActivation::apply(val.y));
    }
};

#ifdef ENABLE_BF16
template<>
struct SiluActivation<__nv_bfloat162> {
    using return_type = float2;
    static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val)
    {
        return make_float2(SiluActivation::apply(val.x), SiluActivation::apply(val.y));
    }
};
#endif  // ENABLE_BF16

/* Identity Activation (= no activation) */

template
struct IdentityActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return val;
    }
};

// `output` may be an alias of `inter_buf`
template class Activation, typename T>
__global__ void activation_kernel(T* inter_buf, const T* __restrict__ gate_buf, int64_t stride, int token_num, int dims)
{
    const int di = threadIdx.x + blockIdx.y * blockDim.x;
    const int ti = blockIdx.x;

    dims /= VecSize;

    if (di >= dims) {
        return;
    }

    using Vec = Array;

    auto p_inter = reinterpret_cast(inter_buf + ti * stride);
    auto p_gate  = reinterpret_cast(gate_buf + ti * stride);

    Vec inter;
    Load(inter, (T*)&p_inter[di]);

    Vec gate;
    Ldg(gate, (const T*)&p_gate[di]);

    PRAGMA_UNROLL
    for (int i = 0; i < VecSize; ++i) {
        inter[i] = Activation::apply(inter[i]) * gate[i];
    }

    Store((T*)&p_inter[di], inter);
}

template class Activation, typename T>
void invokeGenericActivation_v2(
    T* inter_buf, const T* __restrict__ gate_buf, int64_t stride, int token_num, int dims, cudaStream_t stream)
{
    constexpr int kVecSize = 4;

    constexpr int block = 512;
    const dim3    grid(token_num, ceil_div(dims, block * kVecSize));

    activation_kernel
        <<>>(inter_buf, gate_buf, stride, token_num, dims);
}

template class Activation>
void invokeGenericActivation_v3(Ref inter_, const Tensor& gate, cudaStream_t stream)
{
    auto& inter = inter_.get();
    TM_CHECK_EQ(inter.ndim(), 2);
    TM_CHECK_EQ(gate.ndim(), 2);
    TM_CHECK_EQ(inter.stride(0), gate.stride(0));

    TM_CHECK(inter.shape() == gate.shape());

    auto invoke = [&](auto t) {
        using T = decltype(t);

        const auto [num, dim] = inter.shapes(0, 1);

        constexpr int kVecSize = 4;
        constexpr int block    = 512;

        const dim3 grid(num, cdiv((int)dim, block * kVecSize));

        activation_kernel
            <<>>(inter.data(), gate.data(), inter.stride(0), num, dim);
    };

    TM_DISPATCH_PRIMARY_DTYPES(inter.dtype(), invoke);
}

template void invokeGenericActivation_v3(Ref inter_, const Tensor& gate, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/activation_kernels.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

// clang-format off
template struct GeluActivation;
template struct ReluActivation;
template struct SiluActivation;
template struct IdentityActivation;
// clang-format on

template class Activation>
void invokeGenericActivation_v3(Ref inter_, const Tensor& gate, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu
================================================
// Modified from xgrammar python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu

/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// clang-format off
#include 
#include 
#include 

#include "src/turbomind/core/context.h"
#include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h"
// clang-format on

using namespace std;

#ifndef CUDART_INF_FP16
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
#endif

#if __CUDA_ARCH__ >= 800
#ifndef CUDART_INF_BF16
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#endif
#endif

constexpr int32_t BITS_PER_BLOCK           = 32;
constexpr int32_t THREADS_PER_THREAD_BLOCK = 256;

template
__device__ T NegativeInfinity()
{
    return -INFINITY;
}

template<>
__device__ __half NegativeInfinity<__half>()
{
    return -CUDART_INF_FP16;
}

#if __CUDA_ARCH__ >= 800
template<>
__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>()
{
    return -CUDART_INF_BF16;
}
#endif

template
__device__ PackedT PackedNegativeInfinity()
{
    constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
    T             packed[kAlignment];
#pragma unroll
    for (int i = 0; i < kAlignment; i++) {
        packed[i] = NegativeInfinity();
    }
    return *reinterpret_cast(packed);
}

template
__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(T* __restrict__ logits,
                                                                                const int32_t* __restrict__ bitmask,
                                                                                const int32_t* __restrict__ indices,
                                                                                int32_t vocab_size,
                                                                                int32_t logits_stride,
                                                                                int32_t bitmask_stride)
{
    constexpr int      kAlignment  = sizeof(PackedT) / sizeof(T);
    constexpr uint32_t kPackedMask = (1 << kAlignment) - 1;

    const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y];

    const int      block_offset      = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread;
    T*             logits_gmem_ptr   = logits + batch_idx * logits_stride + block_offset;
    const int32_t* bitmask_gmem_ptr  = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK;
    const int      bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment);
    T              logits_reg[kAlignment];

#pragma unroll
    for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread;
         offset += THREADS_PER_THREAD_BLOCK * kAlignment) {
        if (block_offset + offset >= vocab_size) {
            break;
        }

        const uint32_t bitmask_val =
            (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask;

        if (bitmask_val == 0) {
            continue;
        }

        if (bitmask_val == kPackedMask) {
            *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity();
            continue;
        }

        *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset);
#pragma unroll
        for (int i = 0; i < kAlignment; i++) {
            if (((bitmask_val >> i) & 1)) {
                logits_reg[i] = NegativeInfinity();
            }
        }
        *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg);
    }
}

template::value>>
constexpr auto CeilDiv(T numerator, T denominator)
{
    return (numerator + denominator - 1) / denominator;
}

template
void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits,
                                                     const int32_t* __restrict__ bitmask,
                                                     const int32_t* __restrict__ indices,
                                                     int32_t vocab_size,
                                                     int32_t logits_stride,
                                                     int32_t bitmask_stride,
                                                     int32_t num_rows)
{
    constexpr int kAlignment          = sizeof(PackedT) / sizeof(T);
    const int32_t num_blocks_per_row  = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
    const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);

    const dim3  block(THREADS_PER_THREAD_BLOCK);
    const auto& stream = turbomind::core::Context::stream();

    if (num_bits_per_thread <= 4 && kAlignment <= 4) {
        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
        LogitsBitmaskKernel
            <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
    }
    else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
        LogitsBitmaskKernel
            <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
    }
    else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
        LogitsBitmaskKernel
            <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
    }
    else {
        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
        LogitsBitmaskKernel
            <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
    }
}

template
void ApplyTokenBitmaskInplaceDispatchToPackedT(T* __restrict__ logits,
                                               const int32_t* __restrict__ bitmask,
                                               const int32_t* __restrict__ indices,
                                               int32_t vocab_size,
                                               int32_t logits_stride,
                                               int32_t bitmask_stride,
                                               int32_t num_rows)
{
    if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) {
        ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
            logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
    }
    else {
        ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
            logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
    }
}

namespace turbomind {
using namespace turbomind::core;

void ApplyTokenBitmaskInplace(Tensor logits, Tensor bitmask, std::optional indices)
{
    std::pair logits_shape =
        logits.ndim() == 2 ?
            std::make_pair(static_cast(logits.shape(0)), static_cast(logits.shape(1))) :
            std::make_pair(1, static_cast(logits.shape(0)));

    std::pair bitmask_shape =
        bitmask.ndim() == 2 ?
            std::make_pair(static_cast(bitmask.shape(0)), static_cast(bitmask.shape(1))) :
            std::make_pair(1, static_cast(bitmask.shape(0)));

    int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK);

    int32_t  num_rows    = logits_shape.first;
    int32_t* indices_ptr = nullptr;
    if (indices) {
        num_rows    = indices->shape(0);
        indices_ptr = indices->data();
    }
    else {
        TM_CHECK(logits_shape.first == bitmask_shape.first) << "logits and bitmask must have the same batch size.";
    }

    // Currently we use only float logits.
    TM_CHECK(logits.dtype() == kFloat32);
    ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data(),
                                              bitmask.data(),
                                              indices_ptr,
                                              vocab_size,
                                              logits.stride(0),
                                              bitmask.stride(0),
                                              num_rows);
}
}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h
================================================
#include "src/turbomind/core/tensor.h"

namespace turbomind {
void ApplyTokenBitmaskInplace(core::Tensor                logits,
                              core::Tensor                bitmask,
                              std::optional indices = std::nullopt);
}


================================================
FILE: src/turbomind/kernels/attention/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

add_subdirectory(kernel)

add_library(attention STATIC
            attention.cu
            decoding.cu
            kv_cache_utils_v2.cu
            cp_utils.cu
            registry.cu
            )
set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(attention PRIVATE -O3
    $<$:-use_fast_math --expt-relaxed-constexpr -Xptxas=-v --threads 16>)
target_link_libraries(attention PUBLIC $)
target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass)

if (BUILD_TEST)
    target_compile_options(attention PRIVATE
        $<$:-Xptxas=-v --generate-line-info>)

    add_executable(test_attention
        test_utils.cu
        test_attention.cu
        reference.cu)
    target_compile_options(test_attention PRIVATE
        --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr)
    target_link_libraries(test_attention PRIVATE
        attention
        # flash_attention
        nvidia::cutlass::cutlass
        models
        unfused_attention_kernels
        logger
        cublas)

    add_executable(test_quant test_quant.cu test_utils.cu)
    target_compile_options(test_quant PRIVATE
        --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr)
    target_link_libraries(test_quant PRIVATE
        nvidia::cutlass::cutlass
    )
endif ()


================================================
FILE: src/turbomind/kernels/attention/arch.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind::arch {

// tags for dispatching & conditional codegen

template
struct Arch {
    static constexpr bool is_compatible(int arch)
    {
        return Begin <= arch && (End == -1 || arch < End);
    }
};

struct Sm70: Arch<700, 750> {
    static constexpr int value = 700;
};

struct Sm75: Arch<750, 800> {
    static constexpr int value = 750;
};

struct Sm80: Arch<800> {
    static constexpr int value = 800;
};

inline bool is_arch_compatible(int karch, int darch)
{
    switch (karch) {
        case 0:
            return true;
        case 700:
            return Sm70::is_compatible(darch);
        case 750:
            return Sm75::is_compatible(darch);
        case 800:
            return Sm80::is_compatible(darch);
        default:
            return false;
    }
}

}  // namespace turbomind::arch


================================================
FILE: src/turbomind/kernels/attention/attention.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "attention.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/attention/registry.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

template
void dispatchAttention(const AttentionParams& params)
{
    using namespace attention;

    auto&    reg = Registry::instance();
    AttnDesc desc{};
    desc.mode      = AttnDesc::kPrefill;
    desc.head_dim  = params.size_per_head;
    desc.data_type = data_type_v;

    auto* kernel = reg.Find(desc);

    TM_CHECK(kernel) << "No attention kernel found: " + to_string(desc);

    kernel->Launch(¶ms, reg.sm_count());
}

template void dispatchAttention(const AttentionParams& params);
#if ENABLE_BF16
template void dispatchAttention(const AttentionParams& params);
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/attention.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention_params.h"

namespace turbomind {

constexpr int MAX_CTA_S = 64;

template
void dispatchAttention(const AttentionParams& params);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/attention_params.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "cutlass/fast_math.h"
#include 
#include 

#include "src/turbomind/models/llama/llama_rope.h"

namespace turbomind {

// 64-bit offsets may be needed
struct LinearIteratorParams {
    const void* kv_cache;
    int         stride_h;
    int         key_to_val;
};

struct BlockIteratorParams {
    char**     block_ptrs;
    const int* cu_block_nums;
    int        layer_id;
    int        block_len;
};

typedef void (*cp_post_fn)(void* context);

/// TODO: Rename to attention::Param
template
struct AttentionParams {
    // token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
    T*      out;
    T*      q;
    T*      k;
    T*      v;
    int64_t stride;

    // bias, [qH, D] or [kvH, D]
    T* q_bias;
    T* k_bias;
    T* v_bias;

    // sequence-level buffers
    const int*   cu_q_len;
    const int*   cu_k_len;
    const bool*  finished;
    const float* rope_theta;

    const T* sinks;
    float    scale_sinks;

    LinearIteratorParams linear_iter_params;
    BlockIteratorParams  block_iter_params;

    // batch-level params
    int token_num;
    int batch_size;
    int max_q_len;
    int max_k_len;

    // instance-level params
    int   num_heads;
    int   num_kv_heads;
    int   size_per_head;
    float inv_sqrt_dh;
    int   window_size;
    int   layer_id;  // for debugging

    // rotary embedding
    RopeKernelParam rope_param;

    // log(n) attention
    bool use_logn_attn;
    int  max_position_embeddings;

    int quant_policy;

    int    max_split_k;
    int*   split_cnt;
    float* partial_O;
    float* partial_ML;

    // context parallel
    int                 cp_rank{0};
    cutlass::FastDivmod cp_size{1};
    int                 offset_q{0};  // decode offset
    cp_post_fn          cp_fn{nullptr};
    void*               cp_fn_ctx{nullptr};

    int          arch;
    cudaStream_t stream;

    // debug
    float* qk;
    T*     pr;
};

template
struct CreateCacheIterFactory {
    template
    static CacheIterFactory apply(const Param& param)
    {
        using Tkv = typename CacheIterFactory::Tkv;
        return {(const Tkv*)param.linear_iter_params.kv_cache,
                param.cu_k_len,
                param.linear_iter_params.stride_h,
                param.linear_iter_params.key_to_val};
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/attention_template.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention_params.h"
#include "attention_universal.h"
#include "reduce.h"
#include "utils.h"

namespace turbomind {

template
void invokeAttention(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas)
{
    static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage);

    if constexpr (1) {

        [[maybe_unused]] static const int _ = [&] {
            // std::cout << __PRETTY_FUNCTION__ << std::endl;
            // std::cout << "GmemMap:\n";
            // Print(typename Kernel::Impl::ThreadMapKV{});
            // std::cout << "\nDynamic smem size: " << kSmemSize << "\n";
            return 0;
        }();
    }

    dim3 block(Kernel::kWarpCount * WARP_SIZE);

    static const auto kernel_func = &attention_kernel;

    const int max_cp_k_len    = cdiv(params.max_k_len, (int)params.cp_size);
    const int tile_count      = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);
    const int max_split_count = std::min(params.max_split_k, tile_count);

    typename Kernel::CtaMap cta_map{
        params.max_q_len, params.batch_size, params.num_heads, Kernel::CTA_Q, Kernel::CTA_H, 1};

    // grid shape when split cnt = 1
    dim3 grid = cta_map.get_grid_shape();

    const int grid_size = grid.x * grid.y * grid.z;
    const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 8);

    // printf("max split cnt: %d, split cnt: %d\n", max_split_count, split_cnt);

    // adjust split cnt and update grid shape
    cta_map.set_split_cnt(split_cnt);
    grid = cta_map.get_grid_shape();

    auto cache_iter_factory = CreateCacheIterFactory::apply(params);

    const int q_group_size = params.num_heads / params.num_kv_heads;

    kernel_func<<>>(params,
                                                           cache_iter_factory,
                                                           cta_map,
                                                           q_group_size,
                                                           1,            // q_head_per_cta
                                                           q_group_size  // cta_per_q_group
    );

    if (auto err = cudaGetLastError(); err != cudaSuccess) {
        std::cout << cudaGetErrorString(err) << "\n";
        std::abort();
    }

    if (params.cp_fn) {
        params.cp_fn(params.cp_fn_ctx);
    }

    if (split_cnt > 1 || params.cp_size > 1) {
        attention::invokeReduceV3(params.out + params.offset_q * params.num_heads * Kernel::kHeadDim,
                                                    params.partial_ML,
                                                    params.partial_O,
                                                    split_cnt > 1 ? params.split_cnt : nullptr,
                                                    params.max_split_k,
                                                    split_cnt,
                                                    params.cp_size,
                                                    params.cp_rank,
                                                    params.token_num,
                                                    params.num_heads,
                                                    params.inv_sqrt_dh,
                                                    params.stream);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/attention_universal.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include "quantization.h"

#include "src/turbomind/kernels/attention/rotary_embedding.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/math.h"

#include "attention_params.h"

namespace turbomind {

namespace attention {
struct DecodingCtaMap;
}  // namespace attention

template
struct AttentionUniversal {

    using T   = typename Mainloop::T;
    using Tkv = typename Mainloop::Tkv;

    using Impl = typename Mainloop::Impl;

    using CacheIteratorFactory = CacheIteratorFactory_;
    using CtaMap               = CtaMap_;

    using Arch = Arch_;

    static constexpr int kWarpCount = Impl::kWarpCount;

    using ParamType = AttentionParams;

    static constexpr int kHeadDim = Impl::kHeadDim;

    using FragQ = typename Impl::FragQ;
    using FragO = typename Impl::FragO;
    using FragM = typename Impl::FragM;
    using FragL = typename Impl::FragL;

    using GmemIterK = typename Mainloop::GmemIterK;
    using GmemIterV = typename Mainloop::GmemIterV;

    static constexpr int CTA_H = Impl::CTA_H;
    static constexpr int CTA_Q = Impl::CTA_Q;
    static constexpr int CTA_S = Impl::CTA_S;

    using SharedStorage = typename Mainloop::SharedStorage;

    // Only process KV inline during decoding (DecodingCtaMap), not during context attention
    // (AttentionCtaMap), even when CTA_Q == 1 (e.g. SIMT kernels).
    static constexpr bool kProcessKV = std::is_same_v;

    const int q_group_size_;
    const int q_head_per_cta_;
    const int cta_per_q_group_;

    // past-the-end hi of the CTA
    int hi_end_{1};

    __device__ bool check_h(int hi)
    {
        if constexpr (CTA_Q > 1) {
            // bypass the check for prefill kernels since `hi == 0` constantly
            return true;
        }
        else {
            return hi < hi_end_;
        }
    }

    template
    __device__ void ApplyBias(
        VecQ& vec_Q, VecKV& vec_K, VecKV& vec_V, const ParamType& params, int head_idx, int kv_head_idx, int2 offset)
    {
        using Map = typename Impl::ThreadMapQ;

        constexpr int kVecSize = Map::kAccessC;
        constexpr int ITER_C   = Map::kIterC;
        constexpr int ITER_S   = Map::kIterS;

        constexpr bool HAS_V = kHeadDim != 576;

        if constexpr (kProcessKV) {
            Array bias_K[ITER_C];
            Array bias_V[ITER_C];
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                const int di    = offset.x + c * Map::kDeltaC;
                const int k_idx = kv_head_idx * kHeadDim + di;
                if (params.k_bias) {
                    Ldg(bias_K[c], ¶ms.k_bias[k_idx]);
                }
                if (params.v_bias && HAS_V) {
                    Ldg(bias_V[c], ¶ms.v_bias[k_idx]);
                }
            }
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                using namespace ops;
                if (params.k_bias) {
                    vec_K[0][c] = vec_K[0][c] + bias_K[c];
                }
                if (params.v_bias && HAS_V) {
                    vec_V[0][c] = vec_V[0][c] + bias_V[c];
                }
            }
        }

        if constexpr (CTA_H == 1) {
            Array bias_Q[ITER_C];
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                const int di    = offset.x + c * Map::kDeltaC;
                const int q_idx = head_idx * kHeadDim + di;
                if (params.q_bias) {
                    Ldg(bias_Q[c], ¶ms.q_bias[q_idx]);
                }
            }
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    using namespace ops;
                    if (params.q_bias) {
                        vec_Q[s][c] = vec_Q[s][c] + bias_Q[c];
                    }
                }
            }
        }
        else if constexpr (CTA_Q == 1) {
            Array bias_Q[ITER_S][ITER_C];
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                const int hi = offset.y + s * Map::kDeltaS;
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    const int di    = offset.x + c * Map::kDeltaC;
                    const int q_idx = (head_idx + hi) * kHeadDim + di;
                    if (params.q_bias && check_h(hi)) {
                        Ldg(bias_Q[s][c], ¶ms.q_bias[q_idx]);
                    }
                }
            }
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    using namespace ops;
                    if (params.q_bias) {
                        vec_Q[s][c] = vec_Q[s][c] + bias_Q[s][c];
                    }
                }
            }
        }
        else {
            static_assert(CTA_Q == 1 || CTA_H == 1);
        }
    }

    template
    __device__ void Prologue(const ParamType& params,
                             T*               smem_Q,
                             FragQ&           frag_Q,
                             int              qi_begin,
                             int              qi_end,
                             int              query_idx,
                             int              head_idx,
                             int              kv_head_idx,
                             int              batch_idx,
                             int              history_len,
                             Iterator&        iterator,
                             int              warp_id,
                             int              lane_id)
    {

        using Map = typename Impl::ThreadMapQ;

        constexpr int kVecSize = Map::kAccessC;

        using Vec = Array;

        constexpr int ITER_C = Map::kIterC;
        constexpr int ITER_S = Map::kIterS;

        constexpr bool HAS_V = kHeadDim != 576;

        Vec vec_Q[ITER_S][ITER_C]{};  // [QxH, D]
        Vec vec_K[1][ITER_C];
        Vec vec_V[1][ITER_C];

        const int2 offset = Map::get_offset(warp_id, lane_id);

        // Load Q
        PRAGMA_UNROLL
        for (int s = 0; s < ITER_S; ++s) {
            const int si = offset.y + s * Map::kDeltaS;
            const int hi = si % CTA_H + head_idx;
            const int qi = si / CTA_H + qi_begin;
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                const int     di    = offset.x + c * Map::kDeltaC;
                const int64_t q_idx = qi * params.stride + hi * kHeadDim + di;
                const int64_t k_idx = qi * params.stride + kv_head_idx * kHeadDim + di;
                if (qi < qi_end) {
                    if (check_h(si % CTA_H)) {
                        Ldg(vec_Q[s][c], ¶ms.q[q_idx]);
                    }
                    if constexpr (kProcessKV) {  // duplicate loads in s
                        if (s == 0) {
                            Ldg(vec_K[0][c], ¶ms.k[k_idx]);
                            if constexpr (HAS_V) {
                                Ldg(vec_V[0][c], ¶ms.v[k_idx]);
                            }
                        }
                    }
                }
            }
        }

        ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset);

        FastRoPE rope(params.rope_param, batch_idx, std::integral_constant{});
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int di = offset.x + c * Map::kDeltaC;
            rope.init(di);
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;
                rope.apply(vec_Q[s][c], ti);
                if constexpr (kProcessKV) {
                    if (s == 0) {
                        rope.apply(vec_K[0][c], ti);
                    }
                }
            }
        }

        if (params.use_logn_attn) {
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                const int   ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;
                LogNScaling logn_scaling(ti, params.max_position_embeddings);
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    logn_scaling.apply(vec_Q[s][c]);
                }
            }
        }

        if constexpr (kProcessKV) {
            const int qi = offset.y / CTA_H;
            const int ti = history_len;

            int local_ti, local_ti_rank;
            local_ti = params.cp_size.divmod(local_ti_rank, ti);

            Array param_K[1];
            Array param_V[1];

            if constexpr (!std::is_same_v) {
                warp_stats(param_K, vec_K, bitsof);
                if constexpr (HAS_V) {
                    warp_stats(param_V, vec_V, bitsof);
                }
            }

            Array out_K[1][ITER_C];
            Array out_V[1][ITER_C];

            ConvertKvCache conv_K{param_K[0][0], param_K[0][1]};
            ConvertKvCache conv_V{param_V[0][0], param_V[0][1]};
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                out_K[0][c] = conv_K(vec_K[0][c]);
                if constexpr (HAS_V) {
                    out_V[0][c] = conv_V(vec_V[0][c]);
                }
            }

            iterator.block_head_.with(
                iterator.block_ptrs_, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {
                    if (local_ti_rank != params.cp_rank) {
                        return;
                    }
                    PRAGMA_UNROLL
                    for (int c = 0; c < ITER_C; ++c) {
                        const int di = offset.x + c * Map::kDeltaC;
                        if (qi < CTA_Q) {
                            Store(&k_cache[di], out_K[0][c]);
                            if constexpr (HAS_V) {
                                Store(&v_cache[di], out_V[0][c]);
                            }
                        }
                    }
                    if constexpr (!std::is_same_v) {
                        if (qi < CTA_Q && offset.x == 0) {
                            StoreQuantParam(k_param, param_K[0]);
                            if constexpr (HAS_V) {
                                StoreQuantParam(v_param, param_V[0]);
                            }
                        }
                    }
                });

            __syncthreads();
        }

        using SmemLayoutQ = typename Impl::SmemLayoutQ;

        SmemAccessor sQ{smem_Q};

        // Store to shared memory
        PRAGMA_UNROLL
        for (int s = 0; s < ITER_S; ++s) {
            const int si = offset.y + s * Map::kDeltaS;
            const int hi = si % CTA_H;
            const int qi = si / CTA_H;
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                const int di = offset.x + c * Map::kDeltaC;
                if (qi < CTA_Q && hi < CTA_H) {
                    Store(&sQ(si, di), vec_Q[s][c]);
                }
            }
        }

        __syncthreads();

        Impl::TransformQ(smem_Q, frag_Q);
    }

    __device__ AttentionUniversal(int q_group_size, int q_head_per_cta, int cta_per_q_group):
        q_group_size_{q_group_size}, q_head_per_cta_{q_head_per_cta}, cta_per_q_group_{cta_per_q_group}
    {
    }

    __device__ void
    operator()(const ParamType& params, CacheIteratorFactory& cache_iter_factory, const CtaMap& cta_map, char* smem_buf)
    {
        // [q, h, b]
        const int query_idx = cta_map.query_idx() * CTA_Q;  // Q offset of this sequence
        const int batch_idx = cta_map.batch_idx();
        const int split_idx = cta_map.split_idx();
        const int split_cnt = cta_map.split_count();

        int head_idx;
        int kv_head_idx;

        if constexpr (CTA_H == 1) {
            head_idx    = cta_map.head_idx();
            kv_head_idx = head_idx / q_group_size_;
        }
        else {
            int cta_h_idx = cta_map.head_idx();
            int local_idx = cta_h_idx % cta_per_q_group_ * q_head_per_cta_;
            kv_head_idx   = cta_h_idx / cta_per_q_group_;
            head_idx      = kv_head_idx * q_group_size_ + local_idx;
            hi_end_       = q_group_size_ - local_idx;
        }

        // early exit if finished flag is set
        if (params.finished[batch_idx]) {
            return;
        }

        const int qi_begin = params.cu_q_len[batch_idx] + query_idx;  // global offset into `cu_seqlens`
        const int qi_end   = params.cu_q_len[batch_idx + 1];

        if (qi_begin >= qi_end) {
            return;
        }

        const int input_len = qi_end - (qi_begin - query_idx);

        SharedStorage& storage = *(SharedStorage*)smem_buf;

        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx];
        const int history_len = context_len - input_len;

        auto get_cp_len = [&](int length, int rank) -> int {
            int local_ti, local_ti_rank;
            local_ti = params.cp_size.divmod(local_ti_rank, length);
            return (local_ti + (local_ti_rank > rank ? 1 : 0));
        };

        const int last_K = history_len + min(query_idx + CTA_Q, input_len);
        const int last_K_tile =
            (get_cp_len(last_K, 0) - 1) / CTA_S + 1;  // past-the-end index to past-the-end tile index conversion

        const int first_K      = max(history_len + query_idx - (params.window_size - 1), 0);
        const int first_K_tile = get_cp_len(first_K, 0) / CTA_S;

        const int tile_count = last_K_tile - first_K_tile;

        /// FIXME: This scheme produce splits less than expected
        const int tile_per_split = cdiv(tile_count, split_cnt);
        const int iter_begin     = tile_per_split * split_idx;
        const int iter_end       = min(iter_begin + tile_per_split, tile_count);

        if (iter_begin >= iter_end) {
            return;
        }

        auto cache_iter = cache_iter_factory.Create(batch_idx, kv_head_idx);

        FragQ frag_Q;
        Prologue(params,
                 storage.Q,
                 frag_Q,
                 qi_begin,
                 qi_end,
                 query_idx,
                 head_idx,
                 kv_head_idx,
                 batch_idx,
                 history_len,
                 cache_iter,
                 warp_id,
                 lane_id);

        __align__(16) FragO frag_O{};

        FragL frag_L{};
        FragM frag_M;
        fill(frag_M, -std::numeric_limits::infinity());

        __syncthreads();

        const int offset_Q = history_len + query_idx;
        const int offset_K = (first_K_tile + iter_end - 1) * CTA_S;

        // This is for avoiding OOB access only
        const int max_K = min(get_cp_len(context_len, params.cp_rank), (first_K_tile + iter_end) * CTA_S);

        int tile_iter = iter_end - iter_begin;

        //    min(Q) >= max(K)
        // -> offset_Q >= offset_K + CTA_S - x * CTA_S
        // -> x * CTA_S >= offset_K - offset_Q + CTA_S
        int mask_iter_back = cdiv(max(0, offset_K - offset_Q + CTA_S), CTA_S);
        //    max(Q) < min(K) + w
        // -> offset_Q + CTA_Q - 1 < offset_K - tile_iter * CTA_S + x * CTA_S + w
        // -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w
        int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S);

        if (params.cp_size > 1) {
            mask_iter_back =
                cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S);
            mask_iter_front = cdiv(max(0,
                                       offset_Q + CTA_Q - params.window_size - params.cp_rank
                                           - params.cp_size * (offset_K - tile_iter * CTA_S)),
                                   params.cp_size * CTA_S);
        }

#if 0
        if (threadIdx.x == 0) {
            printf(
                "tile count: %d, tile per iter: %d, range_Q: [%d, %d), offset_K: %d, max_K: %d, tile_iter: %d, range_K: [%d, %d), range_K_tiles: [%d, %d), mask_iter: %d, mask_iter_front: %d\n",
                tile_count,
                tile_per_split,
                offset_Q,
                offset_Q + min(query_idx + CTA_Q, input_len),
                offset_K,
                max_K,
                tile_iter,
                first_K,
                last_K,
                first_K_tile * CTA_S,
                last_K_tile * CTA_S,
                mask_iter_back,
                mask_iter_front);
        }
#endif

        cache_iter.SetTile(first_K_tile + iter_end - 1);

        Mainloop mainloop;
        mainloop.SetCpInfo(params.cp_size, params.cp_rank);
        mainloop(frag_Q,
                 cache_iter,
                 frag_O,
                 frag_M,
                 frag_L,
                 offset_Q,
                 offset_K,
                 max_K,
                 tile_iter,
                 mask_iter_back,
                 mask_iter_front,
                 params.window_size,
                 params.inv_sqrt_dh,
                 storage,
                 StoreS(params, query_idx, head_idx, batch_idx, context_len));

        Impl::Merge(frag_O, frag_M, frag_L, params.inv_sqrt_dh, storage);

        if (params.sinks && iter_end == tile_count && params.cp_rank == 0) {
            Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float& M, float& L) {
                if (check_h(hi) && M != -std::numeric_limits::infinity()) {
                    auto sink = (float)params.sinks[head_idx + hi];
                    L += expf(sink - M * params.scale_sinks);
                }
            });
        }

        if (split_cnt > 1 && iter_end == tile_count && head_idx == 0) {
            // Store actual split count, only used by separate reduction kernel
            for (int ti = threadIdx.x; ti < CTA_Q; ti += kWarpCount * WARP_SIZE) {
                if (qi_begin + ti < qi_end) {
                    params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : (params.cp_size > 1 ? 1 : 0);
                }
            }
        }

        if (iter_begin == 0 && iter_end == tile_count && params.cp_size == 1) {
            StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage);
        }
        else {
            StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage);
        }
    }

    __device__ void StoreO(FragO&           frag_O,
                           FragL&           frag_L,
                           int              qi_begin,
                           int              qi_end,
                           int              head_idx,
                           const ParamType& params,
                           SharedStorage&   storage)
    {
        Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {
            if (qi_begin + qi < qi_end && check_h(hi)) {
                const int offset = (qi_begin + qi) * params.num_heads * kHeadDim + (head_idx + hi) * kHeadDim + di;
                Store(¶ms.out[offset], cast(vec));
            }
        });
    }

    __device__ auto StoreS(const ParamType& params,
                           const int&       query_idx,
                           const int&       head_idx,
                           const int&       batch_idx,
                           const int&       max_context_len)
    {
        return [&](auto& frag_S, int offset_K) {
            Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float score) {
                qi += query_idx;
                si += offset_K;
                if (qi < params.max_q_len && si < max_context_len && check_h(hi)) {
                    params.qk[batch_idx * params.num_heads * params.max_q_len * max_context_len
                              + (head_idx + hi) * params.max_q_len * max_context_len + qi * max_context_len + si] =
                        score;
                }
            });
        };
    }

    __device__ void StorePartial(FragO&           frag_O,
                                 FragM&           frag_M,
                                 FragL&           frag_L,
                                 int              split_cnt,
                                 int              qi_begin,
                                 int              qi_end,
                                 int              head_idx,
                                 int              split_idx,
                                 const ParamType& params,
                                 SharedStorage&   storage)
    {
        auto get_index = [&](int hi, int qi) {
            // [B, H, k, D]
            return (qi_begin + qi - params.offset_q) * params.num_heads * params.max_split_k
                   + (head_idx + hi) * params.max_split_k + split_idx;
        };

        Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {
            if (qi_begin + qi < qi_end && check_h(hi)) {
                Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec);
            }
        });

        Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) {
            const int index = get_index(hi, qi);
            if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) {
                Store(¶ms.partial_ML[index * 2], Array{M, L});
            }
        });
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

extern __shared__ char smem_buf[];

template
__global__ void attention_kernel(typename Kernel::ParamType            params,
                                 typename Kernel::CacheIteratorFactory cache_iter_factory,
                                 typename Kernel::CtaMap               cta_map,
                                 int                                   q_group_size,
                                 int                                   q_head_per_cta,
                                 int                                   cta_per_q_group)
{
#if __CUDA_ARCH__
    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {
        Kernel{q_group_size, q_head_per_cta, cta_per_q_group}(params, cache_iter_factory, cta_map, smem_buf);
    }
#endif
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/block.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/sub_byte_ptr.h"
#include 
#include 

namespace turbomind {

namespace block {

template
struct Config {
    int head_num_;
    int block_len_;

    TM_HOST_DEVICE constexpr int t_bits() const
    {
        if constexpr (std::is_same_v) {
            return 0;
        }
        else {
            return bitsof;
        }
    }

    TM_HOST_DEVICE constexpr int q_bits() const
    {
        return bitsof;
    }

    TM_HOST_DEVICE constexpr int head_dim() const
    {
        return HeadDim;
    }

    TM_HOST_DEVICE int head_num() const
    {
        return head_num_;
    }

    TM_HOST_DEVICE constexpr int block_len() const
    {
        return block_len_;
    }

    TM_HOST_DEVICE constexpr bool is_share_kv() const
    {
        return ShareKV;
    }
};

// Layout -> LayerId -> HeadId -> Timestep -> [Block] -> (k_data, v_data, k_param, v_param)

template
class Head {
public:
    TM_HOST_DEVICE Head(Layout layout, int layer_id, int head_id):
        layout_{layout}, layer_id_{layer_id}, head_id_{head_id}
    {
    }

    TM_HOST_DEVICE auto k_data(char* block, int ti) const
    {
        if constexpr (std::is_same_v) {
            return SubBytePtr{block + layout_.k_data(layer_id_, head_id_, ti)};
        }
        else {
            return reinterpret_cast(block + layout_.k_data(layer_id_, head_id_, ti));
        }
    }

    TM_HOST_DEVICE auto v_data(char* block, int ti) const
    {
        if constexpr (std::is_same_v) {
            return SubBytePtr{block + layout_.v_data(layer_id_, head_id_, ti)};
        }
        else {
            return reinterpret_cast(block + layout_.v_data(layer_id_, head_id_, ti));
        }
    }

    TM_HOST_DEVICE T* k_param(char* block, int ti) const
    {
        return reinterpret_cast(block + layout_.k_param(layer_id_, head_id_, ti));
    }

    TM_HOST_DEVICE T* v_param(char* block, int ti) const
    {
        return reinterpret_cast(block + layout_.v_param(layer_id_, head_id_, ti));
    }

    TM_HOST_DEVICE void get_block_coord(int seq_ti, int& block_idx, int& block_ti) const
    {
        block_idx = seq_ti / block_len();
        block_ti  = seq_ti % block_len();
    }

    TM_HOST_DEVICE auto block_len() const
    {
        return layout_.config().block_len();
    }

    template
    TM_HOST_DEVICE auto with(char** block_ptrs, int ti, Func&& func) const
    {
        int block_id;
        int block_ti;
        get_block_coord(ti, block_id, block_ti);

        char* block = block_ptrs[block_id];

        return ((Func &&) func)(
            k_data(block, block_ti), v_data(block, block_ti), k_param(block, block_ti), v_param(block, block_ti));
    }

private:
    Layout layout_;

    int layer_id_;
    int head_id_;
};

// L(H2SDQ+H2S2T)
template
struct Layout {

    using Config = Config_;

    Config config_;

    // This trivial ctor is defined for CTAD
    TM_HOST_DEVICE Layout(Config config): config_{config} {}

    TM_HOST_DEVICE const Config& config() const
    {
        return config_;
    }

    TM_HOST_DEVICE constexpr bool is_share_kv() const
    {
        // return 0;
        return config().is_share_kv();
    }

    TM_HOST_DEVICE constexpr int kv_num() const
    {
        // return 2;
        return is_share_kv() ? 1 : 2;
    }

    TM_HOST_DEVICE int token_data_size() const
    {
        return config().q_bits() * config().head_dim() / 8;
    }

    TM_HOST_DEVICE int token_param_size() const
    {
        return config().t_bits() * 2 / 8;  // 2 for scales/zeros
    }

    TM_HOST_DEVICE int head_data_size() const
    {
        return config().block_len() * token_data_size();
    }

    TM_HOST_DEVICE int head_param_size() const
    {
        return config().block_len() * token_param_size();
    }

    TM_HOST_DEVICE int layer_size() const
    {
        // TODO: enforce alignment
        return config().head_num() * kv_num() * head_data_size() + config().head_num() * kv_num() * head_param_size();
    }

    TM_HOST_DEVICE int block_size(int layer_num) const
    {
        return layer_size() * layer_num;
    }

    TM_HOST_DEVICE int k_data(int layer, int head, int token) const
    {
        return layer_data(layer) + head_data(head) + token_data(token);
    }

    TM_HOST_DEVICE int v_data(int layer, int head, int token) const
    {
        return k_data(layer, head, token) + (is_share_kv() ? 0 : head_data_size());
    }

    TM_HOST_DEVICE int k_param(int layer, int head, int token) const
    {
        return layer_param(layer) + head_param(head) + token_param(token);
    }

    TM_HOST_DEVICE int v_param(int layer, int head, int token) const
    {
        return k_param(layer, head, token) + (is_share_kv() ? 0 : head_param_size());
    }

    TM_HOST_DEVICE int layer_data(int layer) const
    {
        return layer * layer_size();
    }

    TM_HOST_DEVICE int layer_param(int layer) const
    {
        return layer_data(layer) + head_data(config_.head_num());
    }

    TM_HOST_DEVICE int head_data(int head) const
    {
        return head * kv_num() * head_data_size();
    }

    TM_HOST_DEVICE int head_param(int head) const
    {
        return head * kv_num() * head_param_size();
    }

    TM_HOST_DEVICE int token_data(int ti) const
    {
        return ti * token_data_size();
    }

    TM_HOST_DEVICE int token_param(int ti) const
    {
        return ti * token_param_size();
    }
};

template
void dump(const Layout& layout)
{
    std::cout << "head_dim: " << layout.config().head_dim() << "\n";
    std::cout << "head_num: " << layout.config().head_num() << "\n";
    std::cout << "block_len: " << layout.config().block_len() << "\n";
    std::cout << "q_bits: " << layout.config().q_bits() << "\n";
    std::cout << "t_bits: " << layout.config().t_bits() << "\n";
    std::cout << "token_data_size: " << layout.token_data_size() << "\n";
    std::cout << "token_param_size: " << layout.token_param_size() << "\n";
    std::cout << "head_data_size: " << layout.head_data_size() << "\n";
    std::cout << "head_param_size: " << layout.head_param_size() << "\n";
    std::cout << "layer_size: " << layout.layer_size() << "\n";
}

}  // namespace block

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/block_iterator.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention_params.h"
#include "block.h"

namespace turbomind {

template
struct BlockIterator {

    BlockHead block_head_;
    char**    block_ptrs_;

    char* block_{};
    int   block_id_{};
    int   block_ti_{};

    __device__ BlockIterator(BlockHead block_head, char** block_ptrs): block_head_{block_head}, block_ptrs_{block_ptrs}
    {
    }

    __device__ void SetTile(int iter)
    {
        block_head_.get_block_coord(iter * CTA_S, block_id_, block_ti_);
        block_ = block_ptrs_[block_id_];
    }

    __device__ void Advance()
    {
        block_ti_ -= CTA_S;
        if (block_ti_ < 0) {
            block_ti_ += block_head_.block_len();
            block_id_ -= 1;
        }
        if (block_id_ >= 0) {
            block_ = block_ptrs_[block_id_];
        }
    }

    template
    __device__ auto OffsetPtr(int offset) const
    {
        if constexpr (Index == 0) {
            return block_head_.k_data(block_, block_ti_) + offset;
        }
        else if constexpr (Index == 1) {
            return block_head_.v_data(block_, block_ti_) + offset;
        }
        else if constexpr (Index == 2) {
            return block_head_.k_param(block_, block_ti_) + offset;
        }
        else if constexpr (Index == 3) {
            return block_head_.v_param(block_, block_ti_) + offset;
        }
        else {
            static_assert(Index != Index, "invalid index");
        }
    }
};

template
struct BlockIteratorFactory {
    using BlockLayout = BlockLayout_;

    BlockLayout_ block_layout_;
    char**       block_ptrs_;
    const int*   cu_block_nums_;
    int          layer_idx_;

    __device__ auto Create(int batch_idx, int head_idx)
    {
        block::Head head{block_layout_, layer_idx_, head_idx};

        char** block_ptrs = block_ptrs_ + cu_block_nums_[batch_idx];

        return BlockIterator, CTA_S>{head, block_ptrs};
    }
};

template
struct CreateCacheIterFactory> {
    template
    static CacheIterFactory apply(const Param& param)
    {
        using BlockLayout = typename CacheIterFactory::BlockLayout;
        using BlockConfig = typename BlockLayout::Config;

        return {
            BlockLayout{BlockConfig{param.num_kv_heads, param.block_iter_params.block_len}},
            param.block_iter_params.block_ptrs,
            param.block_iter_params.cu_block_nums,
            param.block_iter_params.layer_id,
        };
    }
};

template
using GetBlockIterFactory =
    BlockIteratorFactory>, CTA_S>;

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/cp_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/cp_utils.h"

namespace turbomind {

void CpPost(void* context)
{
    auto ctx = reinterpret_cast(context);

    ctx->d_comm->AllGather(ctx->partial_ML + ctx->cp_rank * ctx->count,  //
                           ctx->partial_ML,
                           ctx->count,
                           DataType::kFloat,
                           ctx->attn_cp_group,
                           ctx->stream);
    sync_check_cuda_error();
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/cp_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

struct CpPostContext {

    CpPostContext(comm::DeviceCommImpl* d_comm, int attn_cp_group): d_comm(d_comm), attn_cp_group(attn_cp_group) {}

    comm::DeviceCommImpl* d_comm;
    int                   attn_cp_group;

    int          cp_rank;
    int          count;
    float*       partial_ML;
    cudaStream_t stream;
};

void CpPost(void* context);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/cta_map.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind::attention {

#if 1
struct AttentionCtaMap {

    int q_cta_cnt_;
    int h_cta_cnt_;
    int batch_size_;
    int split_cnt_;

    __host__ __device__
    AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt):
        q_cta_cnt_((max_q_len + cta_q - 1) / cta_q),
        h_cta_cnt_(head_num / cta_h),
        batch_size_(batch_size),
        split_cnt_(split_cnt)
    {
    }

    __host__ __device__ void set_split_cnt(int value)
    {
        split_cnt_ = value;
    }

    __host__ dim3 get_grid_shape() const
    {
        return dim3(q_cta_cnt_, batch_size_, split_cnt_ * h_cta_cnt_);
    }
    __device__ int query_idx() const
    {
        return blockIdx.x;
    }
    __device__ int head_idx() const
    {
        return blockIdx.z % h_cta_cnt_;
    }
    __device__ int batch_idx() const
    {
        return blockIdx.y;
    }
    __device__ int split_idx() const
    {
        return blockIdx.z / h_cta_cnt_;
    }
    __device__ int split_count() const
    {
        return split_cnt_;
    }
};
#else
struct AttentionCtaMap {

    int q_cta_cnt_;
    int h_cta_cnt_;
    int batch_size_;
    int split_cnt_;

    __host__ __device__
    AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt):
        q_cta_cnt_((max_q_len + cta_q - 1) / cta_q),
        h_cta_cnt_(head_num / cta_h),
        batch_size_(batch_size),
        split_cnt_(split_cnt)
    {
    }

    __host__ __device__ void set_split_cnt(int value)
    {
        split_cnt_ = value;
    }

    __host__ dim3 get_grid_shape() const
    {
        return dim3(q_cta_cnt_, h_cta_cnt_, split_cnt_ * batch_size_);
    }
    __device__ int query_idx() const
    {
        return blockIdx.x;
    }
    __device__ int head_idx() const
    {
        return blockIdx.y;
    }
    __device__ int batch_idx() const
    {
        return blockIdx.z % batch_size_;
    }
    __device__ int split_idx() const
    {
        return blockIdx.z / batch_size_;
    }
    __device__ int split_count() const
    {
        return split_cnt_;
    }
};
#endif

struct DecodingCtaMap {
    static __host__ dim3 get_grid_shape(int kv_head_num, int batch_size, int split_count, int cta_per_q_group)
    {
        return dim3(cta_per_q_group * kv_head_num, batch_size, split_count);
    }
    __device__ int query_idx() const
    {
        return 0;
    }
    __device__ int head_idx() const
    {
        return blockIdx.x;
    }
    __device__ int batch_idx() const
    {
        return blockIdx.y;
    }
    __device__ int split_idx() const
    {
        return blockIdx.z;
    }
    __device__ int split_count() const
    {
        return gridDim.z;
    }
};

struct ReduceCtaMap {
    static __host__ dim3 get_grid_shape(int query_num, int head_num, int max_split_cnt, int cta_k)
    {
        return dim3(head_num, query_num, (max_split_cnt + cta_k - 1) / cta_k);
    }
    static __device__ int query_idx()
    {
        return blockIdx.y;
    }
    static __device__ int head_idx()
    {
        return blockIdx.x;
    }
    static __device__ int split_idx()
    {
        return blockIdx.z;
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/decoding.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "decoding.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/attention/registry.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

template
void dispatchDecoding(const AttentionParams& params)
{
    using namespace attention;

    const bool is_kv_int8     = params.quant_policy & QuantPolicy::kCacheKVInt8;
    const bool is_kv_int4     = params.quant_policy & QuantPolicy::kCacheKVInt4;
    const int  query_group_sz = params.num_heads / params.num_kv_heads;

    FT_CHECK(!(is_kv_int4 && is_kv_int8));

    int kv_quant = is_kv_int4 ? 4 : (is_kv_int8 ? 8 : 0);

    AttnDesc desc{};
    desc.mode           = AttnDesc::kDecoding;
    desc.head_dim       = params.size_per_head;
    desc.data_type      = data_type_v;
    desc.kv_quant       = kv_quant;
    desc.query_group_sz = query_group_sz;

    auto& reg    = Registry::instance();
    auto* kernel = reg.Find(desc);

    TM_CHECK(kernel) << "No decoding kernel found: " + to_string(desc);

    kernel->Launch(¶ms, reg.sm_count());
}

template void dispatchDecoding(const AttentionParams& params);
#if ENABLE_BF16
template void dispatchDecoding(const AttentionParams& params);
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/decoding.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention_params.h"

namespace turbomind {

template
void dispatchDecoding(const AttentionParams& params);

}


================================================
FILE: src/turbomind/kernels/attention/decoding_template.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention_params.h"
#include "attention_universal.h"
#include "reduce.h"
#include "src/turbomind/kernels/core/thread_map.h"
#include "utils.h"
namespace turbomind {

template
bool invokeDecoding(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas)
{
    static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage);

    if constexpr (1) {
        [[maybe_unused]] static const int _ = [&] {
            // std::cout << __PRETTY_FUNCTION__ << std::endl;
            // std::cout << "GmemMap:\n";
            // Print(typename Kernel::Impl::ThreadMapKV{});
            // std::cout << "\nDynamic smem size: " << kSmemSize << "\n";
            return 0;
        }();
    }

    const int max_cp_k_len    = cdiv(params.max_k_len, (int)params.cp_size);
    const int tile_count      = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);
    const int max_split_count = std::min(params.max_split_k, tile_count);

    using CtaMap = typename Kernel::CtaMap;

    dim3 block(Kernel::kWarpCount * WARP_SIZE);

    auto kernel_func = &attention_kernel;

    const int q_group_size   = params.num_heads / params.num_kv_heads;
    const int q_head_per_cta = std::min(q_group_size, Kernel::CTA_H);

    // cta needed to process one query group
    const int cta_per_q_group = (q_group_size + q_head_per_cta - 1) / q_head_per_cta;

    // std::cout << "CTA_H: " << Kernel::CTA_H << ", head_per_cta: " << q_head_per_cta
    //           << ", cta_per_q_group: " << cta_per_q_group << "\n";

    dim3 grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, 1, cta_per_q_group);

    const int grid_size = grid.x * grid.y * grid.z;
    const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 4);

    grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, split_cnt, cta_per_q_group);

    // Print(typename Kernel::Impl::ThreadMapKVp{});

    // std::cout << "split count: " << split_cnt << "\n";

    auto cache_iter_factory = CreateCacheIterFactory::apply(params);

    kernel_func<<>>(
        params, cache_iter_factory, CtaMap{}, q_group_size, q_head_per_cta, cta_per_q_group);

    if (auto err = cudaGetLastError(); err != cudaSuccess) {
        std::cout << cudaGetErrorString(err) << "\n";
        std::abort();
    }

    if (params.cp_fn) {
        params.cp_fn(params.cp_fn_ctx);
    }

    if (split_cnt > 1 || params.cp_size > 1) {
        attention::invokeReduceV3(params.out,
                                                    params.partial_ML,
                                                    params.partial_O,
                                                    split_cnt > 1 ? params.split_cnt : nullptr,
                                                    params.max_split_k,
                                                    split_cnt,
                                                    params.cp_size,
                                                    params.cp_rank,
                                                    params.token_num,
                                                    params.num_heads,
                                                    params.inv_sqrt_dh,
                                                    params.stream);
    }

    return true;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/desc.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/data_type.h"
#include 
#include 
#include 

namespace turbomind::attention {

struct AttnDesc {
    enum Mode
    {
        kPrefill,
        kDecoding
    };
    Mode     mode;
    int      head_dim;
    DataType data_type;
    int      kv_quant;        // 0=none, 8=int8, 4=int4
    int      query_group_sz;  // num_heads/num_kv_heads for decoding; 0 for prefill
};

inline std::string to_string(const AttnDesc& d)
{
    std::ostringstream ss;
    ss << (d.mode == AttnDesc::kPrefill ? "prefill" : "decode");
    ss << "_d" << d.head_dim;
    ss << "_" << to_string(d.data_type);
    if (d.mode == AttnDesc::kDecoding) {
        if (d.kv_quant == 8)
            ss << "_kvint8";
        else if (d.kv_quant == 4)
            ss << "_kvint4";
        ss << "_gs" << d.query_group_sz;
    }
    return ss.str();
}

struct KernelDesc {
    AttnDesc::Mode mode;
    int            arch;  // 700, 750, 800
    int            head_dim;
    DataType       data_type;
    int            kv_quant;  // 0=none, 8=int8, 4=int4
    int            qh;        // query heads per CTA (1 for prefill)
};

struct KernelInfo {
    int                dynamic_smem_size;
    int                max_active_ctas;
    int                num_warps;
    std::string        name;
    cudaFuncAttributes attr;
};

inline std::string to_string(const KernelDesc& d)
{
    std::ostringstream ss;
    ss << (d.mode == AttnDesc::kPrefill ? "prefill" : "decode");
    ss << "_sm" << d.arch / 10;
    ss << "_d" << d.head_dim;
    ss << "_" << to_string(d.data_type);
    if (d.mode == AttnDesc::kDecoding) {
        if (d.kv_quant == 8)
            ss << "_kvint8";
        else if (d.kv_quant == 4)
            ss << "_kvint4";
        ss << "_qh" << d.qh;
    }
    return ss.str();
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

namespace attention {

struct MMA_16816 {
};

struct MMA_81616 {
};  // MMA_16816 transposed

struct MMA_1688 {
};

struct MMA_884 {
};

struct MMA_SIMT {
};

template
struct Impl {
};

}  // namespace attention

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/impl_16816.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_m16n8.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/core/thread_map.h"

namespace turbomind::attention {

template
struct Impl:
    Impl_m16k8 {

    using Base = Impl_m16k8;

    static constexpr bool MLA = HeadDim == 576;

    using Base::OP_M;
    using Base::OP_N;
    using Base::K_M;
    using Base::K_N;
    using Base::V_M;
    using Base::V_N;

    using typename Base::FragS;
    using typename Base::FragO;
    using typename Base::FragM;
    using typename Base::FragL;

    using Base::ForeachS;
    using Base::Softmax;
    using Base::ConvertStoP;
    using Base::StoreO;

    using T   = T_;
    using Tkv = T_;

    static constexpr int kHeadDim = HeadDim;

    static constexpr int CTA_H = CTA_H_;
    static constexpr int CTA_Q = CTA_Q_;
    static constexpr int CTA_S = CTA_S_;

    static constexpr int kWarpCntQ  = CTA_Q * CTA_H / WARP_Q;
    static constexpr int kWarpCntS  = CTA_S / WARP_S;
    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;

    static constexpr int OP_K = 16;

    static constexpr int K_K = HeadDim / OP_K;  // 128 / 16 = 8
    static constexpr int V_K = WARP_S / OP_K;   //  64 / 16 = 4  -> S4

    using FragQ = Array[K_K][K_M];  // ((q8, d4), (Dk, Qm), (d2, q2, d2))
                                          //    1   2    16  16     8   8   1
    using FragK = Array[K_K][K_N];  // ((s8, d4), (Dk, Sn), (d2, d2))
                                          //    1   2    16   8     8   1
    using FragP = Array[V_M][V_K];  // ((q8, s4), (Qm, Sk), (s2, q2, s2))
                                          //    1   2    16  16     8   8   1
    using FragV = Array[V_K][V_N];  // ((d8, s4), (Sk, Dn), (s2, s2))
                                          //    1   2    16   8     8   1

    static_assert(sizeof(FragS) / 2 == sizeof(FragP));

    using SmemLayoutQ = std::conditional_t>,
                                           SmemLayoutV2>>;
    using SmemLayoutK = std::conditional_t>,
                                           SmemLayoutV2>>;
    using SmemLayoutV = std::conditional_t>,
                                           SmemLayoutV2>>;

    using SmemLayoutKVp = void;

    static constexpr bool kUseSmemQ = false;
    static constexpr bool kUseSmemP = false;

    static_assert(!kUseSmemQ, "current smemQ impl yields inconsistent outputs");

    union SharedStorage {
        __align__(16) T KV[Stages * (SmemLayoutK::kSize + SmemLayoutV::kSize) / 2];
        __align__(16) T Q[SmemLayoutQ::kSize];
    };

    using ThreadMapQ  = RakedThreadMap;
    using ThreadMapKV = RakedThreadMap;

    using ThreadMapKVp = void;

    static constexpr int kBatchK = std::min(4, ThreadMapKV::kIterS);
    static constexpr int kBatchV = kBatchK;

    __device__ static void Sync()
    {
        __syncthreads();
    }

    template
    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)
    {
        int pred = offset_kv;
        gmem_K.SetSmem(storage.KV);
        gmem_V.SetSmem(storage.KV + pred * SmemLayoutK::kSize);
    }

    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        if constexpr (!kUseSmemQ) {
            __syncwarp();

            SmemAccessor sQ{smem_Q};

            // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; ++k) {
                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;
                    const int di = lane_id / 16 * 8 + k * 16;
                    ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));
                }
            }
        }

        if constexpr (0) {
            __syncthreads();

            // Rearrange Q in smem so that swizzling is not needed for later LDSMs
            constexpr int THREADS = kWarpCount * WARP_SIZE;
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    constexpr int kVecSize = 8;
                    Store(&smem_Q[(k * K_M * THREADS + m * THREADS + threadIdx.x) * kVecSize], frag_Q[k][m]);
                }
            }
        }
    }

    struct StateQK {
        SmemAccessor smem_K;
        T*                           smem_Q;

        FragQ frag_Q;
        FragK frag_K;

        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.KV}
        {
            if constexpr (!kUseSmemQ) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; ++k) {
                    PRAGMA_UNROLL
                    for (int m = 0; m < K_M; ++m) {
                        frag_Q[k][m] = frag_Q_[k][m];
                    }
                }
            }
            else {
                smem_Q = storage.Q;
            }
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const int lane_id       = threadIdx.x % WARP_SIZE;
            const int group_id      = lane_id / 16;
            const int group_lane_id = lane_id % 16;
            const int offset_s      = group_lane_id % 8 + group_id * 8;
            const int offset_c      = group_lane_id / 8 * 8;
            const int offset        = pipe_iter * SmemLayoutK::kSize;
            if constexpr (kUseSmemQ) {
                const int                    warp_id = threadIdx.x / WARP_SIZE;
                SmemAccessor sQ{smem_Q};
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;
                    const int di = lane_id / 16 * 8 + k * 16;
                    ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));
                }
            }
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; n += 2) {  // Load (s16,d16) tiles
                const int s = n * 8 + offset_s;
                const int c = k * 16 + offset_c;
                ldsm_x4((Array&)frag_K[k][n], cast_smem_ptr_to_uint(&smem_K(s, c, offset)));
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void
    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < K_K; ++k) {
            if (k < K_K - 1) {
                state_QK.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    const int nn = (Stages == 2) ? (n ^ 1) : (n ^ 2);
                    mma_m16n8k16_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);
                }
            }
            if (k < K_K - 1) {
                ((Prefetch &&) prefetch)(k);
            }
            if (k == K_K - 2) {
                ((Prefetch &&) prefetch)(K_K - 1);
            }
        }
    }

    struct StatePV {
        SmemAccessor smem_V;

        FragP frag_P;
        FragV frag_V;

        __device__ StatePV(SharedStorage& storage, bool offset = false):
            smem_V{storage.KV + (offset ? SmemLayoutK::kSize : 0)}
        {
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const int lane_id  = threadIdx.x % WARP_SIZE;
            const int offset_s = lane_id % 16;
            const int offset_c = lane_id / 16 * 8;
            const int offset   = pipe_iter * SmemLayoutV::kSize;
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; n += 2) {  // Load (d16,s16) tiles
                const int s = k * 16 + offset_s;
                const int c = n * 8 + offset_c;
                ldsm_x4_trans((Array&)frag_V[k][n], cast_smem_ptr_to_uint(&smem_V(s, c, offset)));
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)
    {
        static_assert(kWarpCntS == 1);

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                if constexpr (Base::kDeferReduceL) {
                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 1);
                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 2);
                }
            }
        }
    }

    template
    __device__ static void
    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < V_K; ++k) {
            if (k < V_K - 1) {
                state_PV.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < V_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    const int nn = n ^ 0;
                    mma_m16n8k16_row_col(frag_O[m][nn], state_PV.frag_P[m][k], state_PV.frag_V[k][nn], frag_O[m][nn]);
                }
            }
            if (k < V_K - 1) {
                ((Prefetch &&) prefetch)(k);
            }
            if (k == V_K - 2) {
                ((Prefetch &&) prefetch)(V_K - 1);
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl_1688.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_m16n8.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/core/thread_map.h"

namespace turbomind::attention {

template
struct Impl:
    Impl_m16k8 {

    using Base = Impl_m16k8;

    static constexpr bool MLA = HeadDim == 576;

    using Base::OP_M;
    using Base::OP_N;
    using Base::K_M;
    using Base::K_N;
    using Base::V_M;
    using Base::V_N;

    using typename Base::FragS;
    using typename Base::FragO;
    using typename Base::FragM;
    using typename Base::FragL;

    using Base::ForeachS;
    using Base::Softmax;
    using Base::ConvertStoP;
    using Base::StoreO;

    using T   = T_;
    using Tkv = T_;

    static constexpr int kHeadDim = HeadDim;

    static constexpr int CTA_H = CTA_H_;
    static constexpr int CTA_Q = CTA_Q_;
    static constexpr int CTA_S = CTA_S_;

    static constexpr int kWarpCntQ  = CTA_Q * CTA_H / WARP_Q;
    static constexpr int kWarpCntS  = CTA_S / WARP_S;
    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;

    static constexpr int OP_K = 8;

    static constexpr int K_K = HeadDim / OP_K;  // 128 / 16 = 8
    static constexpr int V_K = WARP_S / OP_K;   //  64 / 16 = 4  -> S4

    using FragQ = Array[K_K][K_M];  // ((q8, d4), (Dk, Qm), (q2, d2))
                                          //    1   2     8  16     8   1
    using FragK = Array[K_K][K_N];  // ((s8, d4), (Dk, Sn), (d2))
                                          //    1   2     8   8     1
    using FragP = Array[V_M][V_K];  // ((q8, s4), (Qm, Sk), (q2, s2))
                                          //    1   2    16   8     8   1
    using FragV = Array[V_K][V_N];  // ((d8, s4), (Sk, Dn), (s2))
                                          //    1   2     8   8     1

    using SmemLayoutQ = std::conditional_t>,
                                           SmemLayoutV2>>;
    using SmemLayoutK = std::conditional_t>,
                                           SmemLayoutV2>>;
    using SmemLayoutV = std::conditional_t>,
                                           SmemLayoutV2>>;

    using SmemLayoutKVp = void;

    union SharedStorage {
        __align__(16) T Q[SmemLayoutQ::kSize];
        struct {
            __align__(16) Tkv K[SmemLayoutK::kSize];
            __align__(16) Tkv V[SmemLayoutV::kSize];
        };
    };

    static constexpr bool kUseSmemQ = false;

    using ThreadMapQ  = RakedThreadMap;
    using ThreadMapKV = RakedThreadMap;

    using ThreadMapKVp = void;

    __device__ static void Sync()
    {
        __syncthreads();
    }

    template
    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)
    {
        gmem_K.SetSmem(storage.K);
        gmem_V.SetSmem(storage.V);
    }

    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        __syncwarp();

        SmemAccessor sQ{smem_Q};
        if constexpr (!kUseSmemQ) {
            // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; k += 2) {
                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;
                    const int di = lane_id / 16 * 8 + k * 8;
                    ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));
                }
            }
        }
        else {
            static_assert(!std::is_same_v, "not supported");
        }
    }

    struct StateQK {
        SmemAccessor smem_K;

        FragQ frag_Q;
        FragK frag_K;

        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.K}
        {
            static_assert(!kUseSmemQ, "not implemented");
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    frag_Q[k][m] = frag_Q_[k][m];
                }
            }
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; n += 4) {  // Load (s32,d8) tiles
                const int s = n * 8 + lane_id;
                const int c = k * 8;
                ldsm_x4((Array&)frag_K[k][n], cast_smem_ptr_to_uint(&smem_K(s, c)));
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void
    ComputeQK(StateQK& state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < K_K; ++k) {
            if (k < K_K - 1) {
                state_QK.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    const int nn = n ^ 2;
                    mma_m16n8k8_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);
                }
            }
        }
    }

    struct StatePV {
        SmemAccessor smem_V;

        FragP frag_P;
        FragV frag_V;

        __device__ StatePV(SharedStorage& storage, bool offset = true): smem_V{storage.V}
        {
            assert(offset);
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; n += 4) {  // Load (d32,s8) tiles
                const int si = k * 8 + lane_id % 8;
                const int di = n * 8 + lane_id / 8 * 8;
                ldsm_x4_trans((Array&)frag_V[k][n], cast_smem_ptr_to_uint(&smem_V(si, di)));
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void
    ComputePV(StatePV& state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < V_K; ++k) {
            if (k < V_K - 1) {
                state_PV.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < V_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    mma_m16n8k8_row_col(frag_O[m][n], state_PV.frag_P[m][k], state_PV.frag_V[k][n], frag_O[m][n]);
                }
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl_81616.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/quantization.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/core/thread_map.h"
#include 

namespace turbomind::attention {

template
struct Impl {
    using T   = T_;
    using Tkv = Tkv_;

    static constexpr int kQuantKV = !std::is_same_v;

    static constexpr bool MLA = HeadDim == 576;

    static constexpr int CTA_H = CTA_H_;
    static constexpr int CTA_Q = CTA_Q_;
    static constexpr int CTA_S = CTA_S_;

    static_assert(CTA_Q == 1);

    static constexpr int WARP_H = WARP_H_;

    static constexpr int kHeadDim = HeadDim;

    static constexpr int kWarpCntH = CTA_H / WARP_H;
    static constexpr int kWarpCntQ = CTA_Q / WARP_Q;
    static constexpr int kWarpCntS = CTA_S / WARP_S;

    static constexpr int kWarpCount = kWarpCntH * kWarpCntQ * kWarpCntS;

    static constexpr int OP_M = 16;
    static constexpr int OP_N = 8;
    static constexpr int OP_K = 16;

    static constexpr int K_M = WARP_S / OP_M;               // 1
    static constexpr int K_N = (WARP_H + OP_N - 1) / OP_N;  // 1
    static constexpr int K_K = HeadDim / OP_K;              // 8

    static constexpr int V_M = HeadDim / OP_M;              // 8
    static constexpr int V_N = (WARP_H + OP_N - 1) / OP_N;  // 1
    static constexpr int V_K = WARP_S / OP_K;               // 1

    using FragK = Array[K_K][K_M];      // (s8,d4) (Dk,Sm) (d2,s2,d2)
                                              //   1  2   16 16    8  8  1
    using FragQ = Array[K_N][K_K];      // (q8,d4) (Qn,Dk) (d2,d2)
                                              //   1  2    8,16    8  1
    using FragS = Array[K_M][K_N];  // (s8,q4) (Sm,Qn) (s2,q2)
                                              //   1  2   16  8    8  1
    using FragV = Array[V_M][V_K];      // (d8,s4) (Dm,Sk) (s2,d2,s2)
                                              //   1  2   16 16    8  8  1
    using FragP = Array[V_K][V_N];      // (q8,s4) (Sk,Qn) (s2,s2)
                                              //   1  2   16  8    8  1
    using FragO = Array[V_M][V_N];  // (d8,q4) (Dm,Qn) (d2,q2)
                                              //   1  2   16  8    8  1
    using FragM = Array[K_N];       // (_8,q4)    (Qn)    (q2)
                                              //      2       8       1

    static constexpr int X = 16 / bitsof;

    using DataK = Array[K_K / X][K_M];  // {s8,d4} [Dk/x,Sm] (d2,s2,dx,d2)
                                                    //   1 2x    16x 16   8x  8  2  1
    using ParamK = Array[K_M][2];             // {s8,_4} [     Sm] (   s2      )
                                                    //   1  0        16       8
    using DataV = Array[V_M / X][V_K];  // {s8,d4} [Dm/x,Sk] (s2,d2,dx,d2)
                                                    //   1 2x    16x 16    8 8x  2  1
    using ParamV = Array[V_K][2];             // {s8,_4} [     Sk] (s2         )
                                                    //   1  0        16    8

    using FragL = FragM;

    using SmemM = Array[K_N][kWarpCntH][kWarpCntS][4];

    using SmemO = Array[V_M][V_N][kWarpCntH][kWarpCntS][WARP_SIZE];

    static constexpr bool kUseSmemQ = false;
    static constexpr bool kUseSmemP = false;

    static constexpr int CTA_H1 = (CTA_H + OP_N - 1) / OP_N * OP_N;

    static constexpr auto _SmemLayoutKV(std::integral_constant)
    {
        return SmemLayoutV2>{};
    }
    static constexpr auto _SmemLayoutKV(std::integral_constant)
    {
        return SmemLayoutV2>{};
    }
    static constexpr auto _SmemLayoutKV(std::integral_constant)
    {
        return std::conditional_t>,
                                  SmemLayoutV2>>{};
    }

    using SmemLayoutQ = SmemLayoutV2>;
    using SmemLayoutK = decltype(_SmemLayoutKV(bitsof));
    using SmemLayoutV = decltype(_SmemLayoutKV(bitsof));

    using SmemLayoutKVp = SmemLayoutV2;

    using PointerKV = get_pointer_type;

    union SharedStorage {
        __align__(16) T Q[SmemLayoutQ::kSize];

        struct {
            __align__(16) Array KV;
            __align__(16) T KVp[Stages * SmemLayoutKVp::kSize];
        };

        struct {
            __align__(16) SmemM M;
            __align__(16) SmemM L;
            __align__(16) SmemO O;
        };

        __align__(16) float O1[CTA_H1][kHeadDim];
    };

    using ThreadMapQ  = RakedThreadMap;
    using ThreadMapKV = RakedThreadMap, kWarpCount>;
    // `WARP_SIZE / WARP_S` is chosen to achieve minimum kIterS w/o introducing partial S iter
    using ThreadMapKVp = RakedThreadMap<2, CTA_S, 2, kWarpCount, WARP_SIZE / WARP_S>;

    static constexpr int kBatchK = ThreadMapKV::kIterS;
    static constexpr int kBatchV = ThreadMapKV::kIterS;

    static constexpr bool kDeferReduceL = true;

    __device__ static void Sync()
    {
        if constexpr (kWarpCntH > 1) {
            __syncthreads();
        }
        else if constexpr (kQuantKV) {  // Thread layout of KV & KVp is different within warp boundary
            __syncwarp();
        }
    }

    template
    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)
    {
        int pred = offset_kv;
        if constexpr (kQuantKV) {
            gmem_K.SetSmem(storage.KV.data(), storage.KVp);
            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize, storage.KVp + pred * SmemLayoutKVp::kSize);
        }
        else {
            gmem_K.SetSmem(storage.KV.data());
            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize);
        }
    }

    static __device__ int2 get_warp_ids()
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        if constexpr (kWarpCntH > 1) {
            return {warp_id % kWarpCntS, warp_id / kWarpCntS};
        }
        else {
            return {warp_id, 0};
        }
    }

    template
    __device__ static void ForeachS(Fragment& S, Func&& func)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        const int si = m * OP_M + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;
                        const int hi = n * OP_N + lane_id % 4 * 2 + q * 1 + warp_ids.y * WARP_H;
                        ((Func &&) func)(hi, /*qi*/ 0, si, /*ri*/ 0, S[m][n][s * 2 + q]);
                    }
                }
            }
        }
    }

    template
    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {  // Q
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                const int hi = lane_id % 4 * 2 + n * OP_N + q * 1 + warp_ids.y * WARP_H;
                const int ri = lane_id / 4 * 1;
                ((Func &&) func)(hi, /*qi*/ 0, ri, frag_M[n][q], frag_L[n][q]);
            }
        }
    }

    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)
    {
        static_assert(K_K % 2 == 0);
        SmemAccessor sQ{smem_Q};

        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        if constexpr (!kQuantKV) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; k += 2) {  // 16x16 tile
                    const int hi = n * OP_N + lane_id % 8 + warp_ids.y * WARP_H;
                    const int di = k * OP_K + lane_id / 8 * 8;
                    ldsm_x4((Array&)frag_Q[n][k], cast_smem_ptr_to_uint(&sQ(hi, di)));
                }
            }
        }
        else {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; k += X) {
                    PRAGMA_UNROLL
                    for (int x = 0; x < X; ++x) {
                        PRAGMA_UNROLL
                        for (int d = 0; d < 2; ++d) {  // (s8,d8)
                            const int hi = n * OP_N + lane_id / 4 + warp_ids.y * WARP_H;
                            const int di = k * OP_K + lane_id % 4 * 2 * X + x * 2 + d * 8 * X;
                            Load((Array&)frag_Q[n][k + x][d * 2], &sQ(hi, di));
                        }
                    }
                }
            }
        }
    }

    struct StateQK {
        PointerKV smem_K;
        T*        smem_K_param;
        FragQ     frag_Q;
        ParamK    param_K;
        DataK     data_K;
        FragK     frag_K;

        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_)
        {
            smem_K       = storage.KV.data();
            smem_K_param = storage.KVp;
            static_assert(!kUseSmemQ, "not implemented");
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; ++k) {
                    frag_Q[n][k] = frag_Q_[n][k];
                }
            }
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const auto warp_ids = get_warp_ids();
            const int  lane_id  = threadIdx.x % WARP_SIZE;

            if (kQuantKV && k == 0) {
                static_assert(K_M == 1);
                const int m = 0;
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    const int si = m * 16 + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;
                    Lds(param_K[m][s], &smem_K_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);
                }
            }

            if (k % X == 0) {
                const int offset_s = lane_id % 16 * 1 + warp_ids.x * WARP_S;
                const int offset_c = lane_id / 16 * 8 * X;
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    const int s = m * 16 + offset_s;  // Q
                    const int c = k * 16 + offset_c;  // D
                    static_assert(sizeof(data_K[k / X][m]) == 16);
                    ldsm_x4((Array&)data_K[k / X][m],
                            cast_smem_ptr_to_uint(&smem_K[pipe_iter * SmemLayoutK::kSize + SmemLayoutK::apply(s, c)]));
                }
            }
        }

        __device__ void Transform(int k)
        {
            if constexpr (!kQuantKV) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    frag_K[k][m] = data_K[k][m];
                }
            }
            else {  // this also covers non-quantized case, but it's too convolved to read
                static_assert(K_M == 1);
                if (k % X == 0) {
                    using Converter = ConvertKvCache;
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        PRAGMA_UNROLL
                        for (int d = 0; d < 2; ++d) {
                            auto dx_d2 =
                                Converter::convert((Array&)data_K[k / X][0][d * 4 * X + s * 2 * X]);
                            PRAGMA_UNROLL
                            for (int x = 0; x < X; ++x) {
                                (Array&)frag_K[k + x][0][d * 4 + s * 2] = (Array&)dx_d2[x * 2];
                            }
                        }
                    }
                }
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 2; ++d) {
                        auto& d2 = (Array&)frag_K[k][0][d * 4 + s * 2];
                        PRAGMA_UNROLL
                        for (int i = 0; i < 2; ++i) {
                            d2[i] = __hfma(d2[i], param_K[0][s][0], param_K[0][s][1]);
                        }
                    }
                }
            }
        }
    };

    template
    __device__ static void
    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        if constexpr (K_K == 1) {
            ((Prefetch &&) prefetch)(0);
        }

        PRAGMA_UNROLL
        for (int k = 0; k < K_K; ++k) {
            if (k < K_K - 1) {
                state_QK.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }

            state_QK.Transform(k);

            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    mma_m16n8k16_row_col(frag_S[m][n], state_QK.frag_K[k][m], state_QK.frag_Q[n][k], frag_S[m][n]);
                }
            }
            if (k < K_K - 1) {
                ((Prefetch &&) prefetch)(k);
            }
            if (k == K_K - 2) {
                ((Prefetch &&) prefetch)(K_K - 1);
            }
        }
    }

    struct StatePV {
        PointerKV smem_V;
        T*        smem_V_param;
        ParamV    param_V;
        DataV     data_V;
        FragP     frag_P;
        FragV     frag_V;

        __device__ StatePV(SharedStorage& storage, bool offset = false)
        {
            smem_V       = storage.KV.data() + (offset ? SmemLayoutK::kSize : 0);
            smem_V_param = storage.KVp + (offset ? SmemLayoutKVp::kSize : 0);
        }

        __device__ void Load(int m, int pipe_iter)
        {
            const auto warp_ids = get_warp_ids();
            const int  lane_id  = threadIdx.x % WARP_SIZE;

            if (kQuantKV && m == 0) {
                static_assert(V_K == 1);
                const int k = 0;
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    const int si = k * 16 + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;
                    Lds(param_V[k][s], &smem_V_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);
                }
            }

            if (m % X == 0) {
                const int offset_s = lane_id / 16 * 8 + lane_id % 8 + warp_ids.x * WARP_S;
                const int offset_c = lane_id % 16 / 8 * 8 * X;
                PRAGMA_UNROLL
                for (int k = 0; k < V_K; ++k) {
                    const int s = k * 16 + offset_s;
                    const int c = m * 16 + offset_c;
                    static_assert(sizeof(data_V[m / X][k]) == 16);
                    if constexpr (!kQuantKV) {
                        ldsm_x4_trans(
                            (Array&)data_V[m / X][k],
                            cast_smem_ptr_to_uint(&smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(s, c)]));
                    }
                    else {
                        ldsm_x4(
                            (Array&)data_V[m / X][k],
                            cast_smem_ptr_to_uint(&smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(s, c)]));
                    }
                }
            }
        }

        __device__ void Transform(int m)
        {
            if constexpr (!kQuantKV) {
                PRAGMA_UNROLL
                for (int k = 0; k < V_K; ++k) {
                    frag_V[m][k] = data_V[m][k];
                }
            }
            else {
                static_assert(V_K == 1);
                if (m % X == 0) {
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        PRAGMA_UNROLL
                        for (int d = 0; d < 2; ++d) {
                            auto dx_d2 = ConvertKvCache::convert(
                                (Array&)data_V[m / X][0][s * 4 * X + d * 2 * X]);
                            PRAGMA_UNROLL
                            for (int x = 0; x < X; ++x) {
                                (Array&)frag_V[m + x][0][s * 4 + d * 2] = (Array&)dx_d2[x * 2];
                            }
                        }
                    }
                }
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 2; ++d) {
                        auto& d2 = (Array&)frag_V[m][0][s * 4 + d * 2];
                        PRAGMA_UNROLL
                        for (int i = 0; i < 2; ++i) {
                            d2[i] = __hfma(d2[i], param_V[0][s][0], param_V[0][s][1]);
                        }
                        (uint32_t&)d2 = transpose_m8n8_b16((uint32_t&)d2);
                    }
                }
            }
        }
    };

    template
    __device__ static void
    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            if (m < V_M - 1) {
                state_PV.Load(m + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }

            state_PV.Transform(m);

            PRAGMA_UNROLL
            for (int k = 0; k < V_K; ++k) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    mma_m16n8k16_row_col(frag_O[m][n], state_PV.frag_V[m][k], state_PV.frag_P[k][n], frag_O[m][n]);
                }
            }
            if (m < V_M - 1) {
                ((Prefetch &&) prefetch)(m);
            }
            if (m == V_M - 2) {
                ((Prefetch &&) prefetch)(V_M - 1);
            }
        }
    }

    template
    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)
    {
        FragM prev_M;
        copy(frag_M, prev_M);

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {  // h
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {  // s
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        frag_M[n][q] = fmaxf(frag_M[n][q], frag_S[m][n][s * 2 + q]);
                    }
                }
            }
        }

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 4));
                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 8));
                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 16));
            }
        }

        FragM expdiff_M;
        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                expdiff_M[n][q] = exp2f((prev_M[n][q] - frag_M[n][q]) * qk_scale);
                if (is_residue && frag_M[n][q] == -std::numeric_limits::infinity()) {
                    expdiff_M[n][q] = 0.f;
                }
                frag_L[n][q] *= expdiff_M[n][q];
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                PRAGMA_UNROLL
                for (int d = 0; d < 2; ++d) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        frag_O[m][n][d * 2 + q] *= expdiff_M[n][q];  // Rescale previous output
                    }
                }
            }
        }

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                float tmp_L{};
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        float p = exp2f(frag_S[m][n][s * 2 + q] * qk_scale - frag_M[n][q] * qk_scale);
                        if (is_residue && frag_M[n][q] == -std::numeric_limits::infinity()) {
                            p = 0.f;
                        }
                        tmp_L += p;
                        frag_S[m][n][s * 2 + q] = p;
                    }
                }
                if constexpr (!kDeferReduceL) {
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 4);
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 8);
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 16);
                }
                frag_L[n][q] += tmp_L;  // update L
            }
        }
    }

    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage&)
    {
        static_assert(K_M == V_K);

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    Array tmp_P;
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        tmp_P[q] = static_cast(frag_S[m][n][s * 2 + q]);
                    }
                    // (s8,q4),(s2,q2) -> (q8,s4),(s2,s2)
                    //   1  2    8  1       1  2    8  1
                    (uint32_t&)tmp_P = transpose_m8n8_b16((uint32_t&)tmp_P);

                    (Array&)frag_P[m][n][s * 2] = tmp_P;
                }
            }
        }
    }

    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, SharedStorage& storage)
    {
        if constexpr (kWarpCntS == 1 && !kDeferReduceL) {
            __syncthreads();
            return;
        }

        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        FragM prev_M;
        copy(frag_M, prev_M);

        __syncthreads();

        /////////////////////////////////////////////////////////////////////////
        //  global max
        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            if (lane_id < 4) {
                Store((float*)&storage.M[n][warp_ids.y][warp_ids.x][lane_id], frag_M[n]);
            }
        }

        __syncthreads();

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            // Compute global maximum
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                PRAGMA_UNROLL
                for (int w = 0; w < kWarpCntS - 1; ++w) {
                    const int src_warp = (warp_ids.x + w + 1) % kWarpCntS;
                    frag_M[n][q]       = fmaxf(frag_M[n][q], storage.M[n][warp_ids.y][src_warp][lane_id % 4][q]);
                }
                // if (lane_id < 4) {
                //     printf("M %d %d %f\n", lane_id % 4 * 2 + q, warp_id, frag_M[n][q]);
                // }
            }
        }

        // if (threadIdx.x == 0) {
        //     printf("M %d %f\n", 0, frag_M[0][0]);
        // }

        ///////////////////////////////////////////////////////////////////////////
        //  rescale & global sum

        FragM expdiff_M;
        PRAGMA_UNROLL
        for (int n = 0; n < V_N; ++n) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                expdiff_M[n][q] = exp2f((prev_M[n][q] - frag_M[n][q]) * qk_scale);
                if (frag_M[n][q] == -std::numeric_limits::infinity()) {
                    expdiff_M[n][q] = 0.f;
                }
            }
            PRAGMA_UNROLL
            for (int m = 0; m < V_M; ++m) {
                PRAGMA_UNROLL
                for (int d = 0; d < 2; ++d) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        frag_O[m][n][d * 2 + q] *= expdiff_M[n][q];
                    }
                }
                Store((float*)&storage.O[m][n][warp_ids.y][warp_ids.x][lane_id], frag_O[m][n]);
            }
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                frag_L[n][q] *= expdiff_M[n][q];
                if constexpr (kDeferReduceL) {
                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 4);
                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 8);
                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 16);
                }
            }
            if (lane_id < 4) {
                Store((float*)&storage.L[n][warp_ids.y][warp_ids.x][lane_id], frag_L[n]);
            }
        }

        __syncthreads();

        clear(frag_O);
        clear(frag_L);

        PRAGMA_UNROLL
        for (int n = 0; n < V_N; ++n) {
            PRAGMA_UNROLL
            for (int w = 0; w < kWarpCntS; ++w) {
                using namespace ops;
                PRAGMA_UNROLL
                for (int m = 0; m < V_M; ++m) {
                    Array tmp_O;
                    Load(tmp_O, storage.O[m][n][warp_ids.y][w][lane_id].data());
                    frag_O[m][n] = frag_O[m][n] + tmp_O;
                }
                frag_L[n] = frag_L[n] + storage.L[n][warp_ids.y][w][lane_id % 4];
            }
            // PRAGMA_UNROLL
            // for (int q = 0; q < 2; ++q) {
            //     if (lane_id < 4) {
            //         printf("L %d %d %f\n", lane_id % 4 * 2 + q, warp_id, frag_L[n][q]);
            //     }
            // }

            // if (threadIdx.x == 0) {
            //     printf("L %d %f\n", 0, frag_L[0][0]);
            // }
        }
    }

    template
    __device__ static void StoreO(FragO& frag_O, const FragL& frag_L, SharedStorage& storage, Func&& func)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        FragL inv_L;
        PRAGMA_UNROLL
        for (int n = 0; n < V_N; ++n) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                inv_L[n][q] = fdividef(1.f, frag_L[n][q]);
            }
        }

        __syncthreads();

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; m += X) {
            PRAGMA_UNROLL
            for (int x = 0; x < X; ++x) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 2; ++d) {
                        if constexpr (is_norm) {
                            using namespace ops;
                            (Array&)frag_O[m + x][n][d * 2] =
                                (Array&)frag_O[m + x][n][d * 2] * inv_L[n];
                        }
                        PRAGMA_UNROLL
                        for (int q = 0; q < 2; ++q) {
                            const int hi = n * OP_N + lane_id % 4 * 2 + q * 1 + warp_ids.y * WARP_H;
                            // [43][2][10]
                            //   2  1
                            //   4  1
                            //   8  1
                            const int di = m * OP_M + lane_id / 4 % 2 + d * 8 * X + x * 2 + lane_id / 8 * X * 2;
                            if (warp_ids.x == 0) {
                                storage.O1[hi][di] = frag_O[m + x][n][d * 2 + q];
                                // if (hi == 0) {
                                //     printf("O %4d %4d %f\n", hi, di, frag_O[m][n][d * 2 + q]);
                                // }
                            }
                        }
                    }
                }
            }
        }

        __syncthreads();

        // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE
        using Map = std::conditional_t,
                                       RakedThreadMap>;
        Array tmp_O[Map::kIterS][Map::kIterC];

        const int  warp_id = threadIdx.x / WARP_SIZE;
        const int2 offset  = Map::get_offset(warp_id, lane_id);

        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                const int hi = offset.y + s * Map::kDeltaS;
                const int di = offset.x + c * Map::kDeltaC;
                Load(tmp_O[s][c], &storage.O1[hi][di]);
                ((Func &&) func)(hi, 0, di, tmp_O[s][c]);
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl_884.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/core/thread_map.h"

#include 
#include 

namespace turbomind::attention {

template
struct Impl {
    using T   = T_;
    using Tkv = T_;

    static constexpr bool MLA = false;

    static constexpr int CTA_H    = CTA_H_;
    static constexpr int CTA_Q    = CTA_Q_;
    static constexpr int CTA_S    = CTA_S_;
    static constexpr int kHeadDim = HeadDim;

    static constexpr int kWarpCntQ  = CTA_Q / WARP_Q;
    static constexpr int kWarpCntS  = CTA_S / WARP_S;
    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;

    static constexpr int OP_M = 16;
    static constexpr int OP_N = 16;
    static constexpr int OP_K = 4;

    static constexpr int K_M = WARP_Q / OP_M;   // 1
    static constexpr int K_N = WARP_S / OP_N;   // 4
    static constexpr int K_K = HeadDim / OP_K;  // 32

    static constexpr int V_M = WARP_Q / OP_M;   // 1
    static constexpr int V_N = HeadDim / OP_N;  // 8
    static constexpr int V_K = WARP_S / OP_K;   // 16

    //  +---+---+
    //  | 0 | 1 |
    //  +---+---+
    //  | 2 | 3 |
    //  +---+---+
    using FragQ = Array[K_K][K_M];   //    (q2,q2,x2,q4) (Dk,Qm) (d4)
                                              //      4  8  0  1    4 16    1
    using FragK = Array[K_K][K_N];   //    (s2,x2,s2,s4) (Dk,Sn) (d4)
                                              //      4  0  8  1    4 16    1
    using FragS = Array[K_M][K_N];  // (q2,q2,s2,s2,q2) (Qm,Sn) (s2,q2,s2)
                                              //   4  8  8  2  1   16 16    4  2  1
    using FragP = Array[V_K][V_M];   //    (q2,q2,x2,q4) (Sk,Qm) (s4)
                                              //      4  8  0  1    4 16    1
    using FragV = Array[V_K][V_N];   //    (d2,x2,d2,s4) (Sk,Dn) (d4)       [row major]
                                              //      4  0  8  1    4 16    1
    using FragO = Array[V_M][V_N];  // (q2,q2,d2,d2,q2) (Qm,Dn) (d2,q2,d2)
                                              //   4  8  8  2  1   16 16    4  2  1
    using FragM = Array[V_M];       // (q2,q2,_2,_2,q2) (Qm)    (q2))
    using FragL = FragM;

    // using Swizzle = Identity;

    struct SwizzleV {

        __device__ static int apply(int offset)
        {
            // Rearrange for LDS.128 (also avoid bank-conflict along C)
            // 6543210
            // dDDDDdd
            offset = ((offset & 8) << 2) ^ offset;                                     // x[5] ^= x[3]
            offset = ((offset & ~20) | (((offset & 16) >> 2) | ((offset & 4) << 2)));  // swap(x[4], x[2])

            // Shuffle C according S to avoid bank-conflict
            // ssssSSdDDddd
            offset = ((offset & (0x3 << 6)) >> 3) ^ offset;
            return offset;
        }

        __device__ int operator()(int offset)
        {
            return apply(offset);
        }
    };

    using SmemLayoutQ = SmemLayoutV2;
    using SmemLayoutP = SmemLayoutV2;
    using SmemLayoutK = SmemLayoutV2;
    using SmemLayoutV = SmemLayoutV2;

    using SmemLayoutKVp = void;

    struct SharedStorage {
        union {
            __align__(16) T Q[SmemLayoutQ::kSize];
            struct {
                __align__(16) T K[SmemLayoutK::kSize];
                __align__(16) T V[SmemLayoutV::kSize];
                __align__(16) T P[SmemLayoutP::kSize];
            };
        };
    };

    static constexpr bool kUseSmemQ = false;
    static constexpr bool kUseSmemP = false;

    // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE
    using ThreadMapQ  = std::conditional_t,
                                          RakedThreadMap>;
    using ThreadMapKV = std::conditional_t,
                                           RakedThreadMap>;

    using ThreadMapKVp = void;

    static constexpr bool kDeferReduceL = true;

    __device__ static void Sync()
    {
        __syncthreads();
    }

    template
    __device__ static void ForeachS(Fragment& S, Func&& func)
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int s1 = 0; s1 < 2; ++s1) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        PRAGMA_UNROLL
                        for (int s0 = 0; s0 < 2; ++s0) {
                            const int qi = m * OP_M + (lane_id & 8) + (lane_id & 1) + lane_id / 16 * 4 + q * 2;
                            const int si = n * OP_N + (lane_id & 4) * 2 + (lane_id & 2) + s1 * 4 + s0;
                            ((Func &&) func)(0, warp_id * WARP_Q + qi, si, /*ri*/ 0, S[m][n][s1 * 4 + q * 2 + s0]);
                        }
                    }
                }
            }
        }
    }

    __device__ static void TransformQ(const T* smem_Q, FragQ& frag_Q)
    {
        if constexpr (!kUseSmemQ) {
            const int warp_id = threadIdx.x / WARP_SIZE;
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    const int qi = m * OP_M + (lane_id & 8) + lane_id % 4 + lane_id / 16 * 4 + warp_id * WARP_Q;
                    const int di = k * 4;
                    Lds(frag_Q[k][m], &smem_Q[SmemLayoutQ::apply(qi, di)]);
                }
            }
        }
    }

    template
    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)
    {
        gmem_K.SetSmem(storage.K);
        gmem_V.SetSmem(storage.V);
    }

    struct StateQK {
        SmemAccessor smem_K;

        FragQ frag_Q;
        FragK frag_K;

        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.K}
        {
            static_assert(!kUseSmemQ, "not implemented");
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    frag_Q[k][m] = frag_Q_[k][m];
                }
            }
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                const int s = n * 16 + lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;
                const int c = k * 4;
                Lds(frag_K[k][n], &smem_K(s, c));
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void
    ComputeQK(StateQK& state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < K_K; ++k) {
            if (k < K_K - 1) {
                state_QK.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    const int nn = n ^ 1;
                    mma_m8n8k4_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);
                }
            }
        }
    }

    struct StatePV {
        T* smem_V;

        static_assert(V_N % 2 == 0);
        Array idxs_;

        FragP frag_P;
        FragV frag_V;

        __device__ StatePV(SharedStorage& storage, bool offset): smem_V{storage.V}
        {
            assert(offset);
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; n += 2) {
                const int s  = 0 * 4 + lane_id % 4;
                const int c  = n * 16 + lane_id / 16 * 4 + (lane_id & 4) * 2;
                idxs_[n / 2] = SmemLayoutV::apply(s, c);
            }
        }

        __device__ void Load(int k, int pipe_iter)
        {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; n += 2) {
                const int idx = idxs_[n / 2] + k * 4 * SmemLayoutV::C0;
                Lds((Array&)frag_V[k][n], &smem_V[idx]);
            }
        }

        __device__ void Transform(int k) {}
    };

    template
    __device__ static void
    ComputePV(StatePV& state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < V_K; ++k) {
            if (k < V_K - 1) {
                state_PV.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }
            PRAGMA_UNROLL
            for (int m = 0; m < V_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    mma_m8n8k4_row_row(frag_O[m][n], state_PV.frag_P[k][m], state_PV.frag_V[k][n], frag_O[m][n]);
                }
            }
        }
    }

    template
    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)
    {
        FragM prev_M;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            prev_M[m] = frag_M[m];
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int s1 = 0; s1 < 2; ++s1) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        PRAGMA_UNROLL
                        for (int s0 = 0; s0 < 2; ++s0) {
                            frag_M[m][q] =
                                fmaxf(frag_M[m][q], frag_S[m][n][s1 * 4 + q * 2 + s0]);  // reduce over local quad
                        }
                    }
                }
            }
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {  // reduce over thread group within warp (within warp tiles)
                frag_M[m][q] = fmaxf(frag_M[m][q], __shfl_xor_sync(uint32_t(-1), frag_M[m][q], 2));
                frag_M[m][q] = fmaxf(frag_M[m][q], __shfl_xor_sync(uint32_t(-1), frag_M[m][q], 4));
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                // exp(M - M'), isinf(frag_M) => isnan(expdiff_M)
                float expdiff_M = exp2f((prev_M[m][q] - frag_M[m][q]) * qk_scale);
                if (is_residue && frag_M[m][q] == -std::numeric_limits::infinity()) {
                    expdiff_M = 0.f;
                }
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    PRAGMA_UNROLL
                    for (int s1 = 0; s1 < 2; ++s1) {
                        PRAGMA_UNROLL
                        for (int s0 = 0; s0 < 2; ++s0) {
                            frag_O[m][n][s1 * 4 + q * 2 + s0] *= expdiff_M;  // Rescale previous output
                        }
                    }
                }
                frag_L[m][q] *= expdiff_M;
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                float tmp_L{};
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    PRAGMA_UNROLL
                    for (int s1 = 0; s1 < 2; ++s1) {
                        PRAGMA_UNROLL
                        for (int s0 = 0; s0 < 2; ++s0) {
                            // unnormalized prob, optimized to FFMA
                            float p = exp2f(frag_S[m][n][s1 * 4 + q * 2 + s0] * qk_scale - frag_M[m][q] * qk_scale);
                            if (is_residue && frag_M[m][q] == -std::numeric_limits::infinity()) {
                                p = 0.f;
                            }
                            tmp_L += p;
                            frag_S[m][n][s1 * 4 + q * 2 + s0] = p;
                        }
                    }
                }
                if constexpr (!kDeferReduceL) {
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 2);
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 4);
                }
                frag_L[m][q] = frag_L[m][q] + tmp_L;  // update L
            }
        }
    }

    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage& storage)
    {
        ForeachS(frag_S,
                 [&](int, int qi, int si, int ri, float p) { storage.P[SmemLayoutP::apply(qi, si)] = half(p); });

        if constexpr (!kUseSmemP) {
            const int warp_id = threadIdx.x / WARP_SIZE;
            const int lane_id = threadIdx.x % WARP_SIZE;
            PRAGMA_UNROLL
            for (int k = 0; k < V_K; ++k) {
                PRAGMA_UNROLL
                for (int m = 0; m < V_M; ++m) {
                    const int qi = m * OP_M + lane_id / 16 * 4 + (lane_id & 8) + lane_id % 4 + warp_id * WARP_Q;
                    const int si = k * OP_K;
                    Lds(frag_P[k][m], &storage.P[SmemLayoutP::apply(qi, si)]);
                }
            }
        }
    }

    template
    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)
    {
        /// FIXME: implement this
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;
        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {  // Q,16
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {  // Q,2
                const int qi = (lane_id & 1) * 1 + (lane_id & 16) / 4 + (lane_id & 8) + m * OP_M + q * 2;
                const int ri = (lane_id & 2) / 2 + (lane_id & 4) / 2;
                ((Func &&) func)(0, warp_id * WARP_Q + qi, ri, frag_M[m][q], frag_L[m][q]);
            }
        }
    };

    template
    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)
    {
        static_assert(kWarpCntS == 1);

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                if constexpr (kDeferReduceL) {
                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 2);
                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 4);
                }
            }
        }
    }

    template
    __device__ static void StoreO(FragO& frag_O, FragL& frag_L, SharedStorage& storage, Func&& func)
    {
        FragL inv_L;
        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                inv_L[m][q] = fdividef(1.f, frag_L[m][q] + 1e-8f);
            }
        }

        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        const int mm = lane_id / 16 * 4 + (lane_id & 8) + (lane_id & 1);
        const int nn = (lane_id & 4) * 2 + (lane_id & 2);

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                PRAGMA_UNROLL
                for (int d1 = 0; d1 < 2; ++d1) {
                    PRAGMA_UNROLL
                    for (int q = 0; q < 2; ++q) {
                        const int qi = m * OP_M + mm + q * 2 + warp_id * WARP_Q;
                        const int di = n * OP_N + nn + d1 * 4;
                        if constexpr (is_norm) {
                            PRAGMA_UNROLL
                            for (int d0 = 0; d0 < 2; ++d0) {
                                frag_O[m][n][d1 * 4 + q * 2 + d0] *= inv_L[m][q];
                            }
                        }
                        ((Func &&) func)(0, qi, di, (Array&)frag_O[m][n][d1 * 4 + q * 2]);
                    }
                }
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl_m16n8.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"

namespace turbomind::attention {

template
struct Impl_m16k8 {

    static constexpr int OP_M = 16;
    static constexpr int OP_N = 8;

    static constexpr int K_M = WARP_Q / OP_M;  //  16 / 16 = 1
    static constexpr int K_N = WARP_S / OP_N;  //  64 /  8 = 8

    static constexpr int V_M = WARP_Q / OP_M;   //  16 / 16 = 1
    static constexpr int V_N = HeadDim / OP_N;  // 128 /  8 = 16 -> D16

    template
    using FragS_ = Array[K_M][K_N];     // ((q8, s4), (Qm, Sn), (q2, s2))
                                              //    1   2    16   8     8   1
    using FragO = Array[V_M][V_N];  // ((q8, d4), (Qm, Dn), (q2, d2))
                                              //    1   2    16   8     8   1
    using FragM = Array[V_M];       // ((q8, _4), Qm, q2) => FragS with all S dim reduced
                                              //    1   0   16   8
    using FragS = FragS_;
    using FragL = FragM;

    static constexpr bool kDeferReduceL = false;

    template
    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {  // Q
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;
                const int ri = lane_id % 4 * 1;
                ((Func &&) func)(qi % WARP_H, qi / WARP_H, ri, frag_M[m][q], frag_L[m][q]);
            }
        }
    }

    template
    __device__ static void ForeachS(Fragment& S, Func&& func)
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {  // Q
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {  // KV
                PRAGMA_UNROLL
                for (int q = 0; q < 2; ++q) {
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;
                        const int ki = lane_id % 4 * 2 + n * OP_N + s * 1;
                        ((Func &&) func)(qi % WARP_H, qi / WARP_H, ki, /*ri*/ 0, S[m][n][q * 2 + s]);
                    }
                }
            }
        }
    }

    template
    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)
    {
        FragM prev_M;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            prev_M[m] = frag_M[m];
        }

        // maximum
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {  // Q
            auto& row_M = frag_M[m];
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {  // KV
                auto& C = frag_S[m][n];
                PRAGMA_UNROLL
                for (int q = 0; q < 2; ++q) {
                    row_M[q] = fmaxf(row_M[q], fmaxf(C[q * 2 + 0], C[q * 2 + 1]));  // reduce over local pair
                }
            }
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {  // reduce over thread group within warp (within warp tiles)
                row_M[q] = fmaxf(row_M[q], __shfl_xor_sync(uint32_t(-1), row_M[q], 1));
                row_M[q] = fmaxf(row_M[q], __shfl_xor_sync(uint32_t(-1), row_M[q], 2));
            }
        }

        FragM expdiff_M;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                // exp(M - M'), isinf(frag_M) => isnan(expdiff_M)
                expdiff_M[m][q] = exp2f((prev_M[m][q] - frag_M[m][q]) * qk_scale);
                if (is_residue && frag_M[m][q] == -std::numeric_limits::infinity()) {
                    expdiff_M[m][q] = 0.f;
                }
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                frag_L[m][q] *= expdiff_M[m][q];
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                float tmp_L{};
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        // unnormalized prob
                        float p = exp2f(frag_S[m][n][q * 2 + s] * qk_scale - frag_M[m][q] * qk_scale);
                        if (is_residue && frag_M[m][q] == -std::numeric_limits::infinity()) {
                            p = 0.f;
                        }
                        tmp_L += p;
                        frag_S[m][n][q * 2 + s] = p;
                    }
                }
                if constexpr (!kDeferReduceL) {
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 1);
                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 2);
                }
                frag_L[m][q] += tmp_L;  // update L
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                PRAGMA_UNROLL
                for (int q = 0; q < 2; ++q) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 2; ++d) {
                        frag_O[m][n][q * 2 + d] *= expdiff_M[m][q];  // Rescale previous output
                    }
                }
            }
        }
    }

    template
    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, Storage&)
    {
        FragS_& frag_Ps = (FragS_&)frag_P;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int q = 0; q < 2; ++q) {
                    PRAGMA_UNROLL
                    for (int s = 0; s < 2; ++s) {
                        frag_Ps[m][n][q * 2 + s] = static_cast(frag_S[m][n][q * 2 + s]);
                    }
                }
            }
        }
    }

    template
    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)
    {
    }

    template
    __device__ static void StoreO(FragO& frag_O, FragL& frag_L, Storage& storage, Func&& func)
    {
        FragL inv_L;
        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                inv_L[m][q] = fdividef(1.f, frag_L[m][q]);
            }
        }

        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int q = 0; q < 2; ++q) {
                const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    if constexpr (is_norm) {
                        PRAGMA_UNROLL
                        for (int d = 0; d < 2; ++d) {
                            frag_O[m][n][q * 2 + d] *= inv_L[m][q];
                        }
                    }
                    const int di = n * 8 + lane_id % 4 * 2;
                    ((Func &&) func)(qi % WARP_H, qi / WARP_H, di, (Array&)frag_O[m][n][q * 2]);
                }
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/impl_simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/thread_map.h"

#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/quantization.h"

namespace turbomind::attention {

template
struct Impl {

    using T   = T_;
    using Tkv = Tkv_;

    static constexpr int kQuantKV = !std::is_same_v;

    static constexpr bool MLA = HeadDim == 576;

    static constexpr int CTA_H = CTA_H_;
    static constexpr int CTA_Q = CTA_Q_;
    static constexpr int CTA_S = CTA_S_;

    static constexpr int WARP_H = WARP_H_;

    static constexpr int kHeadDim = HeadDim;

    static constexpr int kWarpCntH = CTA_H / WARP_H;
    static constexpr int kWarpCntQ = CTA_Q / WARP_Q;
    static constexpr int kWarpCntS = CTA_S / WARP_S;

    static constexpr int kWarpCount = kWarpCntH * kWarpCntQ * kWarpCntS;

    static_assert(kWarpCntQ == 1);

    static constexpr int VEC = 8;

    static constexpr int T_D = 8;                // warp thread C
    static constexpr int T_S = WARP_SIZE / T_D;  // warp thread S

    // warp footprint (1x4x64)
    static constexpr int OP_H = 1;
    static constexpr int OP_S = T_S;
    static constexpr int OP_D = VEC * T_D;

    static constexpr int K_M = WARP_H / OP_H;   // 1
    static constexpr int K_N = WARP_S / OP_S;   // 4
    static constexpr int K_K = HeadDim / OP_D;  // 2

    static constexpr int V_M = K_M;  // 1
    static constexpr int V_N = K_K;  // 2
    static constexpr int V_K = K_N;  // 4

    static_assert(WARP_H % OP_H == 0);
    static_assert(WARP_S % OP_S == 0);
    static_assert(HeadDim % OP_D == 0);

    using Tqk = std::conditional_t;
    using Tpv = Tqk;

    struct RakedD {
        static constexpr int S_D_thr = VEC * K_K;
        static constexpr int S_S_thr = 1;
        static constexpr int S_D     = VEC;
        static constexpr int S_S     = T_S;
        static constexpr int LDS     = std::gcd(16 / sizeof(Array), K_K);
    };

    struct LinearD {
        static constexpr int S_D_thr = VEC;
        static constexpr int S_S_thr = 1;
        static constexpr int S_D     = VEC * T_D;
        static constexpr int S_S     = T_S;
        static constexpr int LDS     = 1;
    };

    using ThreadMap = std::conditional_t;

    // Strides of thread index
    static constexpr int S_D_thr = ThreadMap::S_D_thr;
    static constexpr int S_S_thr = ThreadMap::S_S_thr;
    // Strides of array index
    static constexpr int S_D = ThreadMap::S_D;
    static constexpr int S_S = ThreadMap::S_S;
    // LDS vec count
    static constexpr int LDS_K = ThreadMap::LDS;
    static constexpr int LDS_V = ThreadMap::LDS;

    static_assert(LDS_K <= K_K);

    using FragQ = Array[K_M][K_K];      // (q4, d8), (Qm, Dk), (d8)
    template                          //   0  16     1   8     1
    using FragK_ = Array[K_N][K_K];    // (s4, d8), (Sn, Dk), (d8)
                                                //   4  16     1   8     1
    using FragS = Array[K_M][K_N];    // (s4, d8), (Qm, Sn)
                                                //   4  16     1   1
                                                // (s4, _8), (Qm, Sn)       [after redsum]
                                                //   4   0     1   1
    using FragM = Array[K_M];         // (_4, _8), (Qm)
                                                //   0   0     1
    using FragP = Array[V_M][V_K];      // (s4, _8), (Qm, Sk), (s1)
    template                          //   4   0     1   1     1
    using FragV_ = Array[V_K][V_N];    // (s4, d8), (Sk, Dn), (d8)
                                                //   4  16     1   8     1
    using FragO = Array[V_M][V_N];  // (s4, d8), (Qm, Dn), (d8)
                                                //   1  16     1   8     1
    using ParamK = Array[K_N];            // (s4, x8), (Sn)
                                                //   4   0     1
    using ParamV = Array[V_K];            // (s4, x8), (Sk)
                                                //   4   0     1
    using FragSp = Array[K_M][K_N];

    static_assert(sizeof(FragP) == sizeof(FragSp));

    using DataK = FragK_;
    using DataV = FragV_;

    using FragK = FragK_;
    using FragV = FragV_;

    using FragL = FragM;

    using SmemLayoutQ = SmemLayoutV2;
    using SmemLayoutP = SmemLayoutV2;
    using SmemLayoutK = SmemLayoutV2;
    using SmemLayoutV = SmemLayoutV2;

    using SmemLayoutKVp = SmemLayoutV2;

    using SmemM = float[K_M][kWarpCntH][kWarpCntS];
    using SmemL = float[K_M][kWarpCntH][kWarpCntS];
    using SmemO = Array[V_M][V_N][2][kWarpCntH][kWarpCntS][T_D];  // (Qm, Dn, d2, Hw, Sw, d8), (d4)
                                                                            //   1  64   4  WH  WS   8     1

    using PointerKV = get_pointer_type;

    union SharedStorage {
        __align__(16) T Q[SmemLayoutQ::kSize];

        struct {
            __align__(16) Array KV;
            __align__(16) T KVp[Stages * SmemLayoutKVp::kSize];
        };

        struct {
            __align__(16) SmemM M;
            __align__(16) SmemL L;
            __align__(16) SmemO O;
        };
    };

    static constexpr bool kUseSmemQ = false;
    static constexpr bool kUseSmemP = false;

    using ThreadMapQ  = RakedThreadMap;
    using ThreadMapKV = RakedThreadMap, kWarpCount>;
    // `WARP_SIZE / WARP_S` is chosen to achieve minimum kIterS w/o introducing partial S iter
    using ThreadMapKVp = RakedThreadMap<2, CTA_S, 2, kWarpCount, WARP_SIZE / WARP_S>;

    static constexpr int kBatchK = ThreadMapKV::kIterS;
    static constexpr int kBatchV = ThreadMapKV::kIterS;

    __device__ static void Sync()
    {
        if constexpr (kWarpCntH > 1) {
            __syncthreads();
        }
        if constexpr (kQuantKV) {  // Thread layout of KV & KVp is different within warp boundary
            __syncwarp();
        }
    }

    template
    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)
    {
        int pred = offset_kv;
        if constexpr (kQuantKV) {
            gmem_K.SetSmem(storage.KV.data(), storage.KVp);
            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize, storage.KVp + pred * SmemLayoutKVp::kSize);
        }
        else {
            gmem_K.SetSmem(storage.KV.data());
            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize);
        }
    }

    static __device__ int2 get_warp_ids()
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        if constexpr (kWarpCntH > 1) {
            return {warp_id % kWarpCntS, warp_id / kWarpCntS};
        }
        else {
            return {warp_id, 0};
        }
    }

    template
    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)
    {
        const auto warp_ids = get_warp_ids();
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {  // Q
            const int hi = m * OP_H + warp_ids.y * WARP_H;
            const int ri = threadIdx.x % (WARP_SIZE * kWarpCntS);
            ((Func &&) func)(hi, 0, ri, frag_M[m][0], frag_L[m][0]);
        }
    }

    template
    __device__ static void ForeachS(Fragment& S, Func&& func)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                const int hi = m * OP_H + warp_ids.y * WARP_H;
                const int si = lane_id / T_D * S_S_thr + n * S_S + warp_ids.x * WARP_S;
                const int ri = lane_id % T_D;
                ((Func &&) func)(hi, /*qi*/ 0, si, ri, S[m][n][0]);
            }
        }
    }

    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        __syncthreads();

        SmemAccessor sQ{smem_Q};

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                const int hi = m + warp_ids.y * WARP_H;
                const int di = k * S_D + lane_id % T_D * S_D_thr;
                Lds(frag_Q[m][k], &sQ(hi, di));
            }
        }
    }

    struct StateQK {
        PointerKV smem_K;
        T*        smem_K_param;
        FragQ     frag_Q;
        FragK     frag_K;
        DataK     data_K;
        ParamK    param_K;

        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_)
        {
            smem_K       = storage.KV.data();
            smem_K_param = storage.KVp;
            if constexpr (!kUseSmemQ) {
                PRAGMA_UNROLL
                for (int m = 0; m < K_M; ++m) {
                    PRAGMA_UNROLL
                    for (int k = 0; k < K_K; ++k) {
                        frag_Q[m][k] = frag_Q_[m][k];
                    }
                }
            }
        }

        __device__ void Load(int n, int pipe_iter)
        {
            const auto warp_ids = get_warp_ids();
            const int  lane_id  = threadIdx.x % WARP_SIZE;

            const int offset_s = lane_id / T_D * S_S_thr + warp_ids.x * WARP_S;
            const int offset_c = lane_id % T_D * S_D_thr;

            if (kQuantKV && n == 0) {
                PRAGMA_UNROLL
                for (int n = 0; n < K_N; ++n) {
                    const int si = n * S_S + offset_s;
                    Lds(param_K[n], &smem_K_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);
                }
            }

            PRAGMA_UNROLL
            for (int k = 0; k < K_K; k += LDS_K) {
                const int si = n * S_S + offset_s;
                const int di = k * S_D + offset_c;
                Lds((Array&)data_K[n][k],
                    &smem_K[pipe_iter * SmemLayoutK::kSize + SmemLayoutK::apply(si, di)]);
            }
        }

        __device__ void Transform(int n)
        {
            PRAGMA_UNROLL
            for (int k = 0; k < K_K; ++k) {
                ConvertKvCache convert(param_K[n][0], param_K[n][1]);
                frag_K[n][k] = convert(data_K[n][k]);
            }
        }
    };

    template
    __device__ static void
    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        if constexpr (K_N == 1) {
            ((Prefetch &&) prefetch)(0);
        }

        PRAGMA_UNROLL
        for (int n = 0; n < K_N; ++n) {
            if (n < K_N - 1) {
                state_QK.Load(n + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }

            state_QK.Transform(n);

            PRAGMA_UNROLL
            for (int m = 0; m < K_M; ++m) {
                PRAGMA_UNROLL
                for (int k = 0; k < K_K; ++k) {
                    PRAGMA_UNROLL
                    for (int c = 0; c < 8; ++c) {
                        frag_S[m][n][0] += static_cast((Tqk)state_QK.frag_Q[m][k][c] * state_QK.frag_K[n][k][c]);
                    }
                }
            }

            if (n < K_N - 1) {
                ((Prefetch &&) prefetch)(n);
            }
            if (n == K_N - 2) {
                ((Prefetch &&) prefetch)(K_N - 1);
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                PRAGMA_UNROLL
                for (int mask = 1; mask < T_D; mask *= 2) {
                    frag_S[m][n][0] += __shfl_xor_sync(uint32_t(-1), frag_S[m][n][0], mask);
                }
            }
        }
    }

    struct StatePV {
        PointerKV smem_V;
        T*        smem_V_param;
        FragP     frag_P;
        FragV     frag_V;
        DataV     data_V;
        ParamV    param_V;

        __device__ StatePV(SharedStorage& storage, bool offset = false)
        {
            smem_V       = storage.KV.data() + (offset ? SmemLayoutK::kSize : 0);
            smem_V_param = storage.KVp + (offset ? SmemLayoutKVp::kSize : 0);
        }

        __device__ void Load(int k, int pipe_iter)
        {
            const auto warp_ids = get_warp_ids();
            const int  lane_id  = threadIdx.x % WARP_SIZE;

            const int offset_s = lane_id / T_D * S_S_thr + warp_ids.x * WARP_S;
            const int offset_c = lane_id % T_D * S_D_thr;

            if (kQuantKV && k == 0) {
                PRAGMA_UNROLL
                for (int k = 0; k < V_K; ++k) {
                    const int si = k * S_S + offset_s;
                    Lds(param_V[k], &smem_V_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);
                }
            }

            PRAGMA_UNROLL
            for (int n = 0; n < V_N; n += LDS_V) {
                const int si = k * S_S + offset_s;
                const int di = n * S_D + offset_c;
                Lds((Array&)data_V[k][n],
                    &smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(si, di)]);
            }
        }

        __device__ void Transform(int k)
        {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                ConvertKvCache convert(param_V[k][0], param_V[k][1]);
                frag_V[k][n] = convert(data_V[k][n]);
            }
        }
    };

    template
    __device__ static void
    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)
    {
        if constexpr (V_K == 1) {
            ((Prefetch &&) prefetch)(0);
        }

        PRAGMA_UNROLL
        for (int k = 0; k < V_K; ++k) {
            if (k < V_K - 1) {
                state_PV.Load(k + 1, offset);
            }
            else {
                ((Preload &&) preload)();
            }

            state_PV.Transform(k);

            PRAGMA_UNROLL
            for (int m = 0; m < V_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < V_N; ++n) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 8; ++d) {
                        frag_O[m][n][d] += static_cast((Tpv)state_PV.frag_P[m][k][0] * state_PV.frag_V[k][n][d]);
                    }
                }
            }

            if (k < V_K - 1) {
                ((Prefetch &&) prefetch)(k);
            }
            if (k == V_K - 2) {
                ((Prefetch &&) prefetch)(V_K - 1);
            }
        }
    }

    template
    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragL& frag_L, FragO& frag_O, float qk_scale)
    {
        FragM prev_M;
        copy(frag_M, prev_M);

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                frag_M[m][0] = fmaxf(frag_M[m][0], frag_S[m][n][0]);
            }
        }

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            float expdiff_M = exp2f((prev_M[m][0] - frag_M[m][0]) * qk_scale);
            if (is_residue && frag_M[m][0] == -std::numeric_limits::infinity()) {
                expdiff_M = 0.f;
            }
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                using namespace ops;
                frag_O[m][n] = frag_O[m][n] * expdiff_M;
            }
            frag_L[m][0] *= expdiff_M;
        }

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            float tmp_L{};
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                float p = exp2f(frag_S[m][n][0] * qk_scale - frag_M[m][0] * qk_scale);
                if (is_residue && frag_M[m][0] == -std::numeric_limits::infinity()) {
                    p = 0.f;
                }
                tmp_L += p;
                frag_S[m][n][0] = p;
            }
            frag_L[m][0] += tmp_L;
        }
    }

    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage&)
    {
        FragSp& frag_Sp = (FragSp&)frag_P;
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < K_N; ++n) {
                frag_Sp[m][n][0] = static_cast(frag_S[m][n][0]);
            }
        }
    }

    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, SharedStorage& storage)
    {
        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        FragM prev_M;
        copy(frag_M, prev_M);

        __syncthreads();

        /////////////////////////////////////////////////////////////////////////
        //  global max
        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {
                frag_M[m][0] = fmaxf(frag_M[m][0], __shfl_xor_sync(uint32_t(-1), frag_M[m][0], mask));
            }
            if (lane_id == 0) {
                // printf("warp M %d %f\n", warp_id, frag_M[m][0]);
                storage.M[m][warp_ids.y][warp_ids.x] = frag_M[m][0];
            }
        }

        __syncthreads();

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            PRAGMA_UNROLL
            for (int w = 0; w < kWarpCntS - 1; ++w) {
                frag_M[m][0] = fmaxf(frag_M[m][0], storage.M[m][warp_ids.y][(warp_ids.x + w + 1) % kWarpCntS]);
            }
            // if (threadIdx.x == 0) {
            //     printf("M %d %f\n", m * OP_H + blockIdx.x * CTA_H, frag_M[m][0]);
            // }
        }

        ///////////////////////////////////////////////////////////////////////////
        //  rescale & global sum
        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            float expdiff_M = exp2f((prev_M[m][0] - frag_M[m][0]) * qk_scale);
            if (frag_M[m][0] == -std::numeric_limits::infinity()) {
                expdiff_M = 0.f;
            }
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                PRAGMA_UNROLL
                for (int d = 0; d < 8; ++d) {
                    frag_O[m][n][d] = frag_O[m][n][d] * expdiff_M;
                    PRAGMA_UNROLL
                    for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {
                        frag_O[m][n][d] += __shfl_xor_sync(uint32_t(-1), frag_O[m][n][d], mask);
                    }
                }
                PRAGMA_UNROLL
                for (int d = 0; d < 8; d += 4) {
                    if (lane_id < T_D) {
                        Store(storage.O[m][n][d / 4][warp_ids.y][warp_ids.x][lane_id].data(),
                              (Array&)frag_O[m][n][d]);
                    }
                }
            }
            frag_L[m][0] *= expdiff_M;
            PRAGMA_UNROLL
            for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {
                frag_L[m][0] += __shfl_xor_sync(uint32_t(-1), frag_L[m][0], mask);
            }
            if (lane_id == 0) {
                storage.L[m][warp_ids.y][warp_ids.x] = frag_L[m][0];
            }
        }

        __syncthreads();

        clear(frag_O);

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
#if 0
                static_assert(kWarpCntS % 4 == 0);
                PRAGMA_UNROLL
                for (int s = 0; s < kWarpCntS; s += 4) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 8; d += 4) {
                        Array tmp_O;
                        Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s + lane_id / 8][lane_id % T_D].data());
                        using namespace ops;
                        (Array&)frag_O[m][n][d] = (Array&)frag_O[m][n][d] + tmp_O;
                    }
                }
                PRAGMA_UNROLL
                for (int d = 0; d < 8; ++d) {
                    PRAGMA_UNROLL
                    for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {
                        frag_O[m][n][d] += __shfl_xor_sync(uint32_t(-1), frag_O[m][n][d], mask);
                    }
                }
#else
                PRAGMA_UNROLL
                for (int s = 0; s < kWarpCntS; ++s) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 8; d += 4) {
                        Array tmp_O;
                        Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s][lane_id % T_D].data());
                        using namespace ops;
                        (Array&)frag_O[m][n][d] = (Array&)frag_O[m][n][d] + tmp_O;
                    }
                }
#endif
            }
            PRAGMA_UNROLL
            for (int w = 0; w < kWarpCntS - 1; ++w) {
                frag_L[m][0] += storage.L[m][warp_ids.y][(warp_ids.x + w + 1) % kWarpCntS];
            }
            // if (threadIdx.x == 0) {
            //     printf("L %d %f\n", m * OP_H + blockIdx.x * CTA_H, frag_L[m][0]);
            // }
        }
    }

    template
    __device__ static void StoreO(FragO& frag_O, const FragL& frag_L, SharedStorage& storage, Func&& func)
    {
        FragL inv_L;

        PRAGMA_UNROLL
        for (int m = 0; m < K_M; ++m) {
            inv_L[m][0] = fdividef(1.f, frag_L[m][0]);
        }

        const auto warp_ids = get_warp_ids();
        const int  lane_id  = threadIdx.x % WARP_SIZE;

        if (warp_ids.x != 0) {
            return;
        }

        PRAGMA_UNROLL
        for (int m = 0; m < V_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < V_N; ++n) {
                if constexpr (is_norm) {
                    PRAGMA_UNROLL
                    for (int d = 0; d < 8; ++d) {
                        frag_O[m][n][d] *= inv_L[m][0];
                    }
                }

                if (lane_id < T_D) {
                    const int hi = m * OP_H + warp_ids.y * WARP_H;
                    const int di = n * S_D + lane_id * S_D_thr;
                    // for (int i = 0; i < 8; ++i) {
                    //     printf("O %4d %4d %f\n", hi + blockIdx.x * CTA_H, di + i, frag_O[m][n][i]);
                    // }
                    ((Func &&) func)(hi, 0, di, frag_O[m][n]);
                }
            }
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/iterator.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/pipe_iter.h"
#include 

namespace turbomind {

template
struct BaseGmemIterator {
    using ElementType = T;
    using AccessType  = Array;
    using Pointer     = get_pointer_type;

    static constexpr int kElementSize = sizeof(ElementType);
    static constexpr int kAccessSize  = sizeof(AccessType);
    static constexpr int kIterCount   = Map::kIterS * Map::kIterC;

    using Fragment = Array[Map::kIterS][Map::kIterC];

    Pointer smem_;

    int src_offset_;
    int offset_c_;
    int offset_s_;

    static constexpr std::integral_constant partial_c_{};

    std::conditional_t pred_c_;

    __device__ BaseGmemIterator()
    {
        int  warp_id = threadIdx.x / WARP_SIZE;
        int  lane_id = threadIdx.x % WARP_SIZE;
        int2 offsets = Map::get_offset(warp_id, lane_id);
        src_offset_  = offsets.x + offsets.y * Map::kDimC;
        offset_c_    = offsets.x;
        offset_s_    = offsets.y;
        if constexpr (partial_c_) {
            pred_c_ = offset_c_ < Map::kDimC;
        }
    }

    __device__ void SetSmem(Pointer smem)
    {
        smem_ = smem;
    }

    __device__ void ClearSmem(int pipe_iter = 0)
    {
        SmemAccessor data{smem_};
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                if (pred_c_) {
                    Store(&data(offset_s_ + s * Map::kDeltaS,
                                offset_c_ + c * Map::kDeltaC,
                                pipe_iter * SmemLayout::kSize),
                          Array{});
                }
            }
        }
    }
};

template
struct BaseSmemIterator {
    static constexpr int kElemSize = sizeof(T);

    using Accessor = SmemAccessor;
    T* smem_;

    __device__ explicit BaseSmemIterator(T* smem): smem_{smem} {}
};

template
struct CombinedIterator {
    Iterator0 iterator0_;
    Iterator1 iterator1_;

    struct Fragment {
        typename Iterator0::Fragment frag0;
        typename Iterator1::Fragment frag1;
    };

    // NOTE: can't use reference type here, nvcc does not support variadic templates well in device code
    template
    __device__ void Prefetch(Args... args)
    {
        iterator0_.Prefetch(args...);
        iterator1_.Prefetch(args...);
    }

    /// TODO: Load(bool_constant, CacheIter&) -> Fragment
    template
    __device__ void Load(const CacheIter& cache_iter, Fragment& frag, int max_s)
    {
        iterator0_.Load(cache_iter, frag.frag0, max_s);
        iterator1_.Load(cache_iter, frag.frag1, max_s);
    }

    __device__ void Save(const Fragment& frag)
    {
        iterator0_.Save(frag.frag0);
        iterator1_.Save(frag.frag1);
    }

    __device__ void ClearSmem(int pipe_iter = 0)
    {
        iterator0_.ClearSmem(pipe_iter);
        iterator1_.ClearSmem(pipe_iter);
    }

    template
    __device__ void SetSmem(P0 p0, P1 p1)
    {
        iterator0_.SetSmem(p0);
        iterator1_.SetSmem(p1);
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/iterator_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "iterator.h"
#include "src/turbomind/kernels/core/array_ops.h"

namespace turbomind {

template
struct Sm70GmemIterator: BaseGmemIterator {
    using Base = BaseGmemIterator;

    using typename Base::AccessType;
    using typename Base::Fragment;

    using Base::src_offset_;
    using Base::offset_c_;
    using Base::offset_s_;
    using Base::smem_;

    using Base::partial_c_;
    using Base::pred_c_;

    using Base::Base;

    template
    __device__ void Load(const TileIter& tile_iter, Fragment& rmem, int max_s)
    {
        auto src_data = tile_iter.OffsetPtr(src_offset_);
        int  offset_s = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE).y;
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                copy(Array{}, rmem[s][c]);
                auto src = &src_data[s * Map::kDeltaS * Map::kDimC + c * Map::kDeltaC];
                if constexpr (partial_c_) {  // Only quant params is partial C
                    if (pred_c_) {
                        Ldg(rmem[s][c], src);
                    }
                }
                else if (!is_residue || offset_s + s * Map::kDeltaS < max_s) {
                    Ldg(rmem[s][c], src);
                }
            }
        }
    }

    __device__ void Save(const Fragment& rmem)
    {
        typename SmemLayout::Swizzle swizzle{};

        SmemAccessor data{smem_};
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                if (!partial_c_ || pred_c_) {
                    Store(&data(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC), rmem[s][c]);
                }
            }
        }
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/iterator_sm80.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "iterator.h"
#include "src/turbomind/kernels/core/smem.h"
#include 
#include 

namespace turbomind {

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif

template
struct Sm80GmemIterator: BaseGmemIterator {

    using Base = BaseGmemIterator;

    using typename Base::AccessType;

    using Base::Base;
    using Base::kElementSize;
    using Base::src_offset_;
    using Base::offset_c_;
    using Base::offset_s_;
    using Base::smem_;

    using Base::partial_c_;
    using Base::pred_c_;

    template
    __device__ void
    Prefetch(PartialS partial_s, const TileIter& tile_iter, int s_begin, int s_count, int max_s, int pipe_iter)
    {
        // `src_data` may be `SubBytePtr`
        auto src_data = tile_iter.OffsetPtr(src_offset_);

        SmemAccessor dst_data{smem_};

        PRAGMA_UNROLL
        for (int s = s_begin; s < s_begin + s_count && s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                auto dst = cast_smem_ptr_to_uint(&dst_data(offset_s_ + s * Map::kDeltaS,  //
                                                           offset_c_ + c * Map::kDeltaC,
                                                           pipe_iter * SmemLayout::kSize));
                auto src = &src_data[s * Map::kDeltaS * Map::kDimC + c * Map::kDeltaC];

                if constexpr (partial_c_) {
                    CpAsync(std::true_type{}, dst, (const T*)src, pred_c_);
                }
                else {
                    CpAsync(partial_s, dst, (const T*)src, offset_s_ + s * Map::kDeltaS < max_s);
                }
            }
        }
    }

    template
    __device__ void Prefetch(Partial partial, const TileIter& tile_iter, int max_s, int pipe_iter)
    {
        Prefetch(partial, tile_iter, 0, Map::kIterS, max_s, pipe_iter);
    }

    __device__ void CpAsync(std::true_type, int ptr, const T* __restrict__ src, bool mask)
    {
#if TURBOMIND_ARCH_SM80
        constexpr int size = sizeof(AccessType);
        // clang-format off
        if constexpr (size == 16) {
            asm volatile("{\n"
                        "  .reg .pred p;\n"
                        "  setp.ne.b32 p, %0, 0;\n"
                        "  @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
                        "}\n" ::"r"((int)mask),
                        "r"(ptr),
                        "l"(src),
                        "n"(size));
        } else {
            asm volatile("{\n"
                        "  .reg .pred p;\n"
                        "  setp.ne.b32 p, %0, 0;\n"
                        "  @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
                        "}\n" ::"r"((int)mask),
                        "r"(ptr),
                        "l"(src),
                        "n"(size));
        }
        // clang-format on
#else
        assert(TURBOMIND_ARCH_SM80);
#endif
    }

    __device__ void CpAsync(std::false_type, int ptr, const T* __restrict__ src, bool)
    {
#if TURBOMIND_ARCH_SM80
        constexpr int size = sizeof(AccessType);
        if constexpr (size == 16) {
            asm volatile(
                "cp.async.cg.shared.global" L2_CACHEHINT(128) " [%0], [%1], %2;\n" ::"r"(ptr), "l"(src), "n"(size));
        }
        else {
            asm volatile(
                "cp.async.ca.shared.global" L2_CACHEHINT(128) " [%0], [%1], %2;\n" ::"r"(ptr), "l"(src), "n"(size));
        }
#else
        assert(TURBOMIND_ARCH_SM80);
#endif
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/kernel/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

add_library(attention_kernels STATIC
            ../utils.cc
            ../reduce.cu
            attention_sm70_64.cu
            attention_sm70_128.cu
            attention_sm70_256.cu
            attention_sm70_576.cu
            attention_sm75_64.cu
            attention_sm75_128.cu
            attention_sm75_256.cu
            attention_sm75_576.cu
            attention_sm80_64.cu
            attention_sm80_128.cu
            attention_sm80_192.cu
            attention_sm80_256.cu
            attention_sm80_576.cu
            decoding_sm70_64.cu
            decoding_sm70_128.cu
            decoding_sm70_256.cu
            decoding_sm70_576.cu
            decoding_sm75_64.cu
            decoding_sm75_128.cu
            decoding_sm75_256.cu
            decoding_sm75_576.cu
            decoding_sm80_64.cu
            decoding_sm80_128.cu
            decoding_sm80_192.cu
            decoding_sm80_256.cu
            decoding_sm80_576.cu
            )
set_property(TARGET attention_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(attention_kernels PRIVATE -O3
    $<$:-use_fast_math --expt-relaxed-constexpr  -Xptxas=-v --threads 8>)
target_link_libraries(attention_kernels PRIVATE nvidia::cutlass::cutlass)


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm70_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_884.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm70,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm70_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_884.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm70,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm70_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_884.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

// HeadDim=576 on Sm70: kCTA_S=32, WARP_S=kCTA_S to fit within V100's 96 KB shared memory limit
constexpr int kHeadDim = 576;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 32;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm70,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm70_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_884.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 64;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm70,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm75_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_1688.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm75,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm75_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_1688.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm75,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm75_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_1688.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 576;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 32;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm75,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm75_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_1688.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 64;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm75,
    Mainloop>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) { c.add>(); });
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm80_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_16816.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
#if ENABLE_BF16
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm80_192.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_16816.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 192;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
#if ENABLE_BF16
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm80_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_16816.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
#if ENABLE_BF16
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm80_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_16816.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 576;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 32;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
#if ENABLE_BF16
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/attention_sm80_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_16816.h"
#include "src/turbomind/kernels/attention/linear_iterator.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

// HeadDim=64 special case: kCTA_S=128, WARP_S=kCTA_S
constexpr int kHeadDim = 64;
constexpr int kCTA_Q   = 64;
constexpr int kCTA_S   = 128;
constexpr int kWARP_Q  = 16;
constexpr int kStages  = 2;

template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    LinearIteratorFactory,
    AttentionCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
#if ENABLE_BF16
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 2;

// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)
// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_S   = 32;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 2;

// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)
// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 576;

// CTA_H=2, CTA_S=16, WARP_H=1, WARP_S=8, Stages=2
template
using KT = AttentionUniversal>,
                              GetBlockIterFactory,
                              DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 2;

// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)
// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();

    c.add>();
    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 2;

// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();

    c.add>();
    c.add>();

    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 3;

// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16
// For 256 head dim, we use Qh=1 and Qh=9 (which maps to 16)
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();  // Qh=9 maps to 16

    c.add>();
    c.add>();  // Qh=9 maps to 16

    c.add>();
    c.add>();  // Qh=9 maps to 16
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 576;

// MLA config for all Tkv: CTA_H=16, CTA_S=16, WARP_H=8, WARP_S=16, Stages=2
template
using KT = AttentionUniversal>,
                              GetBlockIterFactory,
                              DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm70.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 2;

// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16
template
using KT =
    AttentionUniversal>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();

    c.add>();
    c.add>();

    c.add>();
    c.add>();
});
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 128;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;

template
using KT = AttentionUniversal;

// T==Tkv, Qh<=2: SIMT, stages=3
template
using Decoding_SIMT = KT, Impl>,
                         GetBlockIterFactory>;

// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv
// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16
template
using Decoding_MMA =
    KT, Impl>,
       GetBlockIterFactory>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();

#if ENABLE_BF16
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 192;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;
constexpr int kStages  = 3;
constexpr int kQh      = 1;

// HeadDim=192 uses SIMT+kStages for all Tkv (incl. uint8_t), kQh=1 only
template
using KT = AttentionUniversal<
    arch::Sm80,
    Mainloop, Impl>,
    GetBlockIterFactory,
    DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();

#if ENABLE_BF16
    c.add>();
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 256;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;

template
using KT = AttentionUniversal;

// T==Tkv, Qh<=2: SIMT, stages=3
template
using Decoding_SIMT = KT, Impl>,
                         GetBlockIterFactory>;

// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv
// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16
template
using Decoding_MMA =
    KT, Impl>,
       GetBlockIterFactory>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();

#if ENABLE_BF16
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 576;

// Non-quant MLA config: CTA_H=16, CTA_S=32, WARP_H=8, WARP_S=16, Stages=2
template
using Decoding_F =
    AttentionUniversal, Impl>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

// Quant config: CTA_H=8, CTA_S=64, WARP_H=8, WARP_S=16, Stages=5
template
using Decoding_Q =
    AttentionUniversal, Impl>,
                       GetBlockIterFactory,
                       DecodingCtaMap>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();

#if ENABLE_BF16
    c.add>();
    c.add>();
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/attention_universal.h"
#include "src/turbomind/kernels/attention/block_iterator.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/impl.h"
#include "src/turbomind/kernels/attention/impl_81616.h"
#include "src/turbomind/kernels/attention/impl_simt.h"
#include "src/turbomind/kernels/attention/mainloop.h"
#include "src/turbomind/kernels/attention/mainloop_sm80.h"
#include "src/turbomind/kernels/attention/registrar.h"

namespace turbomind::attention {

constexpr int kHeadDim = 64;
constexpr int kCTA_S   = 64;
constexpr int kWARP_S  = 16;

template
using KT = AttentionUniversal;

// T==Tkv, Qh<=2: SIMT, stages=3
template
using Decoding_SIMT = KT, Impl>,
                         GetBlockIterFactory>;

// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv
// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16
template
using Decoding_MMA =
    KT, Impl>,
       GetBlockIterFactory>;

namespace {
Registrar reg([](Collector& c) {
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();

#if ENABLE_BF16
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
    c.add>();
#endif
});
}  // namespace

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/desc.h"

namespace turbomind::attention {

class Kernel {
public:
    Kernel(): desc_{}, info_{} {}

    virtual ~Kernel() = default;

    virtual bool Launch(const void* params, int sm_count) const = 0;

    const KernelDesc& desc() const noexcept
    {
        return desc_;
    }

    const KernelInfo& info() const noexcept
    {
        return info_;
    }

    int arch() const noexcept
    {
        return desc_.arch;
    }

    int smem_size() const noexcept
    {
        return info_.attr.sharedSizeBytes + info_.dynamic_smem_size;
    }

    const std::string& name() const
    {
        return info_.name;
    }

protected:
    KernelDesc desc_;
    KernelInfo info_;
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kernel_impl.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/attention/attention_template.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/attention/decoding_template.h"
#include "src/turbomind/kernels/attention/kernel.h"
#include "src/turbomind/kernels/core/common.h"

namespace turbomind::attention {

template
constexpr int kv_quant_from_type()
{
    if constexpr (std::is_same_v) {
        return 8;
    }
    else if constexpr (std::is_same_v) {
        return 4;
    }
    else {
        return 0;
    }
}

template
class KernelImpl: public Kernel {
    static constexpr bool kIsDecoding = std::is_same_v;

public:
    KernelImpl()
    {
        desc_.mode      = kIsDecoding ? AttnDesc::kDecoding : AttnDesc::kPrefill;
        desc_.arch      = K::Arch::value;
        desc_.head_dim  = K::kHeadDim;
        desc_.data_type = data_type_v;

        if constexpr (kIsDecoding) {
            desc_.kv_quant = kv_quant_from_type();
            desc_.qh       = K::CTA_H;
        }
        else {
            desc_.kv_quant = 0;
            desc_.qh       = 1;
        }

        auto func               = &attention_kernel;
        info_.dynamic_smem_size = sizeof(typename K::SharedStorage);

        cudaFuncGetAttributes(&info_.attr, func);

        if (info_.dynamic_smem_size > (48 << 10)) {
            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);
        }

        info_.num_warps = K::kWarpCount;
        cudaOccupancyMaxActiveBlocksPerMultiprocessor(
            &info_.max_active_ctas, func, info_.num_warps * WARP_SIZE, info_.dynamic_smem_size);

        info_.name = to_string(desc_);
    }

    bool Launch(const void* params, int sm_count) const override
    {
        const auto& p = *static_cast(params);
        if constexpr (kIsDecoding) {
            return invokeDecoding(p, sm_count, info_.max_active_ctas);
        }
        else {
            invokeAttention(p, sm_count, info_.max_active_ctas);
            return true;
        }
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/kv_cache_utils_v2.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/kernels/attention/block.h"
#include "src/turbomind/kernels/attention/kv_cache_utils_v2.h"
#include "src/turbomind/kernels/attention/quantization.h"
#include "src/turbomind/kernels/attention/rotary_embedding.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/thread_map.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

using cutlass::FastDivmod;

template
__global__ void __launch_bounds__(128) ProcessKV_v2(char**          blocks,
                                                    const T*        k,
                                                    const T*        v,
                                                    const T*        k_bias,
                                                    const T*        v_bias,
                                                    const int*      cu_q_len,
                                                    const int*      cu_k_len,
                                                    const int*      cu_block_num,
                                                    RopeKernelParam rope_param,
                                                    int64_t         stride_b,
                                                    int64_t         stride_c,
                                                    int64_t         stride_h,
                                                    int64_t         stride_s,
                                                    int             layer_id,
                                                    int             cp_rank,
                                                    FastDivmod      cp_size,
                                                    BlockLayout     block_layout)
{

    constexpr int kVecSize = sizeof(uint4) / sizeof(T);

    using Vec = Array;
    using Map = RakedThreadMap;

    constexpr int ITER_C = Map::kIterC;
    constexpr int ITER_S = Map::kIterS;

    constexpr bool HAS_V = !(typename BlockLayout::Config{}.is_share_kv());

    const int token_idx = blockIdx.x * CTA_S;  // local offset into `input_length`
    const int head_idx  = blockIdx.y;
    const int batch_idx = blockIdx.z;

    const int qi_beg = cu_q_len[batch_idx];
    const int qi_end = cu_q_len[batch_idx + 1];
    const int q_len  = qi_end - qi_beg;

    const int k_len       = cu_k_len[batch_idx + 1] - cu_k_len[batch_idx];
    const int history_len = k_len - q_len;

    if (qi_beg + token_idx >= qi_end) {  // empty tile
        return;
    }

    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;

    const int2 offset = Map::get_offset(warp_id, lane_id);

    Vec __align__(16) vec_K[ITER_S][ITER_C];
    Vec __align__(16) vec_V[ITER_S][ITER_C];

    Vec bias_V[ITER_C];
    Vec bias_K[ITER_C];

    if (k_bias) {
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int di = offset.x + c * Map::kDeltaC;
            Ldg(bias_K[c], &k_bias[head_idx * HeadDim + di]);
        }
    }
    if (v_bias && HAS_V) {
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int di = offset.x + c * Map::kDeltaC;
            Ldg(bias_V[c], &v_bias[head_idx * HeadDim + di]);
        }
    }

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int     qi = offset.y + s * Map::kDeltaS + token_idx;  // sequence local
            const int     di = offset.x + c * Map::kDeltaC;
            const int64_t index =
                (batch_idx * stride_b + qi_beg * stride_c + qi * stride_s + head_idx * stride_h) * HeadDim + di;
            if (qi < q_len) {
                Ldg(vec_K[s][c], &k[index]);
                if constexpr (HAS_V) {
                    Ldg(vec_V[s][c], &v[index]);
                }
            }
        }
    }

    if (k_bias) {
        using namespace ops;
        PRAGMA_UNROLL
        for (int s = 0; s < ITER_S; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                vec_K[s][c] = vec_K[s][c] + bias_K[c];
            }
        }
    }
    if (v_bias && HAS_V) {
        using namespace ops;
        PRAGMA_UNROLL
        for (int s = 0; s < ITER_S; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < ITER_C; ++c) {
                vec_V[s][c] = vec_V[s][c] + bias_V[c];
            }
        }
    }

    if (rope_param.type != RopeType::kNull) {
        FastRoPE rope(rope_param, batch_idx, std::integral_constant{});
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int di = offset.x + c * Map::kDeltaC;
            rope.init(di);
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx;  // sequence local
                rope.apply(vec_K[s][c], ti);
            }
        }
    }

    Array param_K[ITER_S];
    Array param_V[ITER_S];

    if constexpr (!std::is_same_v) {
        warp_stats(param_K, vec_K, bitsof);
        if constexpr (HAS_V) {
            warp_stats(param_V, vec_V, bitsof);
        }
    }

    Array out_K[ITER_S][ITER_C];
    Array out_V[ITER_S][ITER_C];

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        ConvertKvCache conv_K{param_K[s][0], param_K[s][1]};
        ConvertKvCache conv_V{param_V[s][0], param_V[s][1]};
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            out_K[s][c] = conv_K(vec_K[s][c]);
            if constexpr (HAS_V) {
                out_V[s][c] = conv_V(vec_V[s][c]);
            }
        }
    }

    int local_ti, local_ti_rank;

    blocks += cu_block_num[batch_idx];

    block::Head block_head{block_layout, layer_id, head_idx};

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        const int qi = offset.y + s * Map::kDeltaS + token_idx;  // local offset into `input_length`
        const int ti = history_len + qi;                         // timestep
        local_ti     = cp_size.divmod(local_ti_rank, ti);
        if (qi < q_len && local_ti_rank == cp_rank) {
            block_head.with((char**)blocks, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    int di = offset.x + c * Map::kDeltaC;
                    Store(&k_cache[di], out_K[s][c]);
                    if constexpr (HAS_V) {
                        Store(&v_cache[di], out_V[s][c]);
                    }
                }
                if constexpr (!std::is_same_v) {
                    if (offset.x == 0) {
                        StoreQuantParam(k_param, param_K[s]);
                        if constexpr (HAS_V) {
                            StoreQuantParam(v_param, param_V[s]);
                        }
                        // if (ti == history_len) {
                        // printf("src %d %f %f\n", ti, (float)param_K[s][0], (float)param_K[s][1]);
                        // }
                    }
                }
            });
        }
    }
}

template
void invokeProcessKV_v2(char**                 blocks,
                        const T*               k,
                        const T*               v,
                        const T*               k_bias,
                        const T*               v_bias,
                        const int*             cu_q_len,
                        const int*             cu_k_len,
                        const int*             cu_block_num,
                        const RopeKernelParam& rope_param,
                        int64_t                stride_b,
                        int64_t                stride_c,
                        int64_t                stride_h,
                        int64_t                stride_s,
                        int                    block_seq_len,
                        int                    layer_id,
                        int                    cp_rank,
                        FastDivmod             cp_size,
                        int                    max_q_len,
                        int                    head_num,
                        int                    head_dim,
                        int                    batch_size,
                        int                    quant_policy,
                        cudaStream_t           stream)
{

    auto invoke = [&](auto tkv, const auto dim) {
        using Tkv = decltype(tkv);

        constexpr int  kHeadDim = dim;
        constexpr bool kShareKV = kHeadDim == 576;

        constexpr int WARPS = 4;
        constexpr int CTA_S = kShareKV ? 32 : 64;

        int  block = WARPS * WARP_SIZE;
        dim3 grid(cdiv(max_q_len, CTA_S), head_num, batch_size);

        TM_CHECK_EQ(head_dim, kHeadDim);

        block::Layout block_layout{block::Config{head_num, block_seq_len}};

        ProcessKV_v2<<>>(blocks,
                                                                              k,
                                                                              v,
                                                                              k_bias,
                                                                              v_bias,
                                                                              cu_q_len,
                                                                              cu_k_len,
                                                                              cu_block_num,
                                                                              rope_param,
                                                                              stride_b,
                                                                              stride_c,
                                                                              stride_h,
                                                                              stride_s,
                                                                              layer_id,
                                                                              cp_rank,
                                                                              cp_size,
                                                                              block_layout);
    };

    auto dispatch = [&](auto tkv) {
        if (0) {}
        else if (head_dim == 64) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 128) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 192) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 256) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 576) {
            return invoke(tkv, std::integral_constant{});
        }
        FT_CHECK(0);
    };

    if (quant_policy & QuantPolicy::kCacheKVInt8) {
        dispatch(uint8_t{});
    }
    else if (quant_policy & QuantPolicy::kCacheKVInt4) {
        dispatch(uint4_t{});
    }
    else {
        dispatch(T{});
    }
}

#define INSTANTIATE_invokeProcessKV_v2(type)                                                                           \
    template void invokeProcessKV_v2(char**                 blocks,                                                    \
                                     const type*            k,                                                         \
                                     const type*            v,                                                         \
                                     const type*            k_bias,                                                    \
                                     const type*            v_bias,                                                    \
                                     const int*             cu_q_len,                                                  \
                                     const int*             cu_k_len,                                                  \
                                     const int*             cu_block_num,                                              \
                                     const RopeKernelParam& rope_param,                                                \
                                     int64_t                stride_b,                                                  \
                                     int64_t                stride_c,                                                  \
                                     int64_t                stride_h,                                                  \
                                     int64_t                stride_s,                                                  \
                                     int                    block_seq_len,                                             \
                                     int                    layer_id,                                                  \
                                     int                    cp_rank,                                                   \
                                     FastDivmod             cp_size,                                                   \
                                     int                    max_q_len,                                                 \
                                     int                    head_num,                                                  \
                                     int                    head_dim,                                                  \
                                     int                    batch_size,                                                \
                                     int                    quant_policy,                                              \
                                     cudaStream_t           stream);

INSTANTIATE_invokeProcessKV_v2(half);
#if ENABLE_BF16
INSTANTIATE_invokeProcessKV_v2(nv_bfloat16);
#endif

template
__global__ void __launch_bounds__(128) flattenKV_v2(T*              k,
                                                    T*              v,
                                                    const Tkv**     blocks,
                                                    const int*      cu_k_len,
                                                    const int*      cu_block_num,
                                                    RopeKernelParam rope_param,
                                                    int64_t         stride_b,
                                                    int64_t         stride_c,
                                                    int64_t         stride_h,
                                                    int64_t         stride_s,
                                                    int             layer_id,
                                                    int             cp_rank,
                                                    FastDivmod      cp_size,
                                                    BlockLayout     block_layout)
{
    constexpr int kVecSize = sizeof(uint4) / sizeof(T);

    using Map = RakedThreadMap;

    constexpr int ITER_C = Map::kIterC;
    constexpr int ITER_S = Map::kIterS;

    constexpr bool HAS_V = !(typename BlockLayout::Config{}.is_share_kv());

    const int token_idx = blockIdx.x * CTA_S;
    const int head_idx  = blockIdx.y;
    const int batch_idx = blockIdx.z;

    const int ti_0   = cu_k_len[0];
    const int ti_beg = cu_k_len[batch_idx] - ti_0;
    const int ti_end = cu_k_len[batch_idx + 1] - ti_0;

    const int seq_len = ti_end - ti_beg;

    if (ti_beg + token_idx >= ti_end) {  // empty tile
        return;
    }

    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;

    const int2 offset = Map::get_offset(warp_id, lane_id);

    Array __align__(16) vec_K[ITER_S][ITER_C];
    Array __align__(16) vec_V[ITER_S][ITER_C];

    Array __align__(16) out_K[ITER_S][ITER_C];
    Array __align__(16) out_V[ITER_S][ITER_C];

    blocks += cu_block_num[batch_idx];

    block::Head block_head{block_layout, layer_id, head_idx};

    Array param_K[ITER_S];
    Array param_V[ITER_S];

    int local_ti, local_ti_rank;

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        const int si = offset.y + s * Map::kDeltaS + token_idx;
        local_ti     = cp_size.divmod(local_ti_rank, si);
        if (si < seq_len && local_ti_rank == cp_rank) {
            block_head.with((char**)blocks, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {
                PRAGMA_UNROLL
                for (int c = 0; c < ITER_C; ++c) {
                    int di = offset.x + c * Map::kDeltaC;
                    Ldg(vec_K[s][c], &k_cache[di]);
                    if constexpr (HAS_V) {
                        Ldg(vec_V[s][c], &v_cache[di]);
                    }
                }
                if constexpr (!std::is_same_v) {
                    Ldg(param_K[s], k_param);
                    if constexpr (HAS_V) {
                        Ldg(param_V[s], v_param);
                    }
                }
            });
        }
    }

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        ConvertKvCache conv_K{param_K[s][0], param_K[s][1]};
        ConvertKvCache conv_V{param_V[s][0], param_V[s][1]};
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            out_K[s][c] = conv_K(vec_K[s][c]);
            if constexpr (HAS_V) {
                out_V[s][c] = conv_V(vec_V[s][c]);
            }
        }
    }

    if (rope_param.type != RopeType::kNull) {
        FastRoPE rope(rope_param, batch_idx, std::integral_constant{});
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int di = offset.x + c * Map::kDeltaC;
            rope.init(di);
            PRAGMA_UNROLL
            for (int s = 0; s < ITER_S; ++s) {
                const int ti = offset.y + s * Map::kDeltaS + token_idx;  // sequence local
                rope.apply(out_K[s][c], ti);
            }
        }
    }

    PRAGMA_UNROLL
    for (int s = 0; s < ITER_S; ++s) {
        PRAGMA_UNROLL
        for (int c = 0; c < ITER_C; ++c) {
            const int si = offset.y + s * Map::kDeltaS + token_idx;
            const int di = offset.x + c * Map::kDeltaC;
            local_ti     = cp_size.divmod(local_ti_rank, si);
            if (si < seq_len && local_ti_rank == cp_rank) {
                const int64_t index =
                    (batch_idx * stride_b + ti_beg * stride_c + local_ti * stride_s + head_idx * stride_h) * HeadDim
                    + di;
                Store(&k[index], out_K[s][c]);
                if constexpr (HAS_V) {
                    Store(&v[index], out_V[s][c]);
                }
            }
        }
    }
}

template
void invokeFlattenKV_v2(T*                     k,
                        T*                     v,
                        char**                 blocks,
                        const int*             cu_k_len,
                        const int*             cu_block_num,
                        const RopeKernelParam& rope_param,
                        int64_t                stride_b,
                        int64_t                stride_c,
                        int64_t                stride_h,
                        int64_t                stride_s,
                        int                    block_seq_len,
                        int                    layer_id,
                        int                    cp_rank,
                        FastDivmod             cp_size,
                        int                    max_seq_len,
                        int                    head_num,
                        int                    head_dim,
                        int                    batch_size,
                        int                    quant_policy,
                        cudaStream_t           stream)
{

    auto invoke = [&](auto tkv, const auto dim) {
        using Tkv = decltype(tkv);

        constexpr int  kHeadDim = dim;
        constexpr bool kShareKV = kHeadDim == 576;

        constexpr int kWarpCnt = 4;
        constexpr int CTA_S    = kShareKV ? 32 : 64;

        constexpr int block = kWarpCnt * WARP_SIZE;
        const dim3    grid((max_seq_len + CTA_S - 1) / CTA_S, head_num, batch_size);

        TM_CHECK_EQ(head_dim, kHeadDim);

        block::Layout block_layout{block::Config{head_num, block_seq_len}};

        flattenKV_v2<<>>(k,
                                                                            v,
                                                                            (const Tkv**)blocks,
                                                                            cu_k_len,
                                                                            cu_block_num,
                                                                            rope_param,
                                                                            stride_b,
                                                                            stride_c,
                                                                            stride_h,
                                                                            stride_s,
                                                                            layer_id,
                                                                            cp_rank,
                                                                            cp_size,
                                                                            block_layout);
    };

    auto dispatch = [&](auto tkv) {
        if (0) {}
        else if (head_dim == 64) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 128) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 192) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 256) {
            return invoke(tkv, std::integral_constant{});
        }
        else if (head_dim == 576) {
            return invoke(tkv, std::integral_constant{});
        }
        FT_CHECK(0);
    };

    if (quant_policy & QuantPolicy::kCacheKVInt8) {
        dispatch(uint8_t{});
    }
    else if (quant_policy & QuantPolicy::kCacheKVInt4) {
        dispatch(uint4_t{});
    }
    else {
        dispatch(T{});
    }
}

#define INSTANTIATE_invokeFlattenKV_v2(type)                                                                           \
    template void invokeFlattenKV_v2(type*                  k,                                                         \
                                     type*                  v,                                                         \
                                     char**                 blocks,                                                    \
                                     const int*             cu_k_len,                                                  \
                                     const int*             cu_block_num,                                              \
                                     const RopeKernelParam& rope_param,                                                \
                                     int64_t                stride_b,                                                  \
                                     int64_t                stride_c,                                                  \
                                     int64_t                stride_h,                                                  \
                                     int64_t                stride_s,                                                  \
                                     int                    block_seq_len,                                             \
                                     int                    layer_id,                                                  \
                                     int                    cp_rank,                                                   \
                                     FastDivmod             cp_size,                                                   \
                                     int                    max_seq_len,                                               \
                                     int                    head_num,                                                  \
                                     int                    head_dim,                                                  \
                                     int                    batch_size,                                                \
                                     int                    quant_policy,                                              \
                                     cudaStream_t           stream);

INSTANTIATE_invokeFlattenKV_v2(half);
#if ENABLE_BF16
INSTANTIATE_invokeFlattenKV_v2(nv_bfloat16);
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/kv_cache_utils_v2.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/attention/attention_params.h"

namespace turbomind {

template
void invokeProcessKV_v2(char**                 blocks,
                        const T*               k,
                        const T*               v,
                        const T*               k_bias,
                        const T*               v_bias,
                        const int*             cu_q_len,
                        const int*             cu_k_len,
                        const int*             cu_block_num,
                        const RopeKernelParam& rope_param,
                        int64_t                stride_b,
                        int64_t                stride_c,
                        int64_t                stride_h,
                        int64_t                stride_s,
                        int                    block_seq_len,
                        int                    layer_id,
                        int                    cp_rank,
                        cutlass::FastDivmod    cp_size,
                        int                    max_q_len,
                        int                    head_num,
                        int                    head_dim,
                        int                    batch_size,
                        int                    quant_policy,
                        cudaStream_t           stream = {});

template
void invokeProcessKV_v2_(const AttentionParams& params)
{
    invokeProcessKV_v2((char**)params.block_iter_params.block_ptrs,
                       params.k,
                       params.v,
                       params.k_bias,
                       params.v_bias,
                       params.cu_q_len,
                       params.cu_k_len,
                       params.block_iter_params.cu_block_nums,
                       params.rope_param,
                       0,                                     // stride b
                       params.stride / params.size_per_head,  // stride c
                       1,                                     // stride h
                       params.stride / params.size_per_head,  // stride s
                       params.block_iter_params.block_len,
                       params.block_iter_params.layer_id,
                       params.cp_rank,
                       params.cp_size,
                       params.max_q_len,
                       params.num_kv_heads,
                       params.size_per_head,
                       params.batch_size,
                       params.quant_policy,
                       params.stream);
}

template
void invokeFlattenKV_v2(T*                     k,
                        T*                     v,
                        char**                 blocks,
                        const int*             cu_k_len,
                        const int*             cu_block_num,
                        const RopeKernelParam& rope_param,
                        int64_t                stride_b,
                        int64_t                stride_c,
                        int64_t                stride_h,
                        int64_t                stride_s,
                        int                    block_seq_len,
                        int                    layer_id,
                        int                    cp_rank,
                        cutlass::FastDivmod    cp_size,
                        int                    max_seq_len,
                        int                    head_num,
                        int                    head_dim,
                        int                    batch_size,
                        int                    quant_policy,
                        cudaStream_t           stream = {});

/// TODO: remove `sum_k_len`
template
void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len)
{
    // blocks -> [H, 2, sum_k_len, D]
    invokeFlattenKV_v2((T*)params.linear_iter_params.kv_cache,
                       (T*)params.linear_iter_params.kv_cache + params.linear_iter_params.key_to_val,
                       (char**)params.block_iter_params.block_ptrs,
                       params.cu_k_len,
                       params.block_iter_params.cu_block_nums,
                       RopeKernelParam{},
                       0,
                       1,
                       params.linear_iter_params.stride_h / params.size_per_head,
                       1,
                       params.block_iter_params.block_len,
                       params.block_iter_params.layer_id,
                       params.cp_rank,
                       params.cp_size,
                       params.max_k_len,
                       params.num_kv_heads,
                       params.size_per_head,
                       params.batch_size,
                       params.quant_policy,
                       params.stream);
}

size_t
get_cache_block_size(DataType dtype, DataType kvtype, int layer_num, int head_num, int head_dim, int block_seq_len);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/linear_iterator.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

template
struct LinearIterator {

    const T* kv_cache_;
    int      key_to_val_;

    const T* key_ptr_{};
    int      tile_id_{};

    __device__ LinearIterator(const T* kv_cache, int key_to_val): kv_cache_{kv_cache}, key_to_val_{key_to_val} {}

    __device__ void SetTile(int tile_id)
    {
        key_ptr_ = kv_cache_ + tile_id * CTA_S * HeadDim;
        tile_id_ = tile_id;
    }

    __device__ void Advance()
    {
        --tile_id_;
        if (tile_id_ >= 0) {
            key_ptr_ -= CTA_S * HeadDim;
        }
    }

    template
    __device__ const T* OffsetPtr(int offset) const
    {
        if constexpr (Index == 0) {
            return key_ptr_ + offset;
        }
        else if constexpr (Index == 1) {
            return key_ptr_ + offset + key_to_val_;
        }
        else {
            static_assert(Index != Index, "invalid index");
        }
    }
};

template
struct LinearIteratorFactory {
    using Tkv = Tkv_;

    const Tkv* kv_cache_;
    const int* cu_ctx_len_;
    int        stride_h_;
    int        key_to_val_;

    __device__ auto Create(int batch_idx, int head_idx)
    {
        int seq_ti = cu_ctx_len_[batch_idx] - cu_ctx_len_[0];
        // `head_idx * stride_h_` may be larger than `INT_MAX`
        const Tkv* kv_cache = kv_cache_ + head_idx * (int64_t)stride_h_ + seq_ti * HeadDim;

        return LinearIterator{kv_cache, key_to_val_};
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/mainloop.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind::attention {

template
struct Mainloop {
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/mainloop_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "arch.h"
#include "iterator_sm70.h"
#include "mainloop.h"

namespace turbomind::attention {

template
struct Mainloop {

    using Impl = Impl_;

    using T   = typename Impl::T;
    using Tkv = typename Impl::Tkv;

    using ThreadMapKV = typename Impl::ThreadMapKV;

    using GmemIterK_ = Sm70GmemIterator;
    using GmemIterV_ = Sm70GmemIterator;

    /// TODO: hide this behind a SFINAE gate so that `*KVp` stuff won't be needed for non-quantized impls
    using CombinedIterK =
        CombinedIterator>;
    using CombinedIterV =
        CombinedIterator>;

    using GmemIterK = std::conditional_t, GmemIterK_, CombinedIterK>;
    using GmemIterV = std::conditional_t, GmemIterV_, CombinedIterV>;

    using FragQ = typename Impl::FragQ;
    using FragS = typename Impl::FragS;
    using FragO = typename Impl::FragO;
    using FragM = typename Impl::FragM;
    using FragL = typename Impl::FragL;

    using SharedStorage = typename Impl::SharedStorage;

    static constexpr int CTA_S = Impl::CTA_S;

    int cp_size_{1};
    int cp_rank_{0};

    __device__ void SetCpInfo(int cp_size, int cp_rank)
    {
        cp_size_ = cp_size;
        cp_rank_ = cp_rank;
    }

    template
    __device__ void operator()(FragQ&         frag_Q,
                               CacheIter&     cache_iter,
                               FragO&         frag_O,
                               FragM&         frag_M,
                               FragL&         frag_L,
                               int            offset_Q,
                               int            offset_K,
                               int            max_step,
                               int            tile_iter,
                               int            mask_iter_back,
                               int            mask_iter_front,
                               int            window_size,
                               float          qk_scale,
                               SharedStorage& storage,
                               const StoreS&  store_S)
    {
        GmemIterK gmem_K{};
        GmemIterV gmem_V{};

        Impl::SetSmemKV(gmem_K, gmem_V, storage, true);

        typename GmemIterK::Fragment tmp_K;

        typename Impl::StateQK state_QK{storage, frag_Q};
        typename Impl::StatePV state_PV{storage, true};

        Impl::Sync();

        gmem_K.Load(cache_iter, tmp_K, max_step - offset_K);
        gmem_K.Save(tmp_K);

        constexpr auto nop = [](int) {};

        auto loop = [&](auto is_residue, auto is_mask) {
            typename GmemIterV::Fragment tmp_V;

            gmem_V.Load(cache_iter, tmp_V, is_residue ? max_step - offset_K : CTA_S);
            cache_iter.Advance();

            FragS frag_S{};

            Impl::Sync();
            state_QK.Load(0, 0);

            Impl::ComputeQK(state_QK, frag_S, 0, nop, [&] {});

            gmem_V.Save(tmp_V);

            if (tile_iter > 0) {
                gmem_K.Load(cache_iter, tmp_K, CTA_S);
            }

            if constexpr (is_mask) {
                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);
            }

            Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale);

            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);

            Impl::Sync();
            state_PV.Load(0, 0);

            Impl::ComputePV(state_PV, frag_O, 0, nop, [&] {});

            gmem_K.Save(tmp_K);

            offset_K -= CTA_S;
        };

        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {
            loop(std::true_type{}, std::true_type{});
        }

        PRAGMA_NO_UNROLL
        for (; tile_iter > mask_iter_front; --tile_iter) {
            loop(std::false_type{}, std::false_type{});
        }

        for (; tile_iter > 0; --tile_iter) {
            loop(std::false_type{}, std::true_type{});
        }
    }

    __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size)
    {
        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
            int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_);
            if (0 <= w && w < window_size) {}
            else {
                score -= std::numeric_limits::infinity();
            }
        });
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/mainloop_sm80.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "iterator_sm80.h"
#include "mainloop.h"
#include "src/turbomind/kernels/core/pipe_iter.h"
#include 
#include 

namespace turbomind::attention {

template
struct Sm80_CpAsync {
};

template
struct Mainloop, Impl_> {

    using Impl = Impl_;

    using T   = typename Impl::T;
    using Tkv = typename Impl::Tkv;

    static constexpr std::false_type false_c{};
    static constexpr std::true_type  true_c{};

    static constexpr int CTA_S = Impl::CTA_S;

    using ThreadMapKV = typename Impl::ThreadMapKV;

    using GmemIterK_ = Sm80GmemIterator;
    using GmemIterV_ = Sm80GmemIterator;

    /// TODO: hide this behind a SFINAE gate so that `*KVp` stuff won't be needed for non-quantized impls
    using CombinedIterK =
        CombinedIterator>;
    using CombinedIterV =
        CombinedIterator>;

    using GmemIterK = std::conditional_t, GmemIterK_, CombinedIterK>;
    using GmemIterV = std::conditional_t, GmemIterV_, CombinedIterV>;

    using FragQ = typename Impl::FragQ;
    using FragS = typename Impl::FragS;
    using FragO = typename Impl::FragO;
    using FragM = typename Impl::FragM;
    using FragL = typename Impl::FragL;

    using SharedStorage = typename Impl::SharedStorage;

    int cp_size_{1};
    int cp_rank_{0};

    __device__ void SetCpInfo(int cp_size, int cp_rank)
    {
        cp_size_ = cp_size;
        cp_rank_ = cp_rank;
    }

    template
    __device__ void operator()(Args&&... args)
    {
        Run(Sm80_CpAsync{},
            std::integral_constant{},
            std::integral_constant{},
            ((Args &&) args)...);
    }

    template
    __device__ static decltype(auto) Select(A&& a, B&& b)
    {
        if constexpr (Idx) {
            return (B &&) b;
        }
        else {
            return (A &&) a;
        }
    }

    template
    __device__ static void Prefetch(GmemIter gmem_iter, BlockIter& block_iter, int k, int pipe_iter)
    {
        const int begin = k * Batch;
        if (begin < ThreadMapKV::kIterS) {
            gmem_iter.Prefetch(false_c, block_iter, begin, Batch, CTA_S, pipe_iter);
        }
        if (begin + Batch == ThreadMapKV::kIterS) {
            if constexpr (Advnace) {
                block_iter.Advance();
            }
            __pipeline_commit();
        }
    }

    template
    __device__ void Run(Sm80_CpAsync,
                        std::integral_constant,
                        std::false_type,  // is MLA
                        FragQ&         frag_Q,
                        CacheIter&     cache_iter,
                        FragO&         frag_O,
                        FragM&         frag_M,
                        FragL&         frag_L,
                        int            offset_Q,
                        int            offset_K,
                        int            max_step,
                        int            tile_iter,
                        int            mask_iter_back,
                        int            mask_iter_front,
                        int            window_size,
                        float          qk_scale,
                        SharedStorage& storage,
                        const StoreS&  store_S)
    {
        // multi-stage: pipe_iter * size
        //   two-stage: constant offset

        GmemIterK gmem_K{};
        GmemIterV gmem_V{};

        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);

        PipeIter pipe_iter;

        PRAGMA_UNROLL
        for (int i = 0; i < Stages; ++i) {
            gmem_K.ClearSmem((++pipe_iter).w);
        }

        Impl::Sync();

        // 0
        gmem_K.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);
        __pipeline_commit();

        // 1
        gmem_V.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);
        __pipeline_commit();

        cache_iter.Advance();

        PRAGMA_UNROLL
        for (int stages = 2; stages < Stages - 2; stages += 2) {
            // 2 + 2X
            gmem_K.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);
            __pipeline_commit();
            // 3 + 2X
            gmem_V.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);
            __pipeline_commit();

            cache_iter.Advance();
        }

        if constexpr (Stages % 2 == 0) {
            // 2 + 2Y
            gmem_K.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);
            __pipeline_commit();
        }

        auto& gmem_0 = Select(gmem_V, gmem_K);
        auto& gmem_1 = Select(gmem_K, gmem_V);

        constexpr auto kBatch0 = Stages % 2 ? Impl::kBatchV : Impl::kBatchK;
        constexpr auto kBatch1 = Stages % 2 ? Impl::kBatchK : Impl::kBatchV;

        typename Impl::StateQK state_QK{storage, frag_Q};
        typename Impl::StatePV state_PV{storage};

        Wait();
        state_QK.Load(0, (++pipe_iter).r);

        auto loop = [&](auto is_mask) {
            __align__(16) FragS frag_S{};

            auto prefetch_0 = [&, pipe_iter](int k) {
                Prefetch(gmem_0, cache_iter, k, pipe_iter.w);
            };

            Impl::ComputeQK(state_QK, frag_S, pipe_iter.r, prefetch_0, [&] {
                Wait();
                state_PV.Load(0, (++pipe_iter).r);
            });

            if constexpr (is_mask) {
                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);
            }

            Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale);

            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);

            auto prefetch_1 = [&, pipe_iter](int k) {
                Prefetch(gmem_1, cache_iter, k, pipe_iter.w);
            };

            Impl::ComputePV(state_PV, frag_O, pipe_iter.r, prefetch_1, [&] {
                Wait();
                state_QK.Load(0, (++pipe_iter).r);
            });

            offset_K -= CTA_S;
        };

        for (int mask_iter = mask_iter_back; tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {
            loop(true_c);
        }

        PRAGMA_NO_UNROLL
        for (; tile_iter > mask_iter_front; --tile_iter) {
            loop(false_c);
        }

        for (; tile_iter > 0; --tile_iter) {
            loop(true_c);
        }

        __pipeline_commit();
        __pipeline_wait_prior(0);
    }

    // #if 1
    template
    __device__ void Run(Sm80_CpAsync<2>,
                        std::integral_constant,
                        std::false_type,  // is MLA
                        FragQ&         frag_Q,
                        CacheIter&     cache_iter,
                        FragO&         frag_O,
                        FragM&         frag_M,
                        FragL&         frag_L,
                        int            offset_Q,
                        int            offset_K,
                        int            max_step,
                        int            tile_iter,
                        int            mask_iter_back,
                        int            mask_iter_front,
                        int            window_size,
                        float          qk_scale,
                        SharedStorage& storage,
                        const StoreS&  store_S)
    {
        GmemIterK gmem_K{};
        GmemIterV gmem_V{};

        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);

        PRAGMA_UNROLL
        for (int i = 0; i < Stages; ++i) {
            gmem_K.ClearSmem(i);
        }

        gmem_K.Prefetch(true_c, cache_iter, max_step - offset_K, 0);
        __pipeline_commit();

        typename Impl::StateQK state_QK{storage, frag_Q};
        typename Impl::StatePV state_PV{storage};

        Wait();
        state_QK.Load(0, 0);

        constexpr auto _ = [](int) {};

        auto loop = [&](auto is_residue, auto is_mask) {
            __align__(16) FragS frag_S{};

            auto prefetch_V = [&](int k) {
                if (k == 0) {
                    gmem_V.Prefetch(is_residue, cache_iter, max_step - offset_K, 1);
                    __pipeline_commit();
                }
            };
            prefetch_V(0);

            Impl::ComputeQK(state_QK, frag_S, 0, _, [&] {
                Wait();
                state_PV.Load(0, 1);
            });

            cache_iter.Advance();

            auto prefetch_K = [&](int k) {
                if (k == 0) {
                    gmem_K.Prefetch(false_c, cache_iter, CTA_S, 0);
                    __pipeline_commit();
                }
            };
            prefetch_K(0);

            if constexpr (is_mask) {
                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);
            }

            Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale);

            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);

            Impl::ComputePV(state_PV, frag_O, 1, _, [&] {
                Wait();
                state_QK.Load(0, 0);
            });

            offset_K -= CTA_S;
        };

        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {
            loop(true_c, true_c);
        }

        PRAGMA_NO_UNROLL
        for (; tile_iter > mask_iter_front; --tile_iter) {
            loop(false_c, false_c);
        }

        for (; tile_iter > 0; --tile_iter) {
            loop(false_c, true_c);
        }

        __pipeline_commit();
        __pipeline_wait_prior(0);
    }

#if 1
    // Load      : K0,K1 | V0,K2,V1,K3 ...
    // Compute   :    K0 | K1,V0,K2,V1 ...
    // - more register consumption
    // - more interleaved HMMA and FMA
    // - slight performance gain
    template
    __device__ void Run(Sm80_CpAsync<2>,
                        std::integral_constant,
                        std::false_type,  // is MLA
                        FragQ&         frag_Q,
                        CacheIter&     cache_iter_,
                        FragO&         frag_O,
                        FragM&         frag_M,
                        FragL&         frag_L,
                        int            offset_Q,
                        int            offset_K,
                        int            max_step,
                        int            tile_iter,
                        int            mask_iter_back,
                        int            mask_iter_front,
                        int            window_size,
                        float          qk_scale,
                        SharedStorage& storage,
                        const StoreS&  store_S)
    {
        GmemIterK gmem_K{};
        GmemIterV gmem_V{};

        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);

        gmem_K.ClearSmem(0);
        gmem_K.ClearSmem(1);

        auto cache_iter_K = cache_iter_;
        auto cache_iter_V = cache_iter_;

        gmem_K.Prefetch(true_c, cache_iter_K, max_step - offset_K, 0);
        __pipeline_commit();
        cache_iter_K.Advance();

        typename Impl::StateQK state_QK{storage, frag_Q};
        typename Impl::StatePV state_PV{storage};

        Wait();
        state_QK.Load(0, 0);

        FragS frag_S{};
        auto  _ = [&](int k) {
            if (k == 0) {
                gmem_K.Prefetch(false_c, cache_iter_K, CTA_S, 1);
                __pipeline_commit();
            }
        };
        Impl::ComputeQK(state_QK, frag_S, 0, _, [&] {
            Wait();
            state_QK.Load(0, 1);
        });
        cache_iter_K.Advance();

        auto loop = [&](auto is_residue, auto is_mask, auto is_last) {
            if constexpr (is_mask) {
                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);
            }

            Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale);

            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);

            auto prefetch_V = [&](int k) {
                if (k == 0) {
                    gmem_V.Prefetch(is_residue, cache_iter_V, max_step - offset_K, 0);
                    __pipeline_commit();
                }
            };
            if constexpr (!is_last) {
                clear(frag_S);
                Impl::ComputeQK(state_QK, frag_S, 1, prefetch_V, [&] {
                    Wait();
                    state_PV.Load(0, 0);
                });
                cache_iter_V.Advance();
            }
            else {
                prefetch_V(0);
                Wait();
                state_PV.Load(0, 0);
            }

            auto prefetch_K = [&](int k) {
                if (k == 0) {
                    gmem_K.Prefetch(false_c, cache_iter_K, CTA_S, 1);
                    __pipeline_commit();
                }
            };
            Impl::ComputePV(state_PV, frag_O, 0, prefetch_K, [&] {
                Wait();
                state_QK.Load(0, 1);
            });
            cache_iter_K.Advance();

            offset_K -= CTA_S;
        };

        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {
            loop(true_c, true_c, false_c);
        }

        mask_iter_front = max(1, mask_iter_front);

        PRAGMA_NO_UNROLL
        for (; tile_iter > mask_iter_front; --tile_iter) {
            loop(false_c, false_c, false_c);
        }

        for (; tile_iter > 1; --tile_iter) {
            loop(false_c, true_c, false_c);
        }

        if (tile_iter > 0) {
            loop(false_c, true_c, true_c);
        }

        __pipeline_commit();
        __pipeline_wait_prior(0);
    }
#endif

    // Simplified MLA implementation
    template
    __device__ void Run(Sm80_CpAsync,
                        std::integral_constant,
                        std::true_type,  // is MLA
                        FragQ&         frag_Q,
                        CacheIter&     cache_iter,
                        FragO&         frag_O,
                        FragM&         frag_M,
                        FragL&         frag_L,
                        int            offset_Q,
                        int            offset_K,
                        int            max_step,
                        int            tile_iter,
                        int            mask_iter_back,
                        int            mask_iter_front,
                        int            window_size,
                        float          qk_scale,
                        SharedStorage& storage,
                        const StoreS&  store_S)
    {
        GmemIterK gmem_KV{};

        Impl::SetSmemKV(gmem_KV, gmem_KV, storage, false);

        PipeIter pipe_iter;

        PRAGMA_UNROLL
        for (int i = 0; i < Stages; ++i) {
            gmem_KV.ClearSmem((++pipe_iter).w);
        }

        Impl::Sync();

        gmem_KV.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);
        __pipeline_commit();
        cache_iter.Advance();

        PRAGMA_UNROLL
        for (int stages = 1; stages < Stages - 1; ++stages) {
            gmem_KV.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);
            __pipeline_commit();
            cache_iter.Advance();
        }

        typename Impl::StateQK state_QK{storage, frag_Q};
        typename Impl::StatePV state_PV{storage};

        Wait();
        state_QK.Load(0, (++pipe_iter).r);

        auto loop = [&](auto is_mask) {
            __align__(16) FragS frag_S{};

            gmem_KV.Prefetch(false_c, cache_iter, CTA_S, pipe_iter.w);
            __pipeline_commit();
            cache_iter.Advance();

            Impl::ComputeQK(
                state_QK, frag_S, pipe_iter.r, [](int) {}, [] {});

            if constexpr (is_mask) {
                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);
            }

            Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale);

            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);

            state_PV.Load(0, pipe_iter.r);
            Impl::ComputePV(
                state_PV, frag_O, pipe_iter.r, [](int) {}, [] {});

            Wait();
            state_QK.Load(0, (++pipe_iter).r);

            offset_K -= CTA_S;
        };

        for (int mask_iter = mask_iter_back; tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {
            loop(true_c);
        }

        PRAGMA_NO_UNROLL
        for (; tile_iter > mask_iter_front; --tile_iter) {
            loop(false_c);
        }

        for (; tile_iter > 0; --tile_iter) {
            loop(true_c);
        }

        __pipeline_commit();
        __pipeline_wait_prior(0);
    }

    __device__ void Wait()
    {
        __pipeline_wait_prior(Stages - 2);
        Impl::Sync();
    }

    __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size)
    {
        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
            int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_);
            if (0 <= w && w < window_size) {}
            else {
                score -= std::numeric_limits::infinity();
            }
        });
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/quantization.h
================================================
#pragma once

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/data_type.h"

#include 
#include 
#include 
#include 

namespace turbomind {

#define TM_ROUND_USE_CVT_RNI 1

inline constexpr bool kFuseU4F16Dequant  = false;
inline constexpr bool kForceIntZeroPoint = false;

template
__device__ T Infinity()
{
    if constexpr (std::is_same_v) {
        return __ushort_as_half((unsigned short)0x7C00U);
    }

#if __CUDA_ARCH__ >= 800
    if constexpr (std::is_same_v) {
        return __ushort_as_bfloat16((unsigned short)0x7F80U);
    }
#endif

    if constexpr (std::is_same_v) {
        return __int_as_float(0x7f800000U);
    }

    return T{};
}

template
__device__ constexpr T Max(T a, T b)
{
    if constexpr (std::is_same_v) {
        return __hmax(a, b);
    }

#if __CUDA_ARCH__ >= 800
    if constexpr (std::is_same_v) {
        return __hmax(a, b);
    }
#endif

    if constexpr (std::is_same_v) {
        return fmaxf(a, b);
    }

    if constexpr (std::is_same_v) {
        return max(a, b);
    }

    return T{};
}

template
__device__ constexpr T Min(T a, T b)
{
    if constexpr (std::is_same_v) {
        return __hmin(a, b);
    }

#if __CUDA_ARCH__ >= 800
    if constexpr (std::is_same_v) {
        return __hmin(a, b);
    }
#endif

    if constexpr (std::is_same_v) {
        return fminf(a, b);
    }

    if constexpr (std::is_same_v) {
        return min(a, b);
    }

    return T{};
}

template
inline __device__ Array cvt_f16x4_u8(const Array& src)
{
    static constexpr uint32_t f16_magic = 0x64000000;
    // 01234567 01234567
    // SEEEEEMM MMMMMMMM
    //      1MM XXXXXXXX
    // (1 + x/2^10) * 2^(e-15) -> e-15=10 -> e=25=16+8+1 -> 01100100b -> 0x64
    Array dst;
    dst[0] = __byte_perm((uint32_t&)src, f16_magic, 0x7170);
    dst[1] = __byte_perm((uint32_t&)src, f16_magic, 0x7372);
    if constexpr (norm) {
        for (int i = 0; i < 4; ++i) {
            ((Array&)dst)[i] -= __ushort_as_half(0x6400U);
        }
    }
    return (Array&)dst;
}

template
inline __device__ Array cvt_f16x2x2_u8_trans(const Array& src)
{
    static constexpr uint32_t f16_magic = 0x64000000;
    // 01234567 01234567
    // SEEEEEMM MMMMMMMM
    //      1MM XXXXXXXX
    // (1 + x/2^10) * 2^(e-15) -> e-15=10 -> e=25=16+8+1 -> 01100100b -> 0x64
    Array dst;
    dst[0] = __byte_perm((uint32_t&)src, f16_magic, 0x7270);
    dst[1] = __byte_perm((uint32_t&)src, f16_magic, 0x7371);
    if constexpr (norm) {
        for (int i = 0; i < 4; ++i) {
            ((Array&)dst)[i] -= __ushort_as_half(0x6400U);
        }
    }
    return (Array&)dst;
}

inline __device__ Array cvt_bf16x4_u8(const Array& src)
{
    // 01234567 01234567 01234567 01234567
    // SEEEEEEE EMMMMMMM MMMMMMMM MMMMMMMM
    //          1MM...   XXXXXXXX
    // (1 + x/2^15) * 2^(e-127) -> e-127=15 -> e=142 -> 01000111 -> 0x47
    static constexpr uint32_t f32_magic = 0x47000000;  // 32768

    Array tmp;
    tmp[0] = __byte_perm((uint32_t&)src, f32_magic, 0x7604);
    tmp[1] = __byte_perm((uint32_t&)src, f32_magic, 0x7614);
    tmp[2] = __byte_perm((uint32_t&)src, f32_magic, 0x7624);
    tmp[3] = __byte_perm((uint32_t&)src, f32_magic, 0x7634);

    auto& vec = (Array&)tmp;

    Array dst;
    PRAGMA_UNROLL
    for (int i = 0; i < 4; ++i) {
        dst[i] = __float2bfloat16(vec[i] - 32768.f);
    }
    return dst;
}

inline __device__ Array cvt_f32x4_u8(const Array& src)
{
    // 01234567 01234567 01234567 01234567
    // SEEEEEEE EMMMMMMM MMMMMMMM MMMMMMMM
    //          1MM...   XXXXXXXX
    // (1 + x/2^15) * 2^(e-127) -> e-127=15 -> e=142 -> 01000111 -> 0x47
    static constexpr uint32_t f32_magic = 0x47000000;  // 32768

    Array tmp;
    tmp[0] = __byte_perm((uint32_t&)src, f32_magic, 0x7604);
    tmp[1] = __byte_perm((uint32_t&)src, f32_magic, 0x7614);
    tmp[2] = __byte_perm((uint32_t&)src, f32_magic, 0x7624);
    tmp[3] = __byte_perm((uint32_t&)src, f32_magic, 0x7634);

    auto& vec = (Array&)tmp;
    PRAGMA_UNROLL
    for (int i = 0; i < 4; ++i) {
        vec[i] -= 32768.f;
    }
    return vec;
}

template
inline __device__ Array cvt_bf16x8_u4(const Array& src)
{
#if __CUDA_ARCH__ >= 800
    // 01234567 01234567
    // SEEEEEEE EMMMMMMM
    //          1...XXXX
    // (1 + x/2^7) * 2^(e-127) -> e-127=7 -> e=134 -> 0100 0011 -> 0x43
    static constexpr uint32_t TEMPLATE = 0x43004300;  // nv_bfloat162(128, 128)
    static constexpr uint32_t MASK     = 0x000f000f;
    static constexpr uint32_t immLut   = (0xf0 & 0xcc) | 0xaa;

    Array h;

    static_assert(sizeof(Array) == sizeof(Array));

    uint32_t const& i4s    = reinterpret_cast(src);
    const uint32_t  i4s_4  = i4s >> 4;
    const uint32_t  i4s_8  = i4s >> 8;
    const uint32_t  i4s_12 = i4s >> 12;

    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(TEMPLATE), "n"(immLut));
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s_4), "n"(MASK), "n"(TEMPLATE), "n"(immLut));
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(i4s_8), "n"(MASK), "n"(TEMPLATE), "n"(immLut));
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(i4s_12), "n"(MASK), "n"(TEMPLATE), "n"(immLut));

    if constexpr (norm) {
        auto result = reinterpret_cast(h.data());
        PRAGMA_UNROLL
        for (int i = 0; i < 8; ++i) {
            result[i] -= nv_bfloat16(128.f);
        }
    }
    return (Array&)h;
#else
    return {};
#endif
}

#if TM_ROUND_USE_CVT_RNI

template
inline __device__ T round(float x)
{
    uint32_t y{};
    if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u16.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u32.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else {
        static_assert(!std::is_same_v, "not implemented");
    }
    return y;
}

template
inline __device__ T round(half x)
{
    uint32_t y{};
    if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u8.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u16.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.u32.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rni.sat.s32.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else {
        static_assert(!std::is_same_v, "not implemented");
    }
    return y;
}

#else

template
inline __device__ T round(float x)
{
    x += .5f;

    uint32_t y{};
    if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u16.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u32.f32 %0, %1;\n" : "=r"(y) : "f"(x));
    }
    else {
        static_assert(!std::is_same_v, "not implemented");
    }
    return y;
}

template
inline __device__ T round(half x)
{
    x += half(.5f);

    uint32_t y{};
    if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u8.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u16.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else if constexpr (std::is_same_v) {
        asm("cvt.rzi.sat.u32.f16 %0, %1;\n" : "=r"(y) : "h"((uint16_t&)x));
    }
    else {
        static_assert(!std::is_same_v, "not implemented");
    }
    return y;
}

#endif

template
inline __device__ To quant(Ti x, B n_bits)
{
    auto y = round(x);
    if constexpr (n_bits < sizeof(To) * 8) {  // saturate operation for sub-byte type
        return min(y, To((1 << n_bits) - 1));
    }
    else {
        return y;
    }
}

template
__device__ inline void warp_minmax(Array& stats, const Array& x)
{
    PRAGMA_UNROLL
    for (int i = 0; i < C; ++i) {
        stats[0] = Min(stats[0], x[i]);
        stats[1] = Max(stats[1], x[i]);
    }
    if constexpr (sizeof(T) == 2) {
        PRAGMA_UNROLL
        for (int mask = WarpThreadC / 2; mask > 0; mask /= 2) {
            Array tmp;
            (uint32_t&)tmp = __shfl_xor_sync(uint32_t(-1), (uint32_t&)stats, mask);
            stats[0]       = Min(stats[0], tmp[0]);
            stats[1]       = Max(stats[1], tmp[1]);
        }
    }
    else {
        PRAGMA_UNROLL
        for (int mask = WarpThreadC / 2; mask > 0; mask /= 2) {
            stats[0] = Min(stats[0], __shfl_xor_sync(uint32_t(-1), stats[0], mask));
            stats[1] = Max(stats[1], __shfl_xor_sync(uint32_t(-1), stats[1], mask));
        }
    }
}

template
__device__ void warp_stats(Array (¶m)[S], const Array (&x)[S][C], B n_bits)
{
    PRAGMA_UNROLL
    for (int s = 0; s < S; ++s) {
        Array stats{Infinity(), -Infinity()};
        PRAGMA_UNROLL
        for (int c = 0; c < C; ++c) {
            warp_minmax(stats, x[s][c]);
        }
        const float inv_q_max = fdividef(1.f, float((1 << n_bits) - 1));
        const float scale     = ((float)stats[1] - (float)stats[0]) * inv_q_max;
        param[s][0]           = (P)scale;
        param[s][1]           = (P)stats[0];

        if constexpr (kForceIntZeroPoint) {
#if TM_ROUND_USE_CVT_RNI
            // rintf -> cvt.rni.f32.f32
            param[s][1] = (P)(rintf((float)stats[0] / scale) * scale);
#else
            // roundf -> cvt.rzi.f32.f32(x + 0.5)
            param[s][1] = (P)(roundf((float)stats[0] / scale) * scale);
#endif
        }
    }
}

template
__device__ void
quantize(Array (&dst)[S][C], const Array (&src)[S][C], const Array (¶ms)[S], B n_bits)
{
    PRAGMA_UNROLL
    for (int s = 0; s < S; ++s) {
        P inv_scale = (P)fdividef(1.f, (float)params[s][0]);
        P zero      = params[s][1];
        PRAGMA_UNROLL
        for (int c = 0; c < C; ++c) {
            PRAGMA_UNROLL
            for (int i = 0; i < N; ++i) {
                const auto v = ((P)src[s][c][i] - zero) * inv_scale;
                dst[s][c][i] = quant(v, n_bits);
            }
        }
    }
}

//////////////////////////////////////////////////////////////////////////////////////////////////

// generic case for floating point -> floating point / integer -> integer conversion
template
struct ConvertKvCache {
    __device__ __host__ ConvertKvCache(float, float) {}
    template
    __device__ static auto convert(const Array& vi)
    {
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = (To)vi[i];
        }
        return vo;
    }
    template
    inline __device__ auto operator()(const Array& vi) const -> Array
    {
        return convert(vi);
    }
};

// generic case for converting to same type, bypass
template
struct ConvertKvCache {
    __device__ __host__ ConvertKvCache(float, float) {}
    template
    __device__ static auto convert(const Array& v)
    {
        return v;
    }
    template
    inline __device__ auto operator()(const Array& v) const -> Array
    {
        return convert(v);
    }
};

//  floating point -> u8
template
struct ConvertKvCache {
    T          inv_scale_;
    T          zero_;
    __device__ ConvertKvCache(T scale, T zero): zero_{zero}
    {
        // NVCC complains if we put this in the member init list
        inv_scale_ = (T)fdividef(1.f, (float)scale);
    }

    template
    __device__ auto operator()(const Array& vi) const
    {
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = quant((vi[i] - zero_) * inv_scale_, std::integral_constant{});
        }
        return vo;
    }
};

template
struct ConvertKvCache {
    T          inv_scale_;
    T          zero_;
    __device__ ConvertKvCache(T scale, T zero): zero_{zero}
    {
        // NVCC complains if we put this in the member init list
        inv_scale_ = (T)fdividef(1.f, (float)scale);
    }

    static __device__ Array pack(const Array& vi)
    {
        Array ui = (Array&)vi;

        ui[0] |= (ui[0] >> 12);
        ui[1] |= (ui[1] >> 12);

        //  7 6 5 4 3 2 1 0
        // _7_67564_3_23120
        uint32_t uo = __byte_perm(ui[0], ui[1], 0x5140);

        return (Array&)uo;
    }

    /// TODO: try cvt.pack.sat.u4
    template
    __device__ auto operator()(const Array& vi) const
    {
        static_assert(N % 8 == 0);
        Array tmp;
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            tmp[i] = quant((vi[i] - zero_) * inv_scale_, std::integral_constant{});
        }
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            (Array&)vo[i] = pack((Array&)tmp[i]);
        }
        return vo;
    }
};
template<>
struct ConvertKvCache {

    half scale_;
    half zero_;

    __device__ ConvertKvCache(half scale, half zero)
    {
        scale_ = scale;
        zero_  = zero;
    }

    static __device__ Array cvt_f16x8_u4(const Array& vi)
    {
        Array            result;
        uint32_t*                 h           = reinterpret_cast(&result);
        uint32_t const&           i4s         = reinterpret_cast(vi);
        static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;
        static constexpr uint32_t BOT_MASK    = 0x000f000f;
        static constexpr uint32_t TOP_MASK    = 0x00f000f0;
        static constexpr uint32_t MAGIC_NUM_0 = 0x64006400;  // `1024`
        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;  // `64`
        // const uint32_t            top_i4s     = i4s >> 8;
        uint32_t top_i4s = __byte_perm(i4s, 0, 0x4321);
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
        asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
        asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
        asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
        return result;
    }

    static __device__ Array cvt_f16x8_u4_biased(const Array& vi)
    {
        Array            result;
        uint32_t*                 h           = reinterpret_cast(&result);
        uint32_t const&           i4s         = reinterpret_cast(vi);
        static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;
        static constexpr uint32_t BOT_MASK    = 0x000f000f;
        static constexpr uint32_t TOP_MASK    = 0x00f000f0;
        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`
        static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4
        const uint32_t            top_i4s     = i4s >> 8;
        // uint32_t top_i4s = __byte_perm(i4s, 0, 0x4321);
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        h[0] <<= 4;
        h[2] <<= 4;
        return result;
    }

    template
    __device__ static auto convert(const Array& vi)
    {
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            auto& v = (Array&)vo[i];
            if constexpr (kFuseU4F16Dequant) {
                v = cvt_f16x8_u4_biased((Array&)vi[i]);
            }
            else {
                v = cvt_f16x8_u4((Array&)vi[i]);
            }
        }
        return vo;
    }

    template
    __device__ auto operator()(const Array& vi) const
    {
        auto vo = convert(vi);
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = vo[i] * scale_ + zero_;
        }
        return vo;
    }
};

template<>
struct ConvertKvCache {

    nv_bfloat16 scale_;
    nv_bfloat16 zero_;

    __device__ ConvertKvCache(nv_bfloat16 scale, nv_bfloat16 zero)
    {
        scale_ = scale;
        zero_  = zero;
    }

    template
    __device__ static Array convert(const Array& vi)
    {
        Array vo{};
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            auto& v = (Array&)vo[i];
            auto  u = cvt_bf16x8_u4((Array&)vi[i]);
            v       = (Array&)u;
        }
        return vo;
    }

    template
    __device__ Array operator()(const Array& vi) const
    {
        auto vo = convert(vi);
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = vo[i] * scale_ + zero_;
        }
        return (Array&)vo;
    }
};

template<>
struct ConvertKvCache {

#if 1
    ConvertKvCache impl_;

    __device__ ConvertKvCache(float scale, float zero): impl_{scale, zero} {}

    template
    __device__ auto operator()(const Array& vi) const
    {
        return cast(impl_(vi));
    }
#else
    static __device__ Array cvt_f16x8_u4_biased(const Array& vi)
    {
        Array result;
        uint32_t* h = reinterpret_cast(&result);
        uint32_t const& i4s = reinterpret_cast(vi);
        static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
        static constexpr uint32_t BOT_MASK = 0x000f000f;
        static constexpr uint32_t TOP_MASK = 0x00f000f0;
        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`
        static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4
        const uint32_t top_i4s = i4s >> 8;
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
        asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
        h[0] <<= 4;
        h[2] <<= 4;
        return result;
    }
    float scale_;
    float zero_;
    __device__ ConvertKvCache(float scale, float zero)
    {
        scale_ = scale;
        zero_ = zero - scale * 64.f;
    }
    template
    __device__ auto operator()(const Array& vi) const
    {
        auto vo = cast(cvt_f16x8_u4_biased(vi));
        using namespace ops;
        return vo * scale_ + zero_;
    }
#endif
};

// u8 -> f32/f16/bf16
template
struct ConvertKvCache {
    T          scale_;
    T          zero_;
    __device__ ConvertKvCache(T scale, T zero): scale_{scale}, zero_{zero} {}

    template
    __device__ static auto convert(const Array& vi)
    {
        Array vo;
        PRAGMA_UNROLL
        for (int n = 0; n < N; n += 4) {
            auto& ui = (const Array&)vi[n];
            auto& uo = (Array&)vo[n];

            if constexpr (std::is_same_v) {
                uo = cvt_f16x4_u8(ui);
            }
            else if constexpr (std::is_same_v) {
                uo = cvt_f32x4_u8(ui);
            }
#if __CUDA_ARCH__ >= 800
            else if constexpr (std::is_same_v) {
                uo = cvt_bf16x4_u8(ui);
            }
#endif
        }
        return vo;
    }

    template
    __device__ auto operator()(const Array& vi) const
    {
        auto vo = convert(vi);
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = vo[i] * scale_ + zero_;
        }
        return vo;
    }
};

template
struct ConvertKvCache {

    __device__ static Array cvt_bf16x8_e2m1(const Array& vi)
    {
        const uint32_t& x = (const uint32_t&)vi;

        constexpr uint32_t S  = 0x80008000U;
        constexpr uint32_t EM = 0x01C001C0U;

        Array vo;

        // clang-format off
        vo[0] = (x << 12 & S) | (x << 6 & EM);
        vo[1] = (x <<  8 & S) | (x << 2 & EM);
        vo[2] = (x <<  4 & S) | (x >> 2 & EM);
        vo[3] = (x <<  0 & S) | (x >> 6 & EM);
        // clang-format on

        constexpr uint32_t e  = (127U - 1U + 127U) << 7U;
        constexpr uint32_t ee = e << 16U | e;

        PRAGMA_UNROLL
        for (int i = 0; i < 4; ++i) {
#if TURBOMIND_ARCH_SM90
            asm("mul.rn.bf16x2 %0, %1, %2;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee));
#else
            asm("fma.rn.bf16x2 %0, %1, %2, %3;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee), "r"(0));
#endif
        }

        return (Array&)vo;
    }

    __device__ static Array cvt_f16x8_e2m1(const Array& vi)
    {
        const uint32_t& x = (const uint32_t&)vi;

        constexpr uint32_t S  = 0x80008000U;
        constexpr uint32_t EM = 0x0E000E00U;

        Array vo;

        // clang-format off
        vo[0] = (x << 12 & S) | (x << 9 & EM);
        vo[1] = (x <<  8 & S) | (x << 5 & EM);
        vo[2] = (x <<  4 & S) | (x << 1 & EM);
        vo[3] = (x <<  0 & S) | (x >> 3 & EM);
        // clang-format on

        constexpr uint32_t e  = (15U - 1U + 15U) << 10U;
        constexpr uint32_t ee = e << 16U | e;

        PRAGMA_UNROLL
        for (int i = 0; i < 4; ++i) {
            asm volatile("mul.f16x2 %0, %1, %2;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee));
        }

        return (Array&)vo;
    }

    template
    __device__ static auto convert(const Array& vi)
    {
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            auto& v = (Array&)vo[i];
            if constexpr (std::is_same_v) {
                v = cvt_bf16x8_e2m1((Array&)vi[i]);
            }
            else if constexpr (std::is_same_v) {
                v = cvt_f16x8_e2m1((Array&)vi[i]);
            }
            else {
                static_assert(N != N, "not implemented");
            }
        }
        return vo;
    }
};

__device__ inline Array cvt_bf16x4_e4m3(const Array& vi)
{
    const uint32_t& x = (const uint32_t&)vi;

    //    0   7   C   0
    // SEEEEEEEEMMMMMMMSEEEEEEEEMMMMMMM
    // SEEEEMMM        SEEEEMMM
    //         SEEEEMMM        SEEEEMMM

    constexpr uint32_t S  = 0x80008000U;
    constexpr uint32_t EM = 0x07F007F0U;

    Array vo;

    vo[0] = (x << 8 & S) | (x << 4 & EM);
    vo[1] = (x << 0 & S) | (x >> 4 & EM);

    constexpr uint32_t e  = (127U - 7U + 127U) << 7U;
    constexpr uint32_t ee = e << 16U | e;

    PRAGMA_UNROLL
    for (int i = 0; i < 2; ++i) {
#if TURBOMIND_ARCH_SM90
        asm("mul.rn.bf16x2 %0, %1, %2;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee));
#else
        asm("fma.rn.bf16x2 %0, %1, %2, %3;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee), "r"(0));
#endif
    }

    return (Array&)vo;
}

__device__ inline Array cvt_f16x4_e4m3(const Array& vi)
{
    const uint32_t& x = (const uint32_t&)vi;

    //    3   F   8   0
    // SEEEEEMMMMMMMMMMSEEEEEMMMMMMMMMM
    // SEEEEMMM        SEEEEMMM
    //         SEEEEMMM        SEEEEMMM

    constexpr uint32_t S  = 0x80008000U;
    constexpr uint32_t EM = 0x3F803F80U;

    Array vo;

    vo[0] = (x << 8 & S) | (x << 7 & EM);
    vo[1] = (x << 0 & S) | (x >> 1 & EM);

    constexpr uint32_t e  = (15U - 7U + 15U) << 10U;
    constexpr uint32_t ee = e << 16U | e;

    PRAGMA_UNROLL
    for (int i = 0; i < 2; ++i) {
        asm("mul.rn.f16x2 %0, %1, %2;" : "=r"(vo[i]) : "r"(vo[i]), "r"(ee));
    }

    return (Array&)vo;
}

template
struct ConvertKvCache {

    template
    __device__ static auto convert(const Array& vi)
    {
        static_assert(N % 4 == 0);
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 4) {
            auto& v = (Array&)vo[i];
            if constexpr (std::is_same_v) {
                v = cvt_bf16x4_e4m3((Array&)vi[i]);
            }
            else if constexpr (std::is_same_v) {
                v = cvt_f16x4_e4m3((Array&)vi[i]);
            }
            else {
                static_assert(N != N, "not implemented");
            }
        }
        return vo;
    }
};

template
inline __device__ void StoreQuantParam(T* dst, Array src)
{
    Store(dst, src);
}

template<>
inline __device__ void StoreQuantParam(half* dst, Array src)
{
    if constexpr (kFuseU4F16Dequant) {
        src[1] = src[1] - src[0] * __ushort_as_half(0x5400);
    }
    Store(dst, src);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/reduce.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "cutlass/fast_math.h"
#include "src/turbomind/kernels/attention/cta_map.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/thread_map.h"
#include "src/turbomind/utils/cuda_utils.h"

#include 

namespace turbomind::attention {

template
__global__ void reduce(T*         out,
                       float*     partial_ML,
                       float*     partial_O,
                       const int* split_cnt_,
                       int        max_split_cnt,
                       int        query_num,
                       int        head_num,
                       float      exp_scale,
                       int        cp_rank,
                       int        stride_k,
                       int        offset_k)
{
    __shared__ float s_out[WarpCnt][HeadDim];
    __shared__ float s_ML[WarpCnt][2];
    __shared__ float s_scale[CTA_K];

    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;

    const int head_idx  = ReduceCtaMap::head_idx();
    const int query_idx = ReduceCtaMap::query_idx();
    const int chunk_idx = ReduceCtaMap::split_idx();

    offset_k *= chunk_idx;
    const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[query_idx] : 1;
    if (offset_k >= split_cnt) {  // out of bound
        return;
    }

    // merge cp and k for the first time and merge k thereafter.
    constexpr int kCpUb     = First ? CP : 1;
    constexpr int kWarpIter = First ? (CP + WarpCnt - 1) / WarpCnt : 1;
    float         ML[kWarpIter][2];

    // frag_M of this cp_rank and lane
    float frag_M = -std::numeric_limits::infinity();

    const int offset_r = cp_rank * query_num * head_num * max_split_cnt * 2;
    const int offset_m = First ? 0 : offset_r;
    const int warp_m   = First ? cp_rank % WarpCnt : 0;

    PRAGMA_UNROLL
    for (int i = 0; i < kWarpIter; ++i) {
        int        cp_i = warp_id + i * WarpCnt;
        int        ki   = lane_id * stride_k + offset_k;
        const bool mask = cp_i < kCpUb && ki < split_cnt;  // cp, q, h, k, 2
        const int  index =
            offset_m + ((cp_i * query_num * head_num + (query_idx * head_num + head_idx)) * max_split_cnt + ki) * 2;

        Array temp_ML = {-std::numeric_limits::infinity(), 0.f};
        if (mask) {
            Load(temp_ML, &partial_ML[index]);
        }
        Store(&ML[i][0], temp_ML);

        frag_M = (mask && warp_m == warp_id) ? ML[i][0] : frag_M;
    }

    float block_M = -std::numeric_limits::infinity();
    float block_L = 0.f;
    PRAGMA_UNROLL
    for (int i = 0; i < kWarpIter; ++i) {
        block_M = fmaxf(block_M, ML[i][0]);
    }

    PRAGMA_UNROLL
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
        block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask));
    }

    PRAGMA_UNROLL
    for (int i = 0; i < kWarpIter; ++i) {
        block_L += (ML[i][0] == -std::numeric_limits::infinity()) ?
                       0.0f :
                       exp2f((ML[i][0] - block_M) * exp_scale) * ML[i][1];
    }

    PRAGMA_UNROLL
    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
        block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask);
    }

    if constexpr (First && CP > 1) {
        if (lane_id == 0) {
            Store(&s_ML[warp_id][0], Array{block_M, block_L});
        }
        __syncthreads();

        if (warp_id == 0 && lane_id == 0) {
            PRAGMA_UNROLL
            for (int i = 0; i < WarpCnt; ++i) {
                block_M = fmaxf(block_M, s_ML[i][0]);
            }

            block_L = 0.f;
            PRAGMA_UNROLL
            for (int i = 0; i < WarpCnt; ++i) {
                block_L += exp2f((s_ML[i][0] - block_M) * exp_scale) * s_ML[i][1];
            }

            Store(&s_ML[0][0], Array{block_M, block_L});
        }
        __syncthreads();

        block_M = s_ML[0][0];
        block_L = s_ML[0][1];
    }

    if (gridDim.z > 1 && warp_id == 0) {
        int        ki    = lane_id * stride_k + offset_k;
        const bool mask  = ki < split_cnt;  // q, h, k, 2
        const int  index = offset_r + ((query_idx * head_num + head_idx) * max_split_cnt + ki) * 2;
        if (mask) {
            Store(&partial_ML[index], Array{block_M, block_L});
        }
    }

    if (warp_id == warp_m) {
        const float divisor = gridDim.z == 1 ? block_L : 1.0f;
        s_scale[lane_id] =
            frag_M == -std::numeric_limits::infinity() ? 0.0f : exp2f((frag_M - block_M) * exp_scale) / divisor;
    }

    __syncthreads();

    // HeadDim / WARP_SIZE
    // 128     -> 4
    // 64, 192 -> 2
    constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2;

    using Map = RakedThreadMap;
    static_assert(Map::kIterS == 1);

    constexpr int C = Map::kIterC;

    using Vec = Array;

    Vec accu_O[C]{};
    Vec frag_O[C];

    const int2 d = Map::get_offset(warp_id, lane_id);

    auto for_each = [&](auto fn) {
        const int ki = d.y;
        PRAGMA_UNROLL
        for (int c = 0; c < C; ++c) {
            const int di = d.x + c * Map::kDeltaC;
            fn(c, ki, di);
        }
    };

    PRAGMA_UNROLL
    for (int k = 0; k < CTA_K; k += WarpCnt) {
        for_each([&](int c, int ki, int di) {
            using namespace ops;
            ki += k;
            const int  split_idx = offset_k + stride_k * ki;
            const bool mask      = split_idx < split_cnt;
            const int  index     = (query_idx * head_num + head_idx) * max_split_cnt + split_idx;
            const int  offset    = index * HeadDim + di;
            if (mask) {
                Load(frag_O[c], &partial_O[offset]);
                accu_O[c] = accu_O[c] + frag_O[c] * s_scale[ki];
            }
        });
    }

    for_each([&](int c, int ki, int di) {
        Store(&s_out[ki][di], accu_O[c]);  //
    });

    PRAGMA_UNROLL
    for (int w = WarpCnt / 2; w > 0; w /= 2) {
        __syncthreads();
        for_each([&](int c, int ki, int di) {
            using namespace ops;
            if (ki < w) {
                (Vec&)s_out[ki][di] = (Vec&)s_out[ki][di] + (Vec&)s_out[w + ki][di];
            }
        });
    }

    for_each([&](int c, int ki, int di) {
        if (ki == 0) {
            if (gridDim.z == 1) {
                const int offset = (query_idx * head_num + head_idx) * HeadDim + di;
                Store(&out[offset], cast((Vec&)s_out[ki][di]));
            }
            else {
                const int offset = ((query_idx * head_num + head_idx) * max_split_cnt + offset_k) * HeadDim + di;
                Store(&partial_O[offset], (Vec&)s_out[ki][di]);
            }
        }
    });
}

template
void invokeReduceV3(T*           out,
                    float*       partial_ML,
                    float*       partial_O,
                    const int*   split_cnt,
                    int          partial_len,
                    int          max_split_cnt,
                    int          cp_size,
                    int          cp_rank,
                    int          query_num,
                    int          head_num,
                    float        exp_scale,
                    cudaStream_t stream)
{
    constexpr int CTA_K = 32;  // warp size

    constexpr int    kWarpCnt  = 4;
    constexpr size_t kSmemSize = sizeof(float) * (kWarpCnt * HeadDim + kWarpCnt * 2 + CTA_K);
    static_assert(kSmemSize < (48 << 10), "shared memory usage exceeds 48KB per block");

    partial_ML -= cp_rank * query_num * head_num * partial_len * 2;  // begin address of cp_rank0

    auto invoke = [&](auto cp, auto is_first, int stride_k) {
        const dim3 block = kWarpCnt * WARP_SIZE;
        const dim3 grid  = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K);

        reduce<<>>(  //
            out,
            partial_ML,
            partial_O,
            split_cnt,
            partial_len,
            query_num,
            head_num,
            exp_scale,
            cp_rank,
            stride_k,
            stride_k * CTA_K);

        sync_check_cuda_error();
    };

    auto dispatch_cp = [&](int stride_k, auto is_first) {
        switch (cp_size) {
#define LAUNCH_INVOKE(n)                                                                                               \
    case n:                                                                                                            \
        invoke(std::integral_constant{}, is_first, stride_k);                                                  \
        break;
            LAUNCH_INVOKE(1);
            LAUNCH_INVOKE(2);
            LAUNCH_INVOKE(4);
            LAUNCH_INVOKE(8);
            LAUNCH_INVOKE(16);
            LAUNCH_INVOKE(32);
            default:
                TM_CHECK(false) << "reduce does not support cp_size = " << cp_size;
#undef LAUNCH_INVOKE
        }
    };

    int stride_k = 1;

    dispatch_cp(stride_k, std::true_type{});
    while (max_split_cnt > CTA_K) {
        max_split_cnt = (max_split_cnt + CTA_K - 1) / CTA_K;
        stride_k *= CTA_K;
        dispatch_cp(stride_k, std::false_type{});
    }
}

#define INSTANTIATE_invokeReduceV3(dim, type)                                                                          \
    template void invokeReduceV3(type * out,                                                                      \
                                      float*       partial_ML,                                                         \
                                      float*       partial_O,                                                          \
                                      const int*   split_cnt,                                                          \
                                      int          partial_len,                                                        \
                                      int          max_split_cnt,                                                      \
                                      int          cp_size,                                                            \
                                      int          cp_rank,                                                            \
                                      int          query_num,                                                          \
                                      int          head_num,                                                           \
                                      float        exp_scale,                                                          \
                                      cudaStream_t stream);

INSTANTIATE_invokeReduceV3(64, half);
INSTANTIATE_invokeReduceV3(128, half);
INSTANTIATE_invokeReduceV3(192, half);
INSTANTIATE_invokeReduceV3(256, half);
INSTANTIATE_invokeReduceV3(576, half);

#if ENABLE_BF16
INSTANTIATE_invokeReduceV3(64, nv_bfloat16);
INSTANTIATE_invokeReduceV3(128, nv_bfloat16);
INSTANTIATE_invokeReduceV3(192, nv_bfloat16);
INSTANTIATE_invokeReduceV3(256, nv_bfloat16);
INSTANTIATE_invokeReduceV3(576, nv_bfloat16);
#endif

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/reduce.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "cta_map.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/thread_map.h"
#include 
#include 
#include 

namespace turbomind::attention {

template
void invokeReduceV3(T*           out,
                    float*       partial_ML,
                    float*       partial_O,
                    const int*   split_cnt,
                    int          partial_len,
                    int          max_split_cnt,
                    int          cp_size,
                    int          cp_rank,
                    int          query_num,
                    int          head_num,
                    float        exp_scale,
                    cudaStream_t stream);
}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/reference.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "reference.h"
#include "src/turbomind/kernels/attention/rotary_embedding.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"

namespace turbomind {

template
__global__ void
createCausalMasks(T* mask, const int* q_lens, const int* k_lens, int64_t max_q_len, int64_t max_k_len, int window_size)
{
    const int     bi      = blockIdx.x;
    const int64_t q_len   = q_lens ? q_lens[bi] : max_q_len;
    const int64_t k_len   = k_lens ? k_lens[bi] : max_k_len;
    const int     history = k_len - q_len;
    mask += bi * max_q_len * max_k_len;
    for (int64_t i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {
        const int q = i / max_k_len;
        const int k = i % max_k_len;
        const int w = q - (k - history);

        const bool is_valid = q < q_len && k < k_len && 0 <= w && w < window_size;

        mask[i] = is_valid ? T{1.} : T{0.};
    }
}

// [B, H, S, D]
template
__global__ void
applyRotaryEmbedding(T* k_cache, int max_k_len, int head_num, int head_dim, float rope_base, int rope_dim)
{
    const int    ti = blockIdx.x;
    const size_t hi = blockIdx.y;
    const size_t bi = blockIdx.z;

    constexpr int kVecSize = 2;
    const int     history  = 0;

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const size_t idx =
            bi * head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;

        Array vec_K;

        Load(vec_K, &k_cache[idx]);

        RotaryEmbedding rope(rope_base, rope_dim, history + ti, {d, 0});

        rope.apply(vec_K);

        Store(&k_cache[idx], vec_K);
    }
}

template
void invokeApplyRotaryEmbedding(T*           k_cache,
                                int          max_k_len,
                                int          head_num,
                                int          head_dim,
                                float        rope_base,
                                int          rope_dim,
                                int          batch_size,
                                cudaStream_t stream)
{
    int  threads = 128;
    dim3 blocks(max_k_len, head_num, batch_size);

    applyRotaryEmbedding<<>>(k_cache, max_k_len, head_num, head_dim, rope_base, rope_dim);
}

template void invokeApplyRotaryEmbedding(half*        k_cache,
                                         int          max_k_len,
                                         int          head_num,
                                         int          head_dim,
                                         float        rope_base,
                                         int          rope_dim,
                                         int          batch_size,
                                         cudaStream_t stream);
#if ENABLE_BF16
template void invokeApplyRotaryEmbedding(nv_bfloat16* k_cache,
                                         int          max_k_len,
                                         int          head_num,
                                         int          head_dim,
                                         float        rope_base,
                                         int          rope_dim,
                                         int          batch_size,
                                         cudaStream_t stream);
#endif

template
__global__ void processQKV(T*       q_out,     // [B, H, s, D]
                           T*       k_cache,   // [B, H, S, D]
                           T*       v_cache,   // [B, H, S, D]
                           const T* qkv,       // [B, s, H, D]
                           const T* qkv_bias,  // [Q; K; V]
                           int      max_q_len,
                           int      max_k_len,
                           int      head_num,
                           int      head_dim,
                           int      kv_head_num,
                           float    rope_theta,
                           int      rope_dim)
{
    const int    ti = blockIdx.x;
    const size_t hi = blockIdx.y;
    const size_t bi = blockIdx.z;

    const int history = max_k_len - max_q_len;

    size_t qkv_head_num = head_num + 2 * kv_head_num;

    auto q = qkv + (bi * max_q_len + ti) * qkv_head_num * head_dim;
    auto k = q + head_num * head_dim;
    auto v = k + kv_head_num * head_dim;

    auto q_bias = qkv_bias ? qkv_bias + hi * head_dim : nullptr;
    auto k_bias = qkv_bias ? q_bias + head_num * head_dim : nullptr;
    auto v_bias = qkv_bias ? k_bias + kv_head_num * head_dim : nullptr;

    constexpr int kVecSize = 2;

    using namespace ops;

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const auto         idx = bi * head_num * max_q_len * head_dim + hi * max_q_len * head_dim + ti * head_dim + d;
        Array vec;
        Ldg(vec, &q[hi * head_dim + d]);
        if (qkv_bias) {
            Array bias;
            Load(bias, &q_bias[d]);
            vec = vec + bias;
        }
        if (rope_theta) {
            RotaryEmbedding rope(rope_theta, rope_dim, history + ti, {d, 0});
            rope.apply(vec);
        }

        Store(&q_out[idx], vec);
    }

    if (hi >= kv_head_num) {
        return;
    }

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const auto idx =
            bi * kv_head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;
        Array vec_K;
        Array vec_V;
        Ldg(vec_K, &k[hi * head_dim + d]);
        Ldg(vec_V, &v[hi * head_dim + d]);
        if (qkv_bias) {
            Array bias_K;
            Array bias_V;
            Load(bias_K, &k_bias[d]);
            Load(bias_V, &v_bias[d]);
            vec_K = vec_K + bias_K;
            vec_V = vec_V + bias_V;
        }
        if (rope_theta) {
            RotaryEmbedding rope(rope_theta, rope_dim, history + ti, {d, 0});
            rope.apply(vec_K);
        }
        Store(&k_cache[idx], vec_K);
        Store(&v_cache[idx], vec_V);
    }
}

template
__global__ void RepeatKVKernel(T*       keys,
                               T*       vals,
                               const T* k_cache,
                               const T* v_cache,
                               int      head_num,
                               int      max_k_len,
                               int      head_dim,
                               int      kv_head_num,
                               int      n_reps)
{
    const int64_t ti = blockIdx.x;
    const int64_t hi = blockIdx.y;
    const int64_t bi = blockIdx.z;

    const auto khi = hi / n_reps;

    // clang-format off
    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
        int64_t d_idx = bi *    head_num * max_k_len * head_dim +  hi * max_k_len * head_dim + ti * head_dim + d;
        int64_t s_idx = bi * kv_head_num * max_k_len * head_dim + khi * max_k_len * head_dim + ti * head_dim + d;
        keys[d_idx] = k_cache[s_idx];
        vals[d_idx] = v_cache[s_idx];
    }
    // clang-format on
}

template
Reference::Reference(cudaStream_t stream): stream_(stream)
{
    cublasCreate(&cublas_);
    cublasSetStream(cublas_, stream);
}

template
void Reference::Reshape(size_t max_q_len,
                           size_t max_k_len,
                           size_t head_num,
                           size_t head_dim,
                           size_t kv_head_num,
                           size_t batch_size,
                           int    window_size)
{
    std::cout << max_q_len << " " << max_k_len << " " << head_num << " " << head_dim << " " << batch_size << "\n";

    q_.resize(batch_size * head_num * max_q_len * head_dim);
    mask_.resize(batch_size * max_q_len * max_k_len);

    std::cout << "size of QK buf: "
              << ((batch_size * head_num * max_q_len * max_k_len * sizeof(float)) / float(1 << 30)) << " GB\n";
    qk_.resize(batch_size * head_num * max_q_len * max_k_len);
    pr_.resize(batch_size * head_num * max_q_len * max_k_len);
    out_.resize(batch_size * max_q_len * head_num * head_dim);

    keys_.resize(batch_size * head_num * max_k_len * head_dim);
    vals_.resize(batch_size * head_num * max_k_len * head_dim);

    cudaStreamSynchronize(0);

    createCausalMasks<<>>(
        mask_.data().get(), nullptr, nullptr, max_q_len, max_k_len, window_size);

    max_q_len_   = max_q_len;
    max_k_len_   = max_k_len;
    head_num_    = head_num;
    head_dim_    = head_dim;
    kv_head_num_ = kv_head_num;
    batch_size_  = batch_size;
    window_size_ = window_size;
}

template
void Reference::Execute(
    T* output, T* k_cache, T* v_cache, const T* qkv, const T* qkv_bias, const T* sinks, float rope_base, int rope_dim)
{
    {
        int  threads = 128;
        dim3 blocks(max_q_len_, head_num_, batch_size_);
        cudaDeviceSynchronize();

        processQKV<<>>(q_.data().get(),  //
                                                    k_cache,
                                                    v_cache,
                                                    qkv,
                                                    qkv_bias,
                                                    max_q_len_,
                                                    max_k_len_,
                                                    head_num_,
                                                    head_dim_,
                                                    kv_head_num_,
                                                    rope_base,
                                                    rope_dim);

        // std::cout << head_num_ << " " << kv_head_num_ << " " << head_dim_ / kv_head_num_ << "\n";

        blocks.x = max_k_len_;
        RepeatKVKernel<<>>(keys_.data().get(),
                                                        vals_.data().get(),
                                                        k_cache,
                                                        v_cache,
                                                        head_num_,
                                                        max_k_len_,
                                                        head_dim_,
                                                        kv_head_num_,
                                                        head_num_ / kv_head_num_);

        cudaDeviceSynchronize();
    }

    const cudaDataType data_type = std::is_same_v ? CUDA_R_16F : CUDA_R_16BF;

    float alpha = 1.f / sqrtf((float)head_dim_);
    float beta  = 0.f;
    cublasGemmStridedBatchedEx(cublas_,
                               CUBLAS_OP_T,              // trans A
                               CUBLAS_OP_N,              // trans B
                               max_k_len_,               // m
                               max_q_len_,               // n
                               head_dim_,                // k
                               &alpha,                   // alpha
                               keys_.data().get(),       // A
                               data_type,                // A type
                               head_dim_,                // lda
                               max_k_len_ * head_dim_,   // strideA
                               q_.data().get(),          // B
                               data_type,                // B type
                               head_dim_,                // ldb
                               max_q_len_ * head_dim_,   // stride B
                               &beta,                    // beta
                               qk_.data().get(),         // C
                               CUDA_R_32F,               // C type
                               max_k_len_,               // ldc
                               max_q_len_ * max_k_len_,  // stride C
                               batch_size_ * head_num_,  // batch count
                               CUBLAS_COMPUTE_32F,       // compute type
                               CUBLAS_GEMM_DEFAULT);

    MaskedSoftmaxParam params{};
    params.attention_score = pr_.data().get();
    params.qk              = qk_.data().get();
    params.attention_mask  = mask_.data().get();
    params.batch_size      = batch_size_;
    params.q_length        = max_q_len_;
    params.k_length        = max_k_len_;
    params.num_heads       = head_num_;
    params.sinks           = sinks;
    invokeMaskedSoftmax(params, stream_);

    alpha = 1.f;
    cublasGemmStridedBatchedEx(cublas_,
                               CUBLAS_OP_N,              // trans A
                               CUBLAS_OP_N,              // trans B
                               head_dim_,                // m
                               max_q_len_,               // n
                               max_k_len_,               // k
                               &alpha,                   // alpha
                               vals_.data().get(),       // A
                               data_type,                // A type
                               head_dim_,                // lda
                               max_k_len_ * head_dim_,   // strideA
                               pr_.data().get(),         // B
                               data_type,                // B type
                               max_k_len_,               // ldb
                               max_q_len_ * max_k_len_,  // stride B
                               &beta,                    // beta
                               out_.data().get(),        // C [b, h, q, d]
                               data_type,                // C type
                               head_dim_,                // ldc
                               max_q_len_ * head_dim_,   // stride C
                               batch_size_ * head_num_,  // batch count
                               CUBLAS_COMPUTE_32F,       // compute type
                               CUBLAS_GEMM_DEFAULT);

    // [B, H, Q, D] -> [B, Q, H, D]
    invokeTransposeAttentionOutRemovePadding(out_.data().get(),
                                             output,
                                             batch_size_ * max_q_len_,
                                             batch_size_,
                                             max_q_len_,
                                             head_num_,
                                             head_dim_,
                                             nullptr,
                                             nullptr,
                                             0,
                                             stream_);
}

template class Reference;

#if ENABLE_BF16
template class Reference;
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/reference.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include 

#include "src/turbomind/kernels/unfused_attention_kernels.h"

namespace turbomind {

template
void invokeApplyRotaryEmbedding(T*           k_cache,
                                int          max_k_len,
                                int          head_num,
                                int          head_dim,
                                float        rope_base,
                                int          rope_dim,
                                int          batch_size,
                                cudaStream_t stream = {});

template
class Reference {
public:
    explicit Reference(cudaStream_t stream);

    void Reshape(size_t max_q_len,
                 size_t max_k_len,
                 size_t head_num,
                 size_t head_dim,
                 size_t kv_head_num,
                 size_t batch_size,
                 int    window_size);

    void Execute(T*       output,
                 T*       k_cache,
                 T*       v_cache,
                 const T* qkv,
                 const T* qkv_bias,
                 const T* sinks,
                 float    rope_base,
                 int      rope_dim);

    const float* qk() const
    {
        return qk_.data().get();
    }

    const T* pr() const
    {
        return pr_.data().get();
    }

    const T* mask() const
    {
        return mask_.data().get();
    }

private:
    cudaStream_t                    stream_;
    cublasHandle_t                  cublas_;
    thrust::universal_vector     mask_;
    thrust::universal_vector qk_;
    thrust::universal_vector     pr_;
    thrust::universal_vector     q_;
    thrust::universal_vector     out_;

    thrust::universal_vector keys_;
    thrust::universal_vector vals_;

    int max_q_len_{};
    int max_k_len_{};
    int head_num_{};
    int head_dim_{};
    int kv_head_num_{};
    int batch_size_{};
    int window_size_{};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/registrar.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include "src/turbomind/kernels/attention/kernel_impl.h"

namespace turbomind::attention {

class Collector {
public:
    template
    void add()
    {
        kernels_.emplace_back(std::make_unique>());
        // std::cout << "add kernel: " << to_string(kernels_.back()->desc()) << std::endl;
    }

    std::vector> release()
    {
        return std::move(kernels_);
    }

private:
    std::vector> kernels_;
};

using RegisterFn = std::function;

inline std::vector& gKernelFactories()
{
    static std::vector v;
    return v;
}

struct Registrar {
    explicit Registrar(RegisterFn fn)
    {
        gKernelFactories().push_back(std::move(fn));
    }
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/registry.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/attention/registry.h"

#include 
#include 
#include 
#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/kernels/attention/arch.h"
#include "src/turbomind/kernels/attention/registrar.h"
#include "src/turbomind/kernels/core/math.h"

namespace turbomind::attention {

namespace {

constexpr float kMaxWasteRatio = 1.f;

}  // namespace

Registry::Registry(std::shared_ptr device_prop):
    device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10}
{
    for (auto& register_fn : gKernelFactories()) {
        Collector collector;
        register_fn(collector);
        for (auto& k : collector.release()) {
            Add(std::move(k));
        }
    }
}

bool Registry::Add(std::unique_ptr kernel)
{
    bool is_valid = true;

    if (!arch::is_arch_compatible(kernel->arch(), arch_)) {
        is_valid = false;
    }

    if ((int)device_prop_->sharedMemPerBlockOptin < kernel->smem_size()) {
        is_valid = false;
    }

    if (is_valid) {
        ptrs_.push_back(kernels_.emplace_back(std::move(kernel)).get());
    }

    return is_valid;
}

const Kernel* Registry::Find(const AttnDesc& desc) const
{
    const int threshold = static_cast(kMaxWasteRatio * desc.query_group_sz);

    const Kernel*             best = nullptr;
    std::tuple cost{};

    for (const auto* k : ptrs_) {
        const auto& d = k->desc();
        if (d.mode != desc.mode || d.head_dim != desc.head_dim  //
            || d.data_type != desc.data_type || d.kv_quant != desc.kv_quant) {
            continue;
        }
        if (desc.mode == AttnDesc::kDecoding) {
            const int ctas  = cdiv(desc.query_group_sz, d.qh);
            const int waste = d.qh * ctas - desc.query_group_sz;

            const auto v = std::make_tuple(waste > threshold, ctas, waste);
            if (!best || v < cost) {
                best = k;
                cost = v;
            }
        }
        else {  // attention, return on first match
            return k;
        }
    }
    return best;
}

Registry& Registry::instance()
{
    struct DeviceState {
        std::unique_ptr registry;
        std::once_flag            flag;
    };

    static std::vector> states = [] {
        int count{};
        TM_CHECK_EQ(cudaGetDeviceCount(&count), cudaSuccess);
        std::vector> vec(count);
        for (auto& s : vec) {
            s = std::make_unique();
        }
        return vec;
    }();

    int device_id{};
    TM_CHECK_EQ(cudaGetDevice(&device_id), cudaSuccess);

    auto& state = *states.at(device_id);

    std::call_once(state.flag, [&]() {
        auto prop = std::make_shared();
        TM_CHECK_EQ(cudaGetDeviceProperties(prop.get(), device_id), cudaSuccess);
        state.registry = std::make_unique(std::move(prop));
    });

    return *TM_CHECK_NOTNULL(state.registry);
}

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/registry.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include "src/turbomind/kernels/attention/kernel_impl.h"

namespace turbomind::attention {

class Registry {
public:
    explicit Registry(std::shared_ptr device_prop);

    template
    [[maybe_unused]] bool Add()
    {
        return Add(std::make_unique>());
    }

    const Kernel* Find(const AttnDesc& desc) const;

    [[nodiscard]] const std::vector& kernels() const
    {
        return ptrs_;
    }

    int sm_count() const noexcept
    {
        return device_prop_->multiProcessorCount;
    }

    static Registry& instance();

private:
    bool Add(std::unique_ptr kernel);

    std::shared_ptr      device_prop_;
    int                                  arch_;
    std::vector> kernels_;
    std::vector                 ptrs_;
};

}  // namespace turbomind::attention


================================================
FILE: src/turbomind/kernels/attention/rotary_embedding.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/models/llama/llama_rope.h"

namespace turbomind {

template
__device__ void init_default(Array& inv_freq, int idx, RopeKernelParam& param)
{
    auto scale_factor = param.scale_factor;
    auto inv_factor   = param.inv_factor;
    PRAGMA_UNROLL
    for (int i = 0; i < N; i += 2) {
        inv_freq[i / 2] = inv_factor * exp2f((idx + i) * scale_factor);
    }
}

template
__device__ void init_yarn(Array& inv_freq, int idx, RopeKernelParam& param)
{
    auto scale_factor            = param.scale_factor;
    auto inv_factor              = param.inv_factor;
    auto ramp_inv_factor_div_2   = param.yarn.ramp_inv_factor_div_2;
    auto ramp_inv_factor_mul_min = param.yarn.ramp_inv_factor_mul_min;

    PRAGMA_UNROLL
    for (int i = 0; i < N; i += 2) {
        auto freq       = exp2f((idx + i) * scale_factor);
        auto alpha      = (idx + i) * ramp_inv_factor_div_2 - ramp_inv_factor_mul_min;
        alpha           = fmaxf(0.f, fminf(1.f, alpha));
        inv_freq[i / 2] = freq - freq * alpha * (1.f - inv_factor);
    }
}

template
__device__ void init_llama3(Array& inv_freq, int idx, RopeKernelParam& param)
{
    auto scale_factor = param.scale_factor;
    auto inv_factor   = param.inv_factor;
    auto alpha        = param.llama3.alpha;
    auto beta         = param.llama3.beta;

    PRAGMA_UNROLL
    for (int i = 0; i < N; i += 2) {
        auto freq       = exp2f((idx + i) * scale_factor);
        auto smooth     = fmaxf(0.f, fminf(1.f, alpha * freq - beta));
        inv_freq[i / 2] = (1 - smooth) * freq * inv_factor + smooth * freq;
    }
}

template
struct FastRoPE {

    static_assert(N % 2 == 0);

    RopeKernelParam     param_;
    Array inv_freq_;
    bool                is_valid_;
    float               attention_scaling_{1.f};
    int                 idx_;

    typedef void (*Func)(Array&, int, RopeKernelParam&);
    Func fill_func_;

    __device__ FastRoPE(const RopeKernelParam& param, int batch_idx, std::integral_constant): param_(param)
    {

        if (param_.type == RopeType::kDynamic) {
            float base          = param_.base[batch_idx];
            param_.scale_factor = -log2f(base) / param_.dim;
        }
        else if (param_.type == RopeType::kYarn) {
            attention_scaling_ = param_.yarn.attention_factor;
        }
        else if (param_.type == RopeType::kMrope) {
            param_.mrope.position_ids += batch_idx * param_.mrope.stride;
            param_.mrope.position_delta += batch_idx;
            param_.mrope.length += batch_idx;
        }
    }

    __device__ void init(int idx)
    {
        is_valid_ = idx < param_.dim;
        idx_      = idx;
        switch (param_.type) {
            case RopeType::kDefault:
            case RopeType::kLinear:
            case RopeType::kDynamic:
            case RopeType::kMrope:
                init_default(inv_freq_, idx, param_);
                break;
            case RopeType::kYarn:
                init_yarn(inv_freq_, idx, param_);
                break;
            case RopeType::kLlama3:
                init_llama3(inv_freq_, idx, param_);
                break;
        }
    }

    template
    __device__ void apply(Array& x, float timestep)
    {
        if (param_.type == RopeType::kMrope) {
            return apply_mrope(x, timestep);
        }
        // Most models apply rotary embedding in half precision
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 2) {
            float c, s;
            sincosf(timestep * inv_freq_[i / 2], &s, &c);
            s *= attention_scaling_;
            c *= attention_scaling_;
            T tmp0 = (T)c * x[i] - (T)s * x[i + 1];
            T tmp1 = (T)c * x[i + 1] + (T)s * x[i];
            if (is_valid_) {
                x[i]     = tmp0;
                x[i + 1] = tmp1;
            }
        }
    }

    template
    __device__ void apply_mrope(Array& x, float timestep)
    {
        int  tt, th, tw;
        int3 section = param_.mrope.section;
        if (timestep < *param_.mrope.length) {
            const int* t = param_.mrope.position_ids + 3 * (int)timestep;
            tt           = t[0];
            th           = t[1];
            tw           = t[2];
        }
        else {
            tt = th = tw = (int)timestep + (*param_.mrope.position_delta);
        }

        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 2) {
            if (i + idx_ < section.x) {
                timestep = (float)tt;
            }
            else if (i + idx_ < section.y) {
                timestep = (float)th;
            }
            else {
                timestep = (float)tw;
            }
            float c, s;
            sincosf(timestep * inv_freq_[i / 2], &s, &c);
            T tmp0 = (T)c * x[i] - (T)s * x[i + 1];
            T tmp1 = (T)c * x[i + 1] + (T)s * x[i];
            if (is_valid_) {
                x[i]     = tmp0;
                x[i + 1] = tmp1;
            }
        }
    }
};

template
struct RotaryEmbedding {

    static_assert(N % 2 == 0);

    Array cs_;

    bool is_valid_;

    __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
    {
        const int idx = offset.x;
        is_valid_     = idx < dims;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 2) {
            const float2 tmp = get_coefficient(idx + i, dims, base, timestep);
            cs_[i]           = tmp.x;
            cs_[i + 1]       = tmp.y;
        }
    }

    // ! depending on the context, this function may generate different result when inlined
    static __device__ __noinline__ float2 get_coefficient(int idx, int dims, float base, int timestep)
    {
        const float inv_freq = timestep / powf(base, idx / (float)dims);
        float2      cs;
        sincosf(inv_freq, &cs.y, &cs.x);
        return cs;
    }

    template
    __device__ void apply(Array& x)
    {

        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 2) {
            auto tmp0 = (T)cs_[i] * x[i] - (T)cs_[i + 1] * x[i + 1];
            auto tmp1 = (T)cs_[i] * x[i + 1] + (T)cs_[i + 1] * x[i];
            if (is_valid_) {
                x[i]     = (T)tmp0;
                x[i + 1] = (T)tmp1;
            }
        }
    }
};
template
__device__ void ApplyRotaryEmbedding(Array& x, float base, int dims, int ti, int di)
{
    PRAGMA_UNROLL
    for (int d1 = 0; d1 < 2; ++d1) {
        int    d        = d1 * 8 + di;
        float  inv_freq = ti / powf(base, d / (float)dims);
        float2 cs;
        sincosf(inv_freq, &cs.y, &cs.x);
        C x1          = (C)cs.x * (C)x[d1 * 2 + 0] - (C)cs.y * (C)x[d1 * 2 + 1];
        C x2          = (C)cs.x * (C)x[d1 * 2 + 1] + (C)cs.y * (C)x[d1 * 2 + 0];
        x[d1 * 2 + 0] = (T)x1;
        x[d1 * 2 + 1] = (T)x2;
    }
}

template
struct RoPE {
    Array inv_freqs_;

    RoPE() = default;
    __device__ RoPE(float idx, float base, float dims)
    {
        for (int i = 0; i < N; ++i) {
            inv_freqs_[i] = powf(base, idx / dims + (C / dims) * i);
        }
    }

    template
    __device__ void apply(Array& x, float timestep)
    {
        for (int i = 0; i < N; ++i) {
            const float inv_freq = timestep * inv_freqs_[i];
            float2      cs;
            sincosf(inv_freq, &cs.y, &cs.x);
            float tmp0   = cs.x * (float)x[i * 2] - cs.y * (float)x[i * 2 + 1];
            float tmp1   = cs.x * (float)x[i * 2 + 1] + cs.y * (float)x[i * 2];
            x[i * 2]     = (T)tmp0;
            x[i * 2 + 1] = (T)tmp1;
        }
    }
};

struct LogNScaling {

    float scale_;

    __device__ static float get_scale(int seq_len, int max_position_embeddings)
    {
        if (seq_len <= max_position_embeddings) {
            return 1.f;
        }
        else {
            return log2f(seq_len) / log2f(max_position_embeddings);
        }
    }

    __device__ LogNScaling(int seq_len, int max_position_embeddings)
    {
        scale_ = get_scale(seq_len, max_position_embeddings);
    }

    template
    __device__ void apply(Array& x) const
    {
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            x[i] = (T)((float)x[i] * scale_);
        }
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/test_attention.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "attention.h"
#include "block.h"
#include "decoding.h"
#include "kv_cache_utils_v2.h"
#include "src/turbomind/kernels/attention/attention_params.h"
#include "src/turbomind/kernels/attention/reference.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "test_utils.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace turbomind;

// [b, h, s, d] : current -> stride_h=s, stride_s=1, stride_b=hs
// [cu_q, h, d] : qkvgemm -> stride_h=1, stride_s=h, stride_b=0
// [h, cu_s, d] : prefill -> stride_h=s, stride_s=1, stride_b=0

template
struct Config {
    int head_dim_;
    int head_num_;
    int block_len_;

    TM_HOST_DEVICE constexpr int t_bits() const
    {
        if constexpr (std::is_same_v) {
            return 0;
        }
        else {
            return bitsof;
        }
    }

    TM_HOST_DEVICE constexpr int q_bits() const
    {
        return bitsof;
    }

    TM_HOST_DEVICE constexpr int head_dim() const
    {
        return head_dim_;
    }

    TM_HOST_DEVICE int head_num() const
    {
        return head_num_;
    }

    TM_HOST_DEVICE constexpr int block_len() const
    {
        return block_len_;
    }

    TM_HOST_DEVICE constexpr bool is_share_kv() const
    {
        return false;
    }
};

// [S/S, H, S, D] <-> [S/b, H, b, D]
template
void TestBlocks(const thrust::universal_vector& k_cache,        // [B, H, S, D]
                const thrust::universal_vector& v_cache,        // [B, H, S, D]
                thrust::universal_vector&    blocks,         // block data
                thrust::universal_vector&   k_ptrs,         // block ptrs
                thrust::universal_vector&     cu_block_cnts,  // cumulative block counts
                const size_t                       head_num,
                const size_t                       head_dim,
                const size_t                       block_seq_len,
                const size_t                       batch_size,
                const int                          rope_dim,
                int                                quant_policy)
{
    const size_t seq_len  = k_cache.size() / (head_dim * head_num * batch_size);
    const size_t n_blocks = (seq_len + block_seq_len - 1) / block_seq_len;

    Config config{(int)head_dim, (int)head_num, (int)block_seq_len};
    block::Layout  layout{config};

    dump(layout);

    const size_t kHSD = head_num * seq_len * head_dim;

    std::cout << "batch_size = " << batch_size << ", seq_len = " << seq_len << ", block_size = " << block_seq_len
              << ", block_num = " << n_blocks << "\n";

    thrust::universal_vector kv_cache(k_cache.size() * 2);  // [B, 2, H, S, D]

    {  // interleave K/V
        auto k_src = k_cache.begin();
        auto v_src = v_cache.begin();
        auto dst   = kv_cache.begin();
        for (size_t i = 0; i < batch_size; ++i) {
            dst = thrust::copy_n(k_src, kHSD, dst);
            dst = thrust::copy_n(v_src, kHSD, dst);
            k_src += kHSD;
            v_src += kHSD;
        }
    }

    // const int kHsD = head_num * block_seq_len * head_dim;

    // [B, S/s, 2, H, s, D]
    // blocks.resize(batch_size * n_blocks * 2 * kHsD);
    blocks.resize(batch_size * n_blocks * layout.block_size(1));
    thrust::fill(blocks.begin(), blocks.end(), NAN);
    k_ptrs.resize(batch_size * n_blocks + 1);  // +1 padding

    std::vector idxs(batch_size * n_blocks);
    std::iota(idxs.begin(), idxs.end(), 0);

    std::random_device rd;
    std::mt19937       g(rd());
    std::shuffle(idxs.begin(), idxs.end(), g);

    for (size_t i = 0; i < idxs.size(); ++i) {
        // k_ptrs[i] = blocks.data().get() + idxs[i] * 2 * kHsD;
        k_ptrs[i] = blocks.data().get() + idxs[i] * layout.block_size(1);
    }

    thrust::universal_vector seq_lens(batch_size);
    thrust::universal_vector cu_seq_lens(batch_size + 1);
    thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);
    for (size_t i = 0; i <= batch_size; ++i) {
        cu_seq_lens[i] = i * seq_len;
    }

    std::vector n_blocks_vec(batch_size + 1, n_blocks);
    cu_block_cnts.resize(batch_size + 1);
    std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);

    cudaDeviceSynchronize();

    // [B, 2H, S, D] -> [B, S/s] x [2H, s, D]
    for (int i = 0; i < 1; ++i) {
        // (B, 2, H, S, D) -> blocks
        invokeProcessKV_v2(k_ptrs.data().get(),
                           kv_cache.data().get(),
                           kv_cache.data().get() + head_num * seq_len * head_dim,
                           (T*)nullptr,
                           (T*)nullptr,
                           cu_seq_lens.data().get(),
                           cu_seq_lens.data().get(),
                           cu_block_cnts.data().get(),
                           RopeKernelParam{},
                           2 * head_num * seq_len,
                           0,
                           seq_len,
                           1,
                           block_seq_len,
                           0,  // layer_id
                           0,  // cp_rank
                           1,  // cp_size
                           seq_len,
                           head_num,
                           head_dim,
                           batch_size,
                           quant_policy);
    }

    thrust::universal_vector kv_cache_2(kv_cache.size());

    // round trip test
    for (int i = 0; i < 1; ++i) {
        // kv_cache_2 is [B, 2, H, S, D]
        invokeFlattenKV_v2(kv_cache_2.data().get(),
                           kv_cache_2.data().get() + head_num * seq_len * head_dim,
                           k_ptrs.data().get(),
                           cu_seq_lens.data().get(),
                           cu_block_cnts.data().get(),
                           RopeKernelParam{},
                           2 * head_num * seq_len,
                           0,
                           seq_len,
                           1,
                           block_seq_len,
                           0,  // layer_id
                           0,  // cp_rank
                           1,  // cp_size
                           seq_len,
                           head_num,
                           head_dim,
                           batch_size,
                           quant_policy);
    }

    cudaDeviceSynchronize();

    if (0) {
        std::cout << ">>> Compare\n";
        Compare(
            kv_cache_2.data().get(), kv_cache.data().get(), head_dim, head_dim, batch_size * 2 * head_num * seq_len, 0);
        std::cout << "<<< Compare\n";
    }
}

double get_memory_bandwidth()  // -> GB/s
{
    int clock_rate_khz{};
    int bus_width_bits{};
    cudaDeviceGetAttribute(&clock_rate_khz, cudaDevAttrMemoryClockRate, 0);
    cudaDeviceGetAttribute(&bus_width_bits, cudaDevAttrGlobalMemoryBusWidth, 0);
    return 2. * (double)clock_rate_khz / 1e6 * (double)bus_width_bits / 8.;
}

#define KV_INT8 0

#define KV_INT4 0

#define DECODING 0

#define SINK 5

template
int test_attention()
{
    AttentionParams params{};

    constexpr size_t kHeadDim    = 128;
    constexpr int    kWindowSize = 128 << 20;

#if DECODING
    // constexpr size_t kHeadNum   = 32;
    // constexpr size_t kBatchSize = 64;
    constexpr size_t kHeadNum   = 64;
    constexpr size_t KvHeadNum  = kHeadNum / 8;
    constexpr size_t kBatchSize = 256;
    constexpr size_t kInputLen  = 1;

    constexpr size_t kSequenceLen = 1000;
    // constexpr size_t kSequenceLen = 4095;
    // constexpr size_t kSequenceLen = 511;
    // constexpr size_t kSequenceLen = 2047;
    // constexpr size_t kSequenceLen = 4095;
    // constexpr size_t kSequenceLen = 8 * 1024 - 1;
    // constexpr size_t kSequenceLen = 32767;
    // constexpr size_t kSequenceLen = 65535;
    // constexpr size_t kSequenceLen = 131071;
    // constexpr size_t kSequenceLen = 200000;
    // constexpr size_t kSequenceLen = 262143;
    // constexpr size_t kSequenceLen = (1 << 20) - 1;  // 1M
    // constexpr size_t kSequenceLen = (1 << 22) - 1;  // 4M
    // constexpr size_t kSequenceLen = (1 << 24) - 1;  // 16M
    // constexpr int kSequenceLen = 2047;
    constexpr int kBlockSz   = 64;
    constexpr int kMaxSplitK = 128;
#else

    // append
    // constexpr size_t kHeadNum     = 32;
    // constexpr size_t KvHeadNum    = kHeadNum;
    // constexpr size_t kBatchSize   = 1;
    // constexpr size_t kInputLen    = 128;
    // constexpr size_t kSequenceLen = 65536;
    // constexpr int    kMaxSplitK   = 128;

    // constexpr size_t kHeadNum     = 1;
    // constexpr size_t KvHeadNum    = kHeadNum;
    // constexpr size_t kBatchSize   = 1;
    // constexpr size_t kInputLen    = 64;
    // constexpr size_t kSequenceLen = 65536;
    // constexpr int    kMaxSplitK   = 1;

    // prefill
    constexpr size_t kHeadNum     = 16;
    constexpr size_t KvHeadNum    = kHeadNum / 8;
    constexpr size_t kBatchSize   = 2;
    constexpr size_t kInputLen    = 8192;
    constexpr size_t kSequenceLen = 0;
    constexpr int    kMaxSplitK   = 1;

    constexpr int kBlockSz     = 64;

#endif

#if KV_INT8
    using Tkv                  = uint8_t;
    constexpr int kQuantPolicy = QuantPolicy::kCacheKVInt8;
#elif KV_INT4
    using Tkv                  = uint4_t;
    constexpr int kQuantPolicy = QuantPolicy::kCacheKVInt4;
#else
    using Tkv                  = T;
    constexpr int kQuantPolicy = 0;
#endif

    static_assert(KvHeadNum > 0);

    constexpr size_t kContextLen = kSequenceLen + kInputLen;
    constexpr size_t kTokenNum   = kBatchSize * kInputLen;
    constexpr int    kTestIter   = 10;

    constexpr float kRoPEBase = 10000.f;
    constexpr int   kRoPEDim  = kHeadDim / 2;
    constexpr int   kDump     = 0;

    RNG rng{};

    thrust::universal_vector k_cache(kBatchSize * KvHeadNum * kContextLen * kHeadDim);
    thrust::universal_vector v_cache(kBatchSize * KvHeadNum * kContextLen * kHeadDim);

    // flattened float point KV cache
    thrust::device_vector kv_cache(KvHeadNum * 2 * (kBatchSize * kContextLen + MAX_CTA_S) * kHeadDim);

    thrust::universal_vector qkv(kBatchSize * kInputLen * (kHeadNum + KvHeadNum * 2) * kHeadDim);
    thrust::universal_vector output(kBatchSize * kInputLen * kHeadNum * kHeadDim);

    thrust::universal_vector  finished(kBatchSize);
    thrust::universal_vector   sequence_length(kBatchSize);
    thrust::universal_vector   input_length(kBatchSize);
    thrust::universal_vector   context_length(kBatchSize);
    thrust::universal_vector rope_base(kBatchSize);
    thrust::universal_vector   cu_seqlens(kBatchSize + 1);
    thrust::universal_vector   cu_kv_lens(kBatchSize + 1);

    thrust::device_vector partial_ML(kTokenNum * kHeadNum * kMaxSplitK * 2);
    thrust::device_vector partial_O(kTokenNum * kHeadNum * kMaxSplitK * kHeadDim);
    thrust::device_vector   split_cnt(kTokenNum);

    thrust::universal_vector qk_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen);
    thrust::universal_vector     pr_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen);

    thrust::universal_vector sinks(kHeadNum);

    rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);

    rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
    rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);

    if (SINK) {
        rng.GenerateUniform(sinks.data().get(), sinks.size(), 2 * SINK, -SINK);
    }

    if (0) {
        // Set input range to zero
        // (BH, SD)
        cudaMemset2DAsync(k_cache.data().get() + kSequenceLen * kHeadDim,
                          sizeof(T) * kContextLen * kHeadDim,
                          0,
                          sizeof(T) * kInputLen * kHeadDim,
                          kBatchSize * KvHeadNum);
        cudaMemset2DAsync(v_cache.data().get() + kSequenceLen * kHeadDim,
                          sizeof(T) * kContextLen * kHeadDim,
                          0,
                          sizeof(T) * kInputLen * kHeadDim,
                          kBatchSize * KvHeadNum);
    }

    invokeApplyRotaryEmbedding(k_cache.data().get(), kContextLen, KvHeadNum, kHeadDim, kRoPEBase, kRoPEDim, kBatchSize);

    thrust::universal_vector k_cache_ref = k_cache;
    thrust::universal_vector v_cache_ref = v_cache;

    thrust::universal_vector  blocks;
    thrust::universal_vector k_ptrs;
    thrust::universal_vector   cu_block_cnts;

    TestBlocks(k_cache,
                    v_cache,
                    blocks,
                    k_ptrs,
                    cu_block_cnts,
                    KvHeadNum,
                    kHeadDim,
                    kBlockSz,
                    kBatchSize,
                    kRoPEDim,
                    kQuantPolicy);

    thrust::universal_vector     output_ref = output;
    thrust::universal_vector k_cache_ref_ptrs(kBatchSize);
    thrust::universal_vector v_cache_ref_ptrs(kBatchSize);

    thrust::universal_vector bias_QKV(kHeadNum * kHeadDim + 2 * KvHeadNum * kHeadDim);

    rng.GenerateNormal(bias_QKV.data().get(), bias_QKV.size(), 0.1f, 0.f);

    cudaDeviceSynchronize();

    for (size_t i = 0; i <= kBatchSize; ++i) {
        cu_seqlens[i] = i * kInputLen;
        cu_kv_lens[i] = i * kContextLen;
    }

    for (size_t i = 0; i < kBatchSize; ++i) {
        input_length[i]     = kInputLen;
        sequence_length[i]  = kSequenceLen;
        context_length[i]   = kContextLen;
        k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
        v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
        rope_base[i]        = kRoPEBase;
    }

    // getchar();

    params.out = output_ref.data().get();
    params.q   = qkv.data().get();
    params.k   = params.q + kHeadNum * kHeadDim;
    params.v   = params.k + KvHeadNum * kHeadDim;

    params.q_bias = bias_QKV.data().get();
    params.k_bias = params.q_bias + kHeadNum * kHeadDim;
    params.v_bias = params.k_bias + KvHeadNum * kHeadDim;

    params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;

    params.token_num  = kTokenNum;
    params.batch_size = kBatchSize;
    params.max_q_len  = kInputLen;
    params.max_k_len  = kContextLen;

    params.block_iter_params = BlockIteratorParams{k_ptrs.data().get(),  //
                                                   cu_block_cnts.data().get(),
                                                   0,
                                                   kBlockSz};

    params.linear_iter_params = LinearIteratorParams{kv_cache.data().get(),  //
                                                     int(2 * kBatchSize * kContextLen * kHeadDim),
                                                     int(kBatchSize * kContextLen * kHeadDim)};

    params.quant_policy = kQuantPolicy;

    params.finished   = finished.data().get();
    params.rope_theta = rope_base.data().get();
    params.cu_q_len   = cu_seqlens.data().get();
    params.cu_k_len   = cu_kv_lens.data().get();

    params.num_heads     = kHeadNum;
    params.num_kv_heads  = KvHeadNum;
    params.size_per_head = kHeadDim;
    params.window_size   = kWindowSize;
    params.inv_sqrt_dh   = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head);

    if (SINK) {
        params.sinks       = sinks.data().get();
        params.scale_sinks = 1. / std::sqrt((float)params.size_per_head);
    }

    float scale_factor = -std::log2f(kRoPEBase) / kRoPEDim;
    params.rope_param  = RopeKernelParam{RopeType::kDefault, nullptr, kRoPEDim, scale_factor, 1.f};

    params.split_cnt  = split_cnt.data().get();
    params.partial_ML = partial_ML.data().get();
    params.partial_O  = partial_O.data().get();

    params.max_split_k = kMaxSplitK;
    params.arch        = getSMVersion();

    params.qk = qk_buf.data().get();
    params.pr = pr_buf.data().get();

    Reference reference({});
    reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize, kWindowSize);

    for (int i = 0; i < 1; ++i) {
        reference.Execute(params.out,  //
                          k_cache_ref.data().get(),
                          v_cache_ref.data().get(),
                          qkv.data().get(),
                          bias_QKV.data().get(),
                          SINK ? sinks.data().get() : nullptr,
                          kRoPEBase,
                          kRoPEDim);
    }

    cudaDeviceSynchronize();

    if constexpr (kDump) {
        for (size_t b = 0; b < kBatchSize; ++b) {
            for (size_t h = 0; h < kHeadNum; ++h) {
                for (size_t q = 0; q < kInputLen; ++q) {
                    auto qk = reference.qk() + b * kHeadNum * kInputLen * kContextLen + h * kInputLen * kContextLen
                              + q * kContextLen;
                    for (size_t k = 0; k < kContextLen; ++k) {
                        std::cout << qk[k] * params.inv_sqrt_dh << " ";
                    }
                    std::cout << "\n";
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    if (auto err = cudaGetLastError(); err != cudaSuccess) {
        std::cout << cudaGetErrorString(err) << "\n";
        return -1;
    }
    std::cout << "---------------------------------------------------\n";

    params.out = output.data().get();

    std::vector> outputs;

    std::vector ev_start(kTestIter);
    std::vector ev_end(kTestIter);

    for (int i = 0; i < kTestIter; ++i) {
        cudaEventCreate(&ev_start[i]);
        cudaEventCreate(&ev_end[i]);
    }

    for (int i = 0; i < std::max(kTestIter, 1); ++i) {

#if DECODING
        cudaEventRecord(ev_start[i]);
        dispatchDecoding(params);
        cudaEventRecord(ev_end[i]);
#else
        // input -> blocked
        invokeProcessKV_v2_(params);
        // blocked -> linear
        invokeFlattenKV_v2_(params, cu_kv_lens[kBatchSize]);

        cudaEventRecord(ev_start[i]);
        dispatchAttention(params);
        cudaEventRecord(ev_end[i]);
#endif

        if (auto err = cudaGetLastError(); err != cudaSuccess) {
            std::cout << cudaGetErrorString(err) << "\n";
            return -1;
        }
        if (1) {
            outputs.push_back(output);
        }
    }

    if (kDump) {
        cudaDeviceSynchronize();
        for (size_t b = 0; b < kBatchSize; ++b) {
            for (size_t h = 0; h < kHeadNum; ++h) {
                for (size_t q = 0; q < kInputLen; ++q) {
                    auto ref = reference.qk() + b * kHeadNum * kInputLen * kContextLen + h * kInputLen * kContextLen
                               + q * kContextLen;
                    auto data = qk_buf.data().get() + b * kHeadNum * kInputLen * kContextLen
                                + h * kInputLen * kContextLen + q * kContextLen;
                    for (size_t k = 0; k < kContextLen; ++k) {
                        // std::cout << std::max(0.f, std::abs(data[k] - (float)ref[k]) - 1e-5f) << " ";
                        std::cout << data[k] * params.inv_sqrt_dh << " ";
                        // std::cout << (float)data[k] << " ";
                    }
                    std::cout << "\n";
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    invokeFlattenKV_v2(k_cache.data().get(),  // [B, H, S, D]
                       v_cache.data().get(),
                       k_ptrs.data().get(),
                       cu_kv_lens.data().get(),
                       cu_block_cnts.data().get(),
                       RopeKernelParam{},  // DECODING ? nullptr : params.rope_theta,
                       KvHeadNum * kContextLen,
                       0,
                       kContextLen,
                       1,
                       kBlockSz,
                       0,  // layer_id
                       0,  // cp_rank
                       1,  // cp_size
                       kContextLen,
                       KvHeadNum,
                       kHeadDim,
                       kBatchSize,
                       kQuantPolicy);
    cudaDeviceSynchronize();

    const size_t nbytes = blocks.size() / kContextLen * std::min(kContextLen, (size_t)params.window_size);
    const size_t ops =
        2 * kInputLen * std::min(kContextLen, (size_t)params.window_size) * kHeadDim * kHeadNum * kBatchSize;

    const float peak_bw = get_memory_bandwidth();

    std::cout << "Device peak global memory bandwidth: " << peak_bw << " GB/s\n";

    for (int i = 0; i < kTestIter; ++i) {
        float ms{};
        cudaEventElapsedTime(&ms, ev_start[i], ev_end[i]);
        const float bw      = nbytes / 1e9f / ms * 1000.f;
        const float flops   = ops / 1e12f / ms * 1000.f;
        const float percent = bw / peak_bw * 100.f;
        printf("time %.3f ms, bw %.3f GB/s, %.3f %%, tflops %.3f \n", ms, bw, percent, flops);
    }

    if (outputs.size() > 1) {
        std::cout << "Evaluating consistency..." << std::endl;
        for (size_t i = 1; i < outputs.size(); ++i) {
            Compare(outputs[i].data().get(), outputs[i - 1].data().get(), kHeadDim, kHeadDim, kHeadNum, 0, 0, 0);
        }
    }

    std::cout << "---------------------------------------------------\n";

    // [B, S, H, D]
    Compare(output.data().get(),  //
            output_ref.data().get(),
            kHeadNum * kHeadDim,
            kHeadNum * kHeadDim,
            kBatchSize * kInputLen,
            0);

    // [BH, SD]
    Compare(k_cache.data().get() + kSequenceLen * kHeadDim,
            k_cache_ref.data().get() + kSequenceLen * kHeadDim,
            kContextLen * kHeadDim,
            kInputLen * kHeadDim,
            kBatchSize * KvHeadNum,
            0);
    Compare(v_cache.data().get() + kSequenceLen * kHeadDim,
            v_cache_ref.data().get() + kSequenceLen * kHeadDim,
            kContextLen * kHeadDim,
            kInputLen * kHeadDim,
            kBatchSize * KvHeadNum);

    return 0;
}

int main(int argc, char* argv[])
{
    test_attention();

    // test_attention();
}


================================================
FILE: src/turbomind/kernels/attention/test_quant.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "quantization.h"
#include "src/turbomind/kernels/attention/test_utils.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/macro.h"
#include 
#include 
#include 

using namespace turbomind;

template
__global__ void convert(T1* dst, const T0* src, size_t n, float scale, float zero)
{
    auto v_src = (Array*)src;
    auto v_dst = (Array*)dst;

    const int v_n = n / kVecSize;

    ConvertKvCache converter{scale, zero};

    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < v_n; i += blockDim.x * gridDim.x) {
        Array vi;
        Array vo;
        Load(vi, (T0*)v_src[i].data());
        vo = converter(vi);
        Store((T1*)v_dst[i].data(), vo);
    }
}

template
void round_trip_test(size_t n, float s1 = 1., float z1 = 0., float s2 = 1., float z2 = 0.)
{
    std::cout << __PRETTY_FUNCTION__ << std::endl;

    using namespace thrust;

    universal_vector src(n);
    universal_vector dst(src.size());

    universal_vector> tmp(src.size() / kVecSize);

    for (size_t i = 0; i < src.size(); ++i) {
        src[i] = T0(float(rand() % (1 << bitsof)));
    }

    convert<<<256, 256>>>((T1*)tmp.data().get(), src.data().get(), n, s1, z1);
    convert<<<256, 256>>>(dst.data().get(), (const T1*)tmp.data().get(), n, s2, z2);

    cudaDeviceSynchronize();

    Compare(dst.data().get(), src.data().get(), src.size(), src.size(), 1);
}

int main(int argc, char* argv[])
{
    round_trip_test(1 << 20);
    round_trip_test(1 << 20);
#if ENABLE_BF16
    round_trip_test(1 << 20);
#endif

    round_trip_test(1 << 20, 1, 0, 1, -64);
    round_trip_test(1 << 20, 1, 0, 1, -64);
#if ENABLE_BF16
    round_trip_test(1 << 20, 1, 0, 1, 0);
#endif

    return 0;
}


================================================
FILE: src/turbomind/kernels/attention/test_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "test_utils.h"
#include 
#include 
#include 
#include 
#include 

#define _CG_ABI_EXPERIMENTAL
#include 
#include 
#include 

namespace turbomind {

cublasHandle_t cublas_handle{};
cudaStream_t   cublas_stream{};

template
void Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)
{
    float asums{};
    float rsums{};
    int   outliers{};
    for (int nn = 0; nn < n; ++nn) {
        float abs_diff_sum{};
        float rel_diff_sum{};
        for (int mm = 0; mm < m; ++mm) {
            auto x = float(src[nn * stride + mm]);
            auto y = float(ref[nn * stride + mm]);
            // if (show) {
            //     std::cout << x << "\t" << y << std::endl;
            // }
            auto abs_diff = std::abs(x - y);
            auto rel_diff = abs_diff / std::abs(y + 1e-6f);
            if (!(abs_diff <= atol + rtol * std::abs(y))) {
                ++outliers;
                if (show) {
                    std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
                }
            }
            abs_diff_sum += abs_diff;
            rel_diff_sum += rel_diff;
        }
        asums += abs_diff_sum / m;
        rsums += rel_diff_sum / m;
    }
    std::cout << "abs_diff = " << asums / n << " rel_diff = " << rsums / n << " outliers = " << outliers / (float)n
              << std::endl;
}

template void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
template void
Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
#if ENABLE_BF16
template void
Compare(const nv_bfloat16* src, const nv_bfloat16* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
#endif

void LoadBinary(const std::string& path, size_t size, void* dst)
{
    std::ifstream ifs(path, std::ios::binary | std::ios::in);
    if (!ifs.is_open()) {
        std::cerr << "failed to open " << path << "\n";
        std::abort();
    }
    ifs.seekg(0, ifs.end);
    auto actual_size_in_bytes = ifs.tellg();
    ifs.seekg(0, ifs.beg);
    if (size != actual_size_in_bytes) {
        std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
                  << " bytes is requested\n";
    }
    ifs.read((char*)dst, size);
    std::cerr << "[info] " << path << " " << size << "\n";
}

namespace cg = cooperative_groups;

__global__ void curand_init(curandState* state)
{
    auto tid = cg::this_grid().thread_rank();
    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
}

template
__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        float tmp = curand_uniform(state + grid.thread_rank());
        result[i] = T(scale * tmp + shift);
    }
}

template
__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        float tmp = curand_normal(state + grid.thread_rank());
        result[i] = T(scale * tmp + shift);
    }
}

__global__ void curand_bytes(curandState* state, size_t count, uint* result)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        result[i] = curand(state + grid.thread_rank());
    }
}

struct RNG::Impl {

    curandState* states{};

    Impl()
    {
        cudaMalloc(&states, sizeof(curandState) * 64 * 64);
        curand_init<<<64, 64>>>(states);
    }

    ~Impl()
    {
        cudaFree(states);
    }

    void GenerateUInt(uint* out, size_t count)
    {
        curand_bytes<<<64, 64>>>(states, count, out);
    }

    template
    void GenerateUniform(T* out, size_t count, float scale, float shift)
    {
        curand_uniform<<<64, 64>>>(states, count, out, scale, shift);
    }

    template
    void GenerateNormal(T* out, size_t count, float scale, float shift)
    {
        curand_normal<<<64, 64>>>(states, count, out, scale, shift);
    }
};

RNG::RNG(): impl_(std::make_unique()) {}

RNG::~RNG() = default;

void RNG::GenerateUInt(uint* out, size_t count)
{
    impl_->GenerateUInt(out, count);
}

template
void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
{
    std::cout << count << std::endl;
    impl_->GenerateUniform(out, count, scale, shift);
}

template
void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
{
    impl_->GenerateNormal(out, count, scale, shift);
}

template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
#if ENABLE_BF16
template void RNG::GenerateUniform(nv_bfloat16* out, size_t count, float scale, float shift);
#endif

template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
#if ENABLE_BF16
template void RNG::GenerateNormal(nv_bfloat16* out, size_t count, float scale, float shift);
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/test_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "attention.h"
#include "src/turbomind/macro.h"
#include 
#include 

namespace turbomind {

template
void Compare(
    const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);

void LoadBinary(const std::string& path, size_t size, void* dst);

class RNG {
public:
    RNG();
    ~RNG();
    void GenerateUInt(uint* out, size_t count);

    template
    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);

    template
    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);

private:
    struct Impl;
    std::unique_ptr impl_;
};

template
void mmha_ft_reference(const AttentionParams& params,
                       T**                       per_sample_k_cache,
                       T**                       per_sample_v_cache,
                       const int*                sequence_length,
                       int                       max_memory_len,
                       cudaStream_t              st);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/utils.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "utils.h"
#include 
#include 
#include 
#include 

namespace turbomind {

int GetSplitCount(
    int max_split_cnt, int grid_size, int max_active_ctas, int sm_count, int max_wave_cnt, float alpha, float beta)
{

    const float scale = (float)grid_size / (sm_count * max_active_ctas);

    auto eval = [&](int s) -> std::tuple {
        float waves = std::ceil(scale * s);
        float cost  = std::numeric_limits::infinity();
        if (s == 1 || waves <= max_wave_cnt) {
            cost = (alpha / s + beta) * waves;
        }
        return {cost, scale * s, s};
    };

    std::tuple best{std::numeric_limits::infinity(), 0.f, 0};

    auto print = [](auto& x) {  //
        // printf("%d %f %f\n", std::get<2>(x), std::get<1>(x), std::get<0>(x));
    };

    for (int i = 1; i <= max_split_cnt; ++i) {
        auto res = eval(i);
        if (std::isinf(std::get<0>(res))) {
            break;
        }
        print(res);
        if (res < best) {
            best = res;
        }
    }

    print(best);

    return std::get(best);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/attention/utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

int GetSplitCount(int   max_split_cnt,
                  int   grid_size,
                  int   max_active_ctas,
                  int   sm_count,
                  int   max_wave_cnt,
                  float alpha = 1,
                  float beta  = 1e-3);

}


================================================
FILE: src/turbomind/kernels/ban_bad_words.cu
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/kernels/ban_bad_words.h"
#include 
// #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
// #include "src/turbomind/utils/cuda_utils.h"
#include 
#include 

namespace turbomind {

template
__device__ inline T getMaxValue();

template<>
__device__ inline float getMaxValue()
{
    return FLT_MAX;
}

template<>
__device__ inline half getMaxValue()
{
    return __ushort_as_half((unsigned short)0x7BFFU);
}

#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 getMaxValue<__nv_bfloat16>()
{
#if __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16((unsigned short)0x7F7FU);
#endif
    return {};
}
#endif

template
__global__ void BanBadWordsKernel(T*                logits,
                                  const int* const* token_ids_ptrs,
                                  const int*        sequence_length,
                                  const int*        bad_words,
                                  size_t            bad_words_len,
                                  int               vocab_size)
{
    const int id        = blockIdx.x * blockDim.x + threadIdx.x;
    const int batch_idx = blockIdx.y;

    const int* base_bad_words         = bad_words + batch_idx * 2 * bad_words_len;
    const int* base_bad_words_offsets = base_bad_words + bad_words_len;

    if (id >= bad_words_len || base_bad_words_offsets[id] < 0) {
        return;
    }

    const int item_end   = base_bad_words_offsets[id];
    const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
    const int item_size  = item_end - item_start;

    const int  seq_len   = sequence_length[batch_idx];
    const int* token_ids = token_ids_ptrs[batch_idx];

    /* The single-token case unconditionally bans the token */
    bool should_ban = item_size == 1;

    /* Multi-token case and enough previously generated tokens to look for a match */
    if (item_size > 1 && seq_len >= item_size - 1) {
        should_ban = true;
        for (int token_idx = item_size - 2, offset = seq_len - 1; token_idx >= 0; token_idx--, offset--) {
            if (token_ids[offset] != base_bad_words[item_start + token_idx]) {
                should_ban = false;
                break;
            }
        }
    }

    logits += batch_idx * (int64_t)vocab_size;
    if (should_ban) {
        int banned_token = base_bad_words[item_end - 1];
        if (0 < banned_token && banned_token < vocab_size) {
            logits[banned_token] = -getMaxValue();
        }
    }
}

void BanBadWords(Tensor&             logits,
                 const Buffer_ token_ids_ptrs,
                 const Buffer_& sequence_length,
                 const Tensor_& bad_words,
                 cudaStream_t        stream)
{

    auto invoke = [&](auto dtype) {
        using T = decltype(dtype);

        const auto [bsz, vocab_size] = logits.shapes(0, 1);
        const int bad_words_len      = bad_words.shape(2);

        const int  block = std::min(round_up(bad_words_len, WARP_SIZE), 256);
        const dim3 grid(cdiv(bad_words_len, block), bsz);

        BanBadWordsKernel<<>>(logits.data(),
                                                      token_ids_ptrs.data(),
                                                      sequence_length.data(),
                                                      bad_words.data(),
                                                      bad_words_len,
                                                      vocab_size);
    };

    invoke(float{});
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/ban_bad_words.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

void BanBadWords(Tensor&             logits,
                 const Buffer_ token_ids_ptrs,
                 const Buffer_& sequence_length,
                 const Tensor_& bad_words,
                 cudaStream_t        stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/array.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/sub_byte_ptr.h"

namespace turbomind {

template
struct Array {
    using value_type      = T;
    using size_type       = int;
    using difference_type = int;
    using reference       = value_type&;
    using const_reference = const value_type&;
    using pointer         = value_type*;
    using const_pointer   = const value_type*;
    using iterator        = pointer;
    using const_iterator  = const_pointer;

    static_assert(N > 0);

    T __a[N];

    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept
    {
        return __a[i];
    }

    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept
    {
        return __a[i];
    }

    TM_HOST_DEVICE constexpr reference front() noexcept
    {
        return *begin();
    }

    TM_HOST_DEVICE constexpr const_reference front() const noexcept
    {
        return *begin();
    }

    TM_HOST_DEVICE constexpr reference back() noexcept
    {
        return *(end() - 1);
    }

    TM_HOST_DEVICE constexpr const_reference back() const noexcept
    {
        return *(end() - 1);
    }

    TM_HOST_DEVICE constexpr pointer data() noexcept
    {
        return &__a[0];
    }

    TM_HOST_DEVICE constexpr const_pointer data() const noexcept
    {
        return &__a[0];
    }

    TM_HOST_DEVICE constexpr iterator begin() noexcept
    {
        return data();
    }

    TM_HOST_DEVICE constexpr const_iterator begin() const noexcept
    {
        return data();
    }

    TM_HOST_DEVICE constexpr iterator end() noexcept
    {
        return data() + N;
    }

    TM_HOST_DEVICE constexpr const_iterator end() const noexcept
    {
        return data() + N;
    }

    TM_HOST_DEVICE static constexpr std::integral_constant size() noexcept
    {
        return {};
    }

    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept
    {
        return {};
    }
};

template
struct Array {
    using value_type      = detail::__uint4_t;
    using size_type       = int;
    using difference_type = int;
    using reference       = value_type&;
    using const_reference = const value_type&;
    using pointer         = SubBytePtr;
    using const_pointer   = SubBytePtr;

    // static_assert(N % 8 == 0);

    detail::__uint4_t __a[N / 8];

    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept
    {
        return __a[i / 8];
    }

    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept
    {
        return __a[i / 8];
    }

    TM_HOST_DEVICE static constexpr std::integral_constant size() noexcept
    {
        return {};
    }

    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept
    {
        return {};
    }

    TM_HOST_DEVICE constexpr pointer data() noexcept
    {
        return {(char*)&__a[0]};
    }
};

static_assert(sizeof(Array) == 4);
static_assert(sizeof(Array) == 8);
static_assert(sizeof(Array) == 12);
static_assert(sizeof(Array) == 16);

template
struct Array {
    using value_type      = detail::__uint4_t;
    using size_type       = int;
    using difference_type = int;
    using reference       = value_type&;
    using const_reference = const value_type&;
    using pointer         = SubBytePtr;
    using const_pointer   = SubBytePtr;

    // static_assert(N % 8 == 0);

    detail::__uint4_t __a[N / 8];

    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept
    {
        return __a[i / 8];
    }

    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept
    {
        return __a[i / 8];
    }

    TM_HOST_DEVICE static constexpr std::integral_constant size() noexcept
    {
        return {};
    }

    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept
    {
        return {};
    }

    TM_HOST_DEVICE constexpr pointer data() noexcept
    {
        return {(char*)&__a[0]};
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/array_ops.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include 
#include 

namespace turbomind {

namespace ops {

template
struct plus {
    __device__ T operator()(T a, T b)
    {
        return a + b;
    }
};

template
struct minus {
    __device__ T operator()(T a, T b)
    {
        return a - b;
    }
};

template
struct multiplies {
    __device__ T operator()(T a, T b)
    {
        return a * b;
    }
};

template
inline __device__ Array binary_op_vv(const Array& a, const Array& b, Op op)
{
    Array c;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        c[i] = op(a[i], b[i]);
    }
    return c;
}

template
inline __device__ Array binary_op_sv(const T& a, const Array& b, Op op)
{
    Array c;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        c[i] = op(a, b[i]);
    }
    return c;
}

template
inline __device__ Array binary_op_vs(const Array& a, const T& b, Op op)
{
    Array c;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        c[i] = op(a[i], b);
    }
    return c;
}

template
inline __device__ Array operator+(const Array& a, const Array& b)
{
    return binary_op_vv(a, b, plus{});
}

template
inline __device__ Array operator*(const Array& a, const Array& b)
{
    return binary_op_vv(a, b, multiplies{});
}

template
inline __device__ Array operator*(const Array& a, const T& b)
{
    return binary_op_vs(a, b, multiplies{});
}

template
inline __device__ Array operator+(const Array& a, const T& b)
{
    return binary_op_vs(a, b, plus{});
}

template
inline __device__ Array operator-(const Array& a, const T& b)
{
    return binary_op_vs(a, b, minus{});
}

}  // namespace ops

template
inline __device__ Array cast(const Array& src)
{
    Array dst;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        dst[i] = (To)src[i];
    }
    return dst;
}

template
inline __device__ void fill(Array& x, T val)
{
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        x[i] = val;
    }
}

template
inline __device__ void fill(Array (&x)[M], T val)
{
    PRAGMA_UNROLL
    for (int i = 0; i < M; ++i) {
        fill(x[i], val);
    }
}

template
inline __device__ void clear(Array& x)
{
    fill(x, T(0));
}

template
inline __device__ void clear(Array (&x)[M])
{
    PRAGMA_UNROLL
    for (int i = 0; i < M; ++i) {
        clear(x[i]);
    }
}

template
inline __device__ void clear(Array (&x)[M1][M0])
{
    PRAGMA_UNROLL
    for (int m1 = 0; m1 < M1; ++m1) {
        PRAGMA_UNROLL
        for (int m0 = 0; m0 < M0; ++m0) {
            clear(x[m1][m0]);
        }
    }
}

template
inline __device__ void copy(const Array& src, Array& dst)
{
    dst = src;
}

template
inline __device__ void copy(const Array (&src)[M], Array (&dst)[M])
{
    PRAGMA_UNROLL
    for (int m = 0; m < M; ++m) {
        dst[m] = src[m];
    }
}

template
inline __device__ void Store(T* dst, const Array& src)
{
    if constexpr (sizeof(Array) == sizeof(uint4)) {
        *(uint4*)dst = (const uint4&)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        *(uint2*)dst = (const uint2&)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint1)) {
        *(uint1*)dst = (const uint1&)src;
    }
    else if constexpr (sizeof(Array) == sizeof(ushort)) {
        *(ushort*)dst = (const ushort&)src;
    }
    else if constexpr (sizeof(Array) == sizeof(char)) {
        *(char*)dst = (const char&)src;
    }
    else if constexpr (sizeof(Array) % sizeof(uint4) == 0) {  //  uncoalesced
        static_assert(bitsof % 8 == 0, "raw pointer arithmetic of sub-byte types");
        constexpr int M = sizeof(Array) / sizeof(uint4);
        PRAGMA_UNROLL
        for (int i = 0; i < M; ++i) {
            *((uint4*)dst + i) = *((uint4*)&src + i);
        }
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Stcs(T* __restrict__ dst, const Array& src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        __stcs((uint4*)dst, (const uint4&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        __stcs((uint2*)dst, (const uint2&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint1)) {
        __stcs((uint*)dst, (const uint&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        __stcs((uint16_t*)dst, (const uint16_t&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        __stcs((uint8_t*)dst, (const uint8_t&)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Stcg(T* __restrict__ dst, const Array& src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        __stcg((uint4*)dst, (const uint4&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        __stcg((uint2*)dst, (const uint2&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint1)) {
        __stcg((uint*)dst, (const uint&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        __stcg((uint16_t*)dst, (const uint16_t&)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        __stcg((uint8_t*)dst, (const uint8_t&)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Ldg(Array& dst, const T* src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        (uint4&)dst = __ldg((const uint4*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        (uint2&)dst = __ldg((const uint2*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        (uint&)dst = __ldg((const uint*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        (uint16_t&)dst = __ldg((const uint16_t*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        (uint8_t&)dst = __ldg((const uint8_t*)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Ldcs(Array& dst, const T* src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        (uint4&)dst = __ldcs((const uint4*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        (uint2&)dst = __ldcs((const uint2*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        (uint&)dst = __ldcs((const uint*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        (uint16_t&)dst = __ldcs((const uint16_t*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        (uint8_t&)dst = __ldcs((const uint8_t*)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Ldcg(Array& dst, const T* src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        (uint4&)dst = __ldcg((const uint4*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        (uint2&)dst = __ldcg((const uint2*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        (uint&)dst = __ldcg((const uint*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        (uint16_t&)dst = __ldcg((const uint16_t*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        (uint8_t&)dst = __ldcg((const uint8_t*)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Load(Array& dst, const T* src)
{
    if constexpr (sizeof(Array) == sizeof(uint4)) {
        (uint4&)dst = *(const uint4*)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        (uint2&)dst = *(const uint2*)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        (uint1&)dst = *(const uint1*)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint16_t)) {
        (uint16_t&)dst = *(const uint16_t*)src;
    }
    else if constexpr (sizeof(Array) == sizeof(uint8_t)) {
        (uint8_t&)dst = *(const uint8_t*)src;
    }
    else if constexpr (sizeof(Array) % sizeof(uint4) == 0) {  //  uncoalesced
        static_assert(bitsof % 8 == 0, "raw pointer arithmetic of sub-byte types");
        constexpr int M = sizeof(Array) / sizeof(uint4);
        PRAGMA_UNROLL
        for (int i = 0; i < M; ++i) {
            *((uint4*)&dst + i) = *((uint4*)src + i);
        }
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void Lds(Array& dst, const T* src)
{
    Load(dst, src);
}

template
inline __device__ void LdShared(Array& dst, uint32_t uintptr)
{
    static_assert(sizeof(Array) <= sizeof(uint4));
    if constexpr (sizeof(Array) == sizeof(uint4)) {
        uint4& p = (uint4&)dst;
        // clang-format off
        asm volatile("ld.shared.v4.b32 {%0,%1,%2,%3}, [%4];\n" : "=r"(p.x), "=r"(p.y), "=r"(p.z), "=r"(p.w) : "r"(uintptr));
        // clang-format on
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        uint2& p = (uint2&)dst;
        asm volatile("ld.shared.v2.b32 {%0,%1}, [%2];\n" : "=r"(p.x), "=r"(p.y) : "r"(uintptr));
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        uint& p = (uint&)dst;
        asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(p) : "r"(uintptr));
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ void StShared(uint32_t uintptr, Array& src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));
    if constexpr (sizeof(Array) == sizeof(uint4)) {
        uint4& p = (uint4&)src;
        // clang-format off
        asm volatile("st.shared.v4.b32 [%0], {%1,%2,%3,%4};\n" :: "r"(uintptr), "r"(p.x), "r"(p.y), "r"(p.z), "r"(p.w) );
        // clang-format on
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        uint2& p = (uint2&)src;
        asm volatile("st.shared.v2.b32 [%0], {%1,%2};\n" ::"r"(uintptr), "r"(p.x), "r"(p.y));
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        uint& p = (uint&)src;
        asm volatile("st.shared.b32  [%0], %1;\n" ::"r"(uintptr), "r"(p));
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
inline __device__ Array blockSum(Array val, T* smem_red, int warp_id, int lane_id)
{
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        PRAGMA_UNROLL
        for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
        }
        if (lane_id == 0) {
            smem_red[i * kWarpCount + warp_id] = val[i];
        }
    }

    __syncthreads();

    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};
        PRAGMA_UNROLL
        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
        }
        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
    }

    return val;
}

template
__device__ void CpAsync(T* dst, const Array* __restrict__ src)
{
    const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
    constexpr int cp_size      = sizeof(Array);
#if TURBOMIND_ARCH_SM80
    asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(smem_int_ptr), "l"(src), "n"(cp_size));
#else
    assert(TURBOMIND_ARCH_SM80);
#endif
}

__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value)
{
    const int lane_id  = threadIdx.x % WARP_SIZE;
    int       src_lane = lane_id / 8 + lane_id % 4 * 8;
    uint      u0       = __shfl_sync(0xffffffff, value, src_lane);
    uint      u1       = __shfl_sync(0xffffffff, value, src_lane + 4);
    short2    r;

    if (lane_id % 8 < 4) {
        r.x = ((short2&)u0).x;
        r.y = ((short2&)u1).x;
    }
    else {
        r.x = ((short2&)u0).y;
        r.y = ((short2&)u1).y;
    }
    return (uint&)r;
}

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
{
#if TURBOMIND_ARCH_SM75
    uint d;
    asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
    return d;
#else
    assert(TURBOMIND_ARCH_SM75);
    return 0;
#endif
}
#endif

__inline__ __device__ uint32_t transpose_m8n8_b16(uint32_t a)
{
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
    return transpose_m8n8_b16_movmatrix(a);
#else
    return transpose_m8n8_b16_warp_shuffle(a);
#endif
}

__inline__ __device__ Array transpose_m8n8_b32(const Array& x)
{
    uint32_t lo = __byte_perm(x[0], x[1], 0x5410);
    uint32_t hi = __byte_perm(x[0], x[1], 0x7632);

    lo = transpose_m8n8_b16(lo);
    hi = transpose_m8n8_b16(hi);

    Array y;
    y[0] = __byte_perm(lo, hi, 0x5410);
    y[1] = __byte_perm(lo, hi, 0x7632);

    return y;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/common.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
#define TURBOMIND_ARCH_SM70 1
#else
#define TURBOMIND_ARCH_SM70 0
#endif

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#define TURBOMIND_ARCH_SM75 1
#else
#define TURBOMIND_ARCH_SM75 0
#endif

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define TURBOMIND_ARCH_SM80 1
#else
#define TURBOMIND_ARCH_SM80 0
#endif

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define TURBOMIND_ARCH_SM90 1
#else
#define TURBOMIND_ARCH_SM90 0
#endif

#define TURBOMIND_ARCH_HAS_BF16 TURBOMIND_ARCH_SM80

#define TURBOMIND_ARCH_HAS_FP8 TURBOMIND_ARCH_SM90

#define TURBOMIND_ARCH_BF16_GUARD(type) (TURBOMIND_ARCH_HAS_BF16 || type != ::turbomind::kBfloat16)

#define TURBOMIND_ARCH_FP8_GUARD(type)                                                                                 \
    (TURBOMIND_ARCH_HAS_FP8 || (type != ::turbomind::kFloat8_e4m3 && type != ::turbomind::kFloat8_e5m2))

#define TURBOMIND_ARCH_DTYPE_GUARD(type) (TURBOMIND_ARCH_BF16_GUARD(type) && TURBOMIND_ARCH_FP8_GUARD(type))

#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define PRAGMA_UNROLL _Pragma("unroll")
#define PRAGMA_UNROLL_4 _Pragma("unroll 4")
#define PRAGMA_NO_UNROLL _Pragma("unroll 1")

#else
#define PRAGMA_UNROLL #pragma unroll
#define PRAGMA_UNROLL_4 #pragma unroll 4
#define PRAGMA_NO_UNROLL #pragma unroll 1

#endif
#else
#define PRAGMA_UNROLL
#define PRAGMA_UNROLL_4
#define PRAGMA_NO_UNROLL
#endif

#if defined(__CUDACC__)
#define TM_HOST_DEVICE __forceinline__ __host__ __device__
#define TM_DEVICE __forceinline__ __device__
#define TM_HOST __forceinline__ __host__
#else
#define TM_HOST_DEVICE inline
#define TM_DEVICE inline
#define TM_HOST inline
#endif

constexpr int WARP_SIZE = 32;

#ifndef uint
using uint = unsigned int;
#endif

#ifndef ushort
using ushort = unsigned short int;
#endif


================================================
FILE: src/turbomind/kernels/core/data_type.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#if ENABLE_BF16
#include 
#endif

#include 

#include "src/turbomind/core/data_type.h"

namespace turbomind {

namespace detail {

struct __uint4_t {
    uint32_t x;
};

}  // namespace detail

template
struct get_pointer_type_t {
    using type = T*;
};

template
using get_pointer_type = typename get_pointer_type_t::type;

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/floating_point.h
================================================
#pragma once

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/common.h"

namespace turbomind {

template
struct FloatingPoint {
    static constexpr unsigned exponent_bits = E;
    static constexpr unsigned mantissa_bits = M;
    static constexpr unsigned exponent_bias = ((1 << exponent_bits) - 1) / 2;

    static constexpr unsigned bits = 1 + exponent_bits + mantissa_bits;

    static constexpr unsigned exponent_mask = (1 << exponent_bits) - 1;
    static constexpr unsigned mantissa_mask = (1 << mantissa_bits) - 1;

    // clang-format off
    // For `reinterpret_cast` is not constexpr yet
    static constexpr float exp2(unsigned e) { float x = 1; for (; e > 0; --e) { x *= 2; } return x; }
    // clang-format on

    static constexpr float max_normal =
        ((1U << (mantissa_bits + 1U)) - 1U) * exp2(exponent_bias + 1) / exp2(mantissa_bits);
    static constexpr float min_normal   = 1 / exp2(exponent_bias - 1);
    static constexpr float max_denormal = mantissa_mask / exp2(exponent_bias - 1 + mantissa_bits);
    static constexpr float min_denormal = 1 / exp2(exponent_bias - 1 + mantissa_bits);

    // Modified from `__nv_cvt_double_to_fp8` in 
    template
    __device__ static unsigned from_f32(float x, R rbits)
    {
        constexpr bool stochastic = std::is_same_v;

        // 1/2 LSB of the target format, positioned in single precision mantissa
        constexpr int half_ulp = 1U << (23U - mantissa_bits - 1U);

        auto absx = fabsf(x);

        unsigned xbits = __float_as_uint(x);

        unsigned sign     = (xbits >> 31U) << (bits - 1);
        unsigned exp      = ((xbits >> 23U) & 0xFFU) - 127U + exponent_bias;
        unsigned mantissa = (xbits >> (23U - mantissa_bits)) & mantissa_mask;

        unsigned res;

        if (absx <= min_denormal / 2.) {  // underflow
            res = 0;
        }
        else if (absx > max_normal) {  // overflow
            res = (exponent_mask << mantissa_bits) | mantissa_mask;
        }
        else if (absx >= min_normal) {  // normal
            res = (exp << mantissa_bits) | mantissa;

            unsigned round_mask = (half_ulp << 1U) - 1U;
            // rounded-off bits
            unsigned round = xbits & round_mask;
            if constexpr (stochastic) {
                // stochastic rounding (.rs) adjustment
                if (round + (rbits & round_mask) > round_mask) {
                    res += 1U;
                }
            }
            else {
                // round-to-nearest-even (.rn) adjustment
                if ((round > half_ulp) || ((round == half_ulp) && (mantissa & 1U))) {
                    res += 1U;
                }
            }
        }
        else {  // denormal
            unsigned shift = 1U - exp;
            // add implicit leading bit
            mantissa |= 1U << mantissa_bits;
            // additional round-off due to denormalization
            res = mantissa >> shift;

            unsigned round_mask = (half_ulp << (shift + 1U)) - 1U;
            // rounded-off bits, including implicit leading bit
            unsigned round = (xbits | (1U << 23U)) & round_mask;
            if constexpr (stochastic) {
                // stochastic rounding (.rs) adjustment
                if (round + (rbits & round_mask) > round_mask) {
                    res += 1U;
                }
            }
            else {
                // round-to-nearest-even (.rn) adjustment
                if ((round > (half_ulp << shift)) || ((round == (half_ulp << shift)) && (res & 1U))) {
                    res += 1U;
                }
            }
        }

        res |= sign;  // preserve sign

        return res;
    }

    __device__ static float to_f32(unsigned x)
    {
        unsigned u = (x >> (bits - 1U)) << 31U;
        u |= (x & ((1U << (bits - 1U)) - 1U)) << (23U - mantissa_bits);

        unsigned e = (127U - exponent_bias + 127U) << 23U;

        float res;
        /// ! force non-FTZ multiplication
        asm("mul.f32 %0, %1, %2;" : "=f"(res) : "r"(u), "r"(e));

        return res;
    }
};

static_assert(FloatingPoint<2, 1>::max_normal == 6);
static_assert(FloatingPoint<2, 1>::min_normal == 1);
static_assert(FloatingPoint<2, 1>::max_denormal == .5);
static_assert(FloatingPoint<2, 1>::min_denormal == .5);

static_assert(FloatingPoint<3, 2>::max_normal == 28.0);
static_assert(FloatingPoint<3, 2>::min_normal == 0.25);
static_assert(FloatingPoint<3, 2>::max_denormal == 0.1875);
static_assert(FloatingPoint<3, 2>::min_denormal == 0.0625);

static_assert(FloatingPoint<2, 3>::max_normal == 7.5);
static_assert(FloatingPoint<2, 3>::min_normal == 1.0);
static_assert(FloatingPoint<2, 3>::max_denormal == 0.875);
static_assert(FloatingPoint<2, 3>::min_denormal == 0.125);

// FloatingPoint<4, 3>::max_normal;
// FloatingPoint<4, 3>::min_normal;
// FloatingPoint<4, 3>::max_denormal;
// FloatingPoint<4, 3>::min_denormal;

// FloatingPoint<5, 2>::max_normal;
// FloatingPoint<5, 2>::min_normal;
// FloatingPoint<5, 2>::max_denormal;
// FloatingPoint<5, 2>::min_denormal;

#if 0
__device__ int cvt_rn_sat_e2m1_f32(float x)
{
    // 0000  0.0
    // 0001  0.5
    // 0010  1.0
    // 0011  1.5
    // 0100  2.0
    // 0101  3.0
    // 0110  4.0
    // 0111  6.0

    float z = fabs(x);
    //   0.25  0.75   1.25  1.75  2.5   3.5    5.0
    // 0.0   0.5   1.0   1.5   2.0   3.0   4.0   6.0
    // 0000  0001  0010  0011  0100  0101  0110  0111
    //   *           *           *           *
    auto f = [](float z) {
        if (z <= .25f) {
            return 0;
        }
        else if (z < .75f) {
            return 1;  // 0.5
        }
        else if (z <= 1.25f) {
            return 2;  // 1.0
        }
        else if (z < 1.75f) {
            return 3;  // 1.5
        }
        else if (z <= 2.5) {
            return 4;  // 2.0
        }
        else if (z < 3.5f) {
            return 5;  // 3.0
        }
        else if (z <= 5.f) {
            return 6;  // 4.0
        }
        else {
            return 7;  // 6.0
        }
    };

    return f(z) | ((__float_as_uint(x) >> 31) << 3);
}
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/layout.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/data_type.h"
namespace turbomind {

template
struct Swizzle {

    using bit_mask = std::integral_constant;
    using yyy_mask = std::integral_constant;
    using shift    = std::integral_constant;

    template
    __host__ __device__ constexpr static auto apply(Offset offset)
    {
        return offset ^ ((offset & yyy_mask{}) >> shift{});
    }

    template
    __host__ __device__ constexpr auto operator()(Offset offset)
    {
        return apply(offset);
    }
};

struct Identity {

    template
    __device__ constexpr static auto apply(Offset offset)
    {
        return offset;
    }

    template
    __device__ Offset operator()(Offset offset)
    {
        return apply(offset);
    }

    template
    __device__ int AdvanceS(int offset, int s0, int s1)
    {
        return offset;
    }
};

template
struct SmemLayoutV2 {

    // (C0,S0),(   C1,       S1)
    // ( 1,C0),(C0*S0, C0*S0*C1)

    static constexpr int S = S_;
    static constexpr int C = C_;

    static constexpr int S0 = S0_ < 0 ? S : S0_;
    static constexpr int C0 = C0_ < 0 ? C : C0_;

    static_assert(S % S0 == 0);
    static_assert(C % C0 == 0);

    static constexpr int S1 = S / S0;
    static constexpr int C1 = C / C0;

    static constexpr int kSize = S * C;

    static constexpr int kSize0 = S0 * C0;
    static constexpr int kSize1 = S1 * C1;

    using Swizzle = Swizzle_;

    static constexpr int kIsTrivial = S == S0 && C == C0 && std::is_same_v;

    __forceinline__ __device__ static int apply(int s, int c, int offset = 0)
    {
        int s1 = s / S0;
        int s0 = s % S0;
        int c1 = c / C0;
        int c0 = c % C0;
        //            variable             | uniform |         constant
        // return Swizzle::apply(s0 * C0 + c0) + offset + (s1 * C1 + c1) * kSize0;

        // return offset + Swizzle::apply(s0 * C0 + c0) + (s1 * C1 + c1) * kSize0;

        return Swizzle::apply(s0 * C0 + c0) + (s1 * C1 + c1) * kSize0 + offset;
    }

    __forceinline__ __device__ int operator()(int s, int c, int offset = 0)
    {
        return apply(s, c, offset);
    }
};

struct Offset {
    __device__ explicit Offset(int value): value_{value} {};
    __device__ int& operator()()
    {
        return value_;
    }
    __device__ const int& operator()() const
    {
        return value_;
    }
    int value_;
};

template
struct SmemAccessor {
    using Pointer = get_pointer_type;
    Pointer ptr_;
    Layout  layout_;

    __device__ SmemAccessor(Pointer ptr): ptr_{ptr} {}

    __device__ T& operator()(int s, int c)
    {
        return ptr_[layout_(s, c)];
    }

    __device__ T& operator()(int s, int c, int offset)
    {
        return ptr_[layout_(s, c, offset)];
    }

    __device__ T& operator()(int idx)
    {
        return ptr_[idx];
    }
};

template
struct Stride {
    T0 v0;
    T1 v1;

    // CTAD
    __host__ __device__ Stride(T0 v0, T1 v1): v0{v0}, v1{v1} {}

    template
    __host__ __device__ constexpr auto operator()(I0 i0, I1 i1) const
    {
        return v0 * i0 + v1 * i1;
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/math.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include 
#include 
#include 

namespace turbomind {

template
TM_HOST_DEVICE constexpr T ceil_div(T a, T b)
{
    return (a + b - 1) / b;
}

template
TM_HOST_DEVICE constexpr T cdiv(T a, T b)
{
    return (a + b - 1) / b;
}

template
TM_HOST_DEVICE constexpr T round_up(T a, T b)
{
    return (a + b - 1) / b * b;
}

template
TM_HOST_DEVICE constexpr T log2(T x)
{
    T n = 0;
    while (x != 1) {
        x /= 2;
        ++n;
    }
    return n;
}

// static_assert(log2(65536) == 16);
// static_assert(log2(32) == 5);
// static_assert(log2(1) == 0);

template
TM_HOST_DEVICE constexpr T lowbit(T x)
{
    const std::make_signed_t s = x;
    return static_cast(s & -s);
}

// https://arxiv.org/abs/1902.01961
template
struct FastDivMod {
};

template<>
struct FastDivMod {
    uint32_t c_;  // cdiv(2^32,d) = (2^32+d-1)/d = (2^32-1)/d+1
    uint32_t d_;

    TM_HOST_DEVICE constexpr FastDivMod(uint16_t d): c_{0xFFFFFFFF / d + 1}, d_{d} {}

    template
    TM_HOST_DEVICE friend constexpr uint16_t operator/(T a, FastDivMod b)
    {
        return (a * (uint64_t)b.c_) >> 32;
    }

    template
    TM_HOST_DEVICE friend constexpr uint16_t operator%(T a, FastDivMod b)
    {
        uint64_t lowbits = (a * (uint64_t)b.c_) & 0xFFFFFFFF;
        return (lowbits * b.d_) >> 32;
    }

    TM_HOST_DEVICE constexpr operator uint16_t() const noexcept
    {
        return d_;
    }
};

static_assert(32 / FastDivMod{5} == 6);
static_assert(32 % FastDivMod{5} == 2);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/meta.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

template
struct basic_type {
    using type = T;
};

template
constexpr basic_type type_c{};

template
struct constant {
    using type       = constant;
    using value_type = decltype(v);

    static constexpr value_type value = v;

    constexpr value_type operator()() const noexcept
    {
        return v;
    }
    constexpr operator value_type() const noexcept
    {
        return v;
    }
};

template
struct pair {
};

template
constexpr auto first(pair)
{
    return u;
}

template
constexpr auto second(pair)
{
    return v;
}

template
struct triplet {
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/mma.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include 

namespace turbomind {

__inline__ __device__ void
mma_m8n8k4_row_col(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM70
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    // clang-format off
    asm volatile(
        "mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32"
        "{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10, %11},"
        "{%12, %13, %14, %15, %16, %17, %18, %19};"
        : "=f"(d[0]), "=f"(d[1]), "=f"(d[2]), "=f"(d[3]), "=f"(d[4]), "=f"(d[5]), "=f"(d[6]), "=f"(d[7])
        : "r"(A[0]), "r"(A[1]),
          "r"(B[0]), "r"(B[1]),
          "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]));
// clang-format on
#endif
}

__inline__ __device__ void
mma_m8n8k4_row_row(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM70
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    // clang-format off
    asm volatile(
        "mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32"
        "{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10, %11},"
        "{%12, %13, %14, %15, %16, %17, %18, %19};"
        : "=f"(d[0]), "=f"(d[1]), "=f"(d[2]), "=f"(d[3]), "=f"(d[4]), "=f"(d[5]), "=f"(d[6]), "=f"(d[7])
        : "r"(A[0]), "r"(A[1]),
          "r"(B[0]), "r"(B[1]),
          "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]));
// clang-format on
#endif
}

__inline__ __device__ void
mma_m16n8k8_row_col(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM75
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    float const*    C = reinterpret_cast(&c);
    float*          D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, "
                 "{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
                 : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
                 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void
mma_m16n8k8_row_col(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM75
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    uint32_t const* C = reinterpret_cast(&c);
    uint32_t*       D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16  {%0,%1}, "
                 "{%2,%3}, {%4}, {%5,%6};\n"
                 : "=r"(D[0]), "=r"(D[1])
                 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void mma_m16n8k8_row_col(Array&             d,
                                               const Array& a,
                                               const Array& b,
                                               Array&             c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    float const*    C = reinterpret_cast(&c);
    float*          D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32  {%0,%1,%2,%3}, "
                 "{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
                 : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
                 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
    assert(TURBOMIND_ARCH_SM80);
#endif
}

__inline__ __device__ void mma_m16n8k8_row_col(Array&       d,
                                               const Array& a,
                                               const Array& b,
                                               Array&       c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    uint32_t const* C = reinterpret_cast(&c);
    uint32_t*       D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k8.row.col.bf16.bf16.bf16.bf16  {%0,%1}, "
                 "{%2,%3}, {%4}, {%5,%6};\n"
                 : "=r"(D[0]), "=r"(D[1])
                 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
#else
    assert(TURBOMIND_ARCH_SM80);
#endif
}

__inline__ __device__ void
mma_m16n8k16_row_col(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    float const*    C = reinterpret_cast(&c);
    float*          D = reinterpret_cast(&d);
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, "
        "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
    const Array* _a = (const Array*)&a;
    const Array* _b = (const Array*)&b;
    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}

__inline__ __device__ void
mma_m16n8k16_row_col(Array& d, const Array& a, const Array& b, Array& c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    uint32_t const* C = reinterpret_cast(&c);
    uint32_t*       D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16  {%0,%1}, "
                 "{%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
                 : "=r"(D[0]), "=r"(D[1])
                 : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
#else
    const Array* _a = (const Array*)&a;
    const Array* _b = (const Array*)&b;
    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}

__inline__ __device__ void mma_m16n8k16_row_col(Array&             d,
                                                const Array& a,
                                                const Array& b,
                                                Array&             c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    float const*    C = reinterpret_cast(&c);
    float*          D = reinterpret_cast(&d);
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32  {%0,%1,%2,%3}, "
        "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
    const Array* _a = (const Array*)&a;
    const Array* _b = (const Array*)&b;
    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}

__inline__ __device__ void mma_m16n8k16_row_col(Array&       d,
                                                const Array& a,
                                                const Array& b,
                                                Array&       c)
{
#if TURBOMIND_ARCH_SM80
    uint32_t const* A = reinterpret_cast(&a);
    uint32_t const* B = reinterpret_cast(&b);
    uint32_t const* C = reinterpret_cast(&c);
    uint32_t*       D = reinterpret_cast(&d);
    asm volatile("mma.sync.aligned.m16n8k16.row.col.bf16.bf16.bf16.bf16  {%0,%1}, "
                 "{%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
                 : "=r"(D[0]), "=r"(D[1])
                 : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
#else
    const Array* _a = (const Array*)&a;
    const Array* _b = (const Array*)&b;
    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/pipe_iter.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

template
struct PipeIter {
    static constexpr int kMaxStep = Stages * Step;

    int r = 0;
    int w = kMaxStep - Step;

    __inline__ __device__ PipeIter& operator++()
    {
        w = r;
        r += Step;
        if (r == kMaxStep) {
            r -= kMaxStep;
        }
        return *this;
    }
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/smem.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include 

namespace turbomind {

__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
    return (uint32_t)__cvta_generic_to_shared(ptr);
}

__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
                 : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
                 : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldsm_x4_trans(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
                 : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
                 : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldsm_x2_trans(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0,%1}, [%2];\n"
                 : "=r"(d0), "=r"(d1)
                 : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldmatrix_m8n8_x1_b16(uint& d0, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 %0, [%1];\n" : "=r"(d0) : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldsm_x1_trans(uint& d0, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
    asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 %0, [%1];\n" : "=r"(d0) : "r"(smem_int_ptr));
#else
    assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void ldsm_x4(Array& d, uint32_t smem_int_ptr)
{
    ldmatrix_m8n8_x4_b16(d[0], d[1], d[2], d[3], smem_int_ptr);
}

__inline__ __device__ void ldsm_x2(Array& d, uint32_t smem_int_ptr)
{
    ldmatrix_m8n8_x2_b16(d[0], d[1], smem_int_ptr);
}

__inline__ __device__ void ldsm_x1(Array& d, uint32_t smem_int_ptr)
{
    ldmatrix_m8n8_x1_b16(d[0], smem_int_ptr);
}

__inline__ __device__ void ldsm_x4_trans(Array& d, uint32_t smem_int_ptr)
{
    ldsm_x4_trans(d[0], d[1], d[2], d[3], smem_int_ptr);
}

__inline__ __device__ void ldsm_x2_trans(Array& d, uint32_t smem_int_ptr)
{
    ldsm_x2_trans(d[0], d[1], smem_int_ptr);
}

__inline__ __device__ void ldsm_x1_trans(Array& d, uint32_t smem_int_ptr)
{
    ldsm_x1_trans(d[0], smem_int_ptr);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/sub_byte_ptr.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/data_type.h"

namespace turbomind {

template
struct SubBytePtr {

    constexpr SubBytePtr() = default;

    constexpr __host__ __device__ explicit SubBytePtr(T* ptr): ptr_((char*)ptr) {}

    constexpr __host__ __device__ SubBytePtr(char* ptr): ptr_(ptr) {}

    __host__ __device__ T& operator[](int i)
    {
        return *reinterpret_cast(ptr_ + i * bitsof / bitsof);
    }

    friend __host__ __device__ SubBytePtr operator+(const SubBytePtr a, int n)
    {
        return SubBytePtr{a.ptr_ + n * bitsof / bitsof};
    }

    friend __host__ __device__ SubBytePtr operator+(int n, const SubBytePtr a)
    {
        return a + n;
    }

    friend __host__ __device__ bool operator==(const SubBytePtr& a, const SubBytePtr& b)
    {
        return a.ptr_ == b.ptr_;
    }

    __host__ __device__ explicit operator T*() const
    {
        return (T*)ptr_;
    }

    char* ptr_;
};

template
struct get_pointer_type_t % 8 != 0>> {
    using type = SubBytePtr;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/sync.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

__inline__ __device__ int sem_fetch(int* lock, bool pred)
{
    int state{};
    if (pred) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
        asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
        asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
    }
    return state;
}

__inline__ __device__ void sem_wait(int* lock, int status, bool pred)
{
    int state = 0;
    while (__syncthreads_and(state != status)) {
        state = sem_fetch(lock, pred);
    }

    __syncthreads();  // memory fence
}

__inline__ __device__ void sem_wait_many(int* lock, int count, bool pred)
{
    int state = 0;
    while (__syncthreads_count(state) != count) {
        state = sem_fetch(lock, pred);
    }

    __syncthreads();  // memory fence
}

__inline__ __device__ void sem_post(int* lock, int status, bool pred)
{
    __syncthreads();  // memory fence

    if (pred) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
        asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
        asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/core/thread_map.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"

#include 

namespace turbomind {

template
struct ThreadMapQ {
    static constexpr int kWarpCount = WarpCount;
    static constexpr int kAccessC   = AccessC;

    static constexpr int kWarpThreadC = C / kAccessC;
    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;

    static_assert(kWarpThreadC <= WARP_SIZE);

    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;  // C
    static constexpr int kWarpAccessS = kWarpThreadS;

    static constexpr int kWarpIterC = C / kWarpAccessC;  // 1
    static constexpr int kWarpIterS = S / kWarpAccessS;

    static constexpr int kWarpC = 1;
    static constexpr int kWarpS = kWarpCount;

    static constexpr int kIterC = kWarpIterC / kWarpC;  // 1
    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);

    static constexpr int kFootprintC = kWarpAccessC * kIterC;  // C
    static constexpr int kFootprintS = kWarpAccessS * kIterS;

    static constexpr int kDeltaC = kWarpAccessC;
    static constexpr int kDeltaS = kWarpAccessS;

    __device__ static int2 get_offset(int warp_id, int lane_id)
    {
        int warp_offset_c = warp_id % kWarpC;
        int warp_offset_s = warp_id / kWarpC;

        int warp_thread_offset_c = lane_id % kWarpThreadC;
        int warp_thread_offset_s = lane_id / kWarpThreadC;

        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;

        return {cta_thread_offset_c, cta_thread_offset_s};
    }
};

template
struct RakedThreadMap {
    static constexpr int kDimC = DimC;
    static constexpr int kDimS = DimS;

    static constexpr int kWarpCount = WarpCount;
    static constexpr int kAccessC   = AccessC;

    static constexpr int kWarpThreadC = WarpThreadC;
    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;

    static_assert(WARP_SIZE % kWarpThreadC == 0);

    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
    static constexpr int kWarpAccessS = kWarpThreadS;

    static constexpr int kWarpIterC = cdiv(kDimC, kWarpAccessC);
    static constexpr int kWarpIterS = cdiv(kDimS, kWarpAccessS);

    static constexpr int kWarpC = WarpC;
    static constexpr int kWarpS = kWarpCount / kWarpC;

    static_assert(kWarpCount % kWarpC == 0);

    static constexpr int kIterC = cdiv(kWarpIterC, kWarpC);
    static constexpr int kIterS = cdiv(kWarpIterS, kWarpS);

    // Allow partial tile when there is ONLY 1 iteration
    static_assert(kDimC % kWarpAccessC == 0 || kIterC == 1);

    static constexpr bool kPartialC = kDimC % kWarpAccessC != 0;

    static constexpr int kFootprintC = kWarpAccessC * kIterC;
    static constexpr int kFootprintS = kWarpAccessS * kIterS;

    static constexpr int kDeltaC = kWarpAccessC;
    static constexpr int kDeltaS = kWarpAccessS;

    // static constexpr int kDeltaC = kWarpAccessC * kWarpC;
    // static constexpr int kDeltaS = kWarpAccessS * kWarpS;

    __device__ static int2 get_offset(int warp_id, int lane_id)
    {
        int warp_offset_c = warp_id % kWarpC;
        int warp_offset_s = warp_id / kWarpC;

        int warp_thread_offset_c = lane_id % kWarpThreadC;
        int warp_thread_offset_s = lane_id / kWarpThreadC;

        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;

        // int cta_thread_offset_c = kWarpAccessC * warp_offset_c + warp_thread_offset_c * kAccessC;
        // int cta_thread_offset_s = kWarpAccessS * warp_offset_s + warp_thread_offset_s;

        return {cta_thread_offset_c, cta_thread_offset_s};
    }
};

namespace {

template
void Print(TMap)
{
    std::cout << "     warps: " << TMap::kWarpCount << "\n";
    std::cout << "     shape: (" << TMap::kDimC << ", " << TMap::kDimS << ")\n";
    std::cout << "    access: (" << TMap::kAccessC << ", " << 1 << ")\n";
    std::cout << "warpThread: (" << TMap::kWarpThreadC << ", " << TMap::kWarpThreadS << ")\n";
    std::cout << "warpAccess: (" << TMap::kWarpAccessC << ", " << TMap::kWarpAccessS << ")\n";
    std::cout << "  warpIter: (" << TMap::kWarpIterC << ", " << TMap::kWarpIterS << ")\n";
    std::cout << "      warp: (" << TMap::kWarpC << ", " << TMap::kWarpS << ")\n";
    std::cout << "      iter: (" << TMap::kIterC << ", " << TMap::kIterS << ")\n";
    std::cout << " footprint: (" << TMap::kFootprintC << ", " << TMap::kFootprintS << ")\n";
    std::cout << "     delta: (" << TMap::kDeltaC << ", " << TMap::kDeltaS << ")\n";
    std::cout << "  partialC: " << TMap::kPartialC << "\n";
}

}  // namespace

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/decoding_kernels.cu
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template
__global__ void embeddingLookupPosEncoding(T*            from_tensor,
                                           const T*      embedding_table,
                                           const T*      position_encoding,
                                           const int*    all_ids,
                                           const int*    padding_count,
                                           const int*    input_lengths,
                                           const int     local_token_num,
                                           const int64_t hidden_units,
                                           const int     step,
                                           const int     max_input_length,
                                           const int     token_num,
                                           const int     ite,
                                           const T       scale)
{
    // 1. lookup from embedding table
    // 2. multiply scale
    // 3. add the position encoding
    const int id_offset = step * token_num + ite * local_token_num;

    const bool use_padding_count = padding_count != nullptr;
    const bool use_input_len     = input_lengths != nullptr;

    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
         index += blockDim.x * gridDim.x) {
        const int row_index   = index / hidden_units;
        const int col_index   = index % hidden_units;
        int       step_offset = step;
        if (use_padding_count) {
            step_offset -= padding_count[row_index];
        }
        else if (use_input_len) {
            step_offset -= max_input_length - input_lengths[row_index];
        }
        step_offset *= hidden_units;

        T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale;
        val   = val + position_encoding[step_offset + col_index];

        from_tensor[index] = val;
    }
}

// No absolute position embedding
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template
__global__ void embeddingLookup(T*                    from_tensor,
                                const T*              embedding_table,
                                const int*            all_ids,
                                pPromptTuningParam prompt_param,
                                const int             local_token_num,
                                const int64_t         hidden_units,
                                const int             step,
                                const int             token_num,
                                const int             ite,
                                const int             seq_len,
                                const T               scale)
{
    // 1. lookup from embedding table
    // 2. multiply scale
    const int id_offset = step * token_num + ite * local_token_num;

    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
         index += blockDim.x * gridDim.x) {

        const int word_index     = index / hidden_units;
        const int word_index_row = word_index / seq_len;  // batch_id
        const int col_index      = index % hidden_units;
        const int input_id       = all_ids == nullptr ? word_index : all_ids[id_offset + word_index];
        const int prompt_id      = input_id - prompt_param.p_prompt_tuning_id_start;
        T         embedding      = (T)0.0f;
        if (PROMPT_SRC > 0 && prompt_id >= 0) {
            if (PROMPT_SRC == 1) {
                // from loaded prompt embedding tables
                embedding =
                    prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];
            }
            else {
                // from request prompt embedding
                embedding =
                    prompt_param
                        .request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units
                                                  + prompt_id * hidden_units + col_index];
            }
        }
        else {
            embedding = embedding_table[input_id * hidden_units + col_index];
        }
        from_tensor[index] = embedding * scale;
    }
}

#define EMBEDDING_LOOKUP(PROMPT_SRC)                                                                                   \
    embeddingLookup<<>>(from_tensor,                                            \
                                                               embedding_table,                                        \
                                                               all_ids,                                                \
                                                               prompt_param,                                           \
                                                               local_token_num,                                        \
                                                               hidden_units,                                           \
                                                               step,                                                   \
                                                               token_num,                                              \
                                                               ite,                                                    \
                                                               seq_len,                                                \
                                                               scale);

/* Adapter function for invokeEmbeddingLookupPosEncoding{PadCount,InputLen} */
template
void invokeEmbeddingLookupPosEncoding(T*                    from_tensor,
                                      const T*              embedding_table,
                                      const T*              position_encoding,
                                      const int*            all_ids,
                                      const int*            padding_count,
                                      const int*            input_lengths,
                                      pPromptTuningParam prompt_param,
                                      const int             local_token_num,
                                      const int             hidden_units,
                                      const T               scale,
                                      const int             step,
                                      const int             max_input_length,
                                      const int             token_num,
                                      const int             ite,
                                      const int             seq_len,
                                      cudaStream_t          stream)
{
    dim3 grid(min(local_token_num, 65536));
    dim3 block(min(hidden_units, 1024));
    if (position_encoding != nullptr) {
        FT_CHECK_WITH_INFO(prompt_param.use_request_p_prompt_embedding == false
                               && prompt_param.p_prompt_tuning_batch_weights == nullptr,
                           fmtstr("embeddingLookupPosEncoding still not support prompt tuning"));
        embeddingLookupPosEncoding<<>>(from_tensor,
                                                                  embedding_table,
                                                                  position_encoding,
                                                                  all_ids,
                                                                  padding_count,
                                                                  input_lengths,
                                                                  local_token_num,
                                                                  hidden_units,
                                                                  step,
                                                                  max_input_length,
                                                                  token_num,
                                                                  ite,
                                                                  scale);
    }
    else {
        if (prompt_param.use_request_p_prompt_embedding) {
            EMBEDDING_LOOKUP(2);
        }
        else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
            EMBEDDING_LOOKUP(1);
        }
        else {
            EMBEDDING_LOOKUP(0);
        }
    }
}

#undef EMBEDDING_LOOKUP

template
void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,
                                              const T*              embedding_table,
                                              const T*              position_encoding,
                                              const int*            all_ids,
                                              const int*            pad_count,
                                              pPromptTuningParam prompt_param,
                                              const int             local_token_num,
                                              const int             hidden_units,
                                              const T               scale,
                                              const int             step,
                                              const int             token_num,
                                              const int             ite,
                                              const int             seq_len,
                                              cudaStream_t          stream)
{
    invokeEmbeddingLookupPosEncoding(from_tensor,
                                        embedding_table,
                                        position_encoding,
                                        all_ids,
                                        pad_count,
                                        nullptr,
                                        prompt_param,
                                        local_token_num,
                                        hidden_units,
                                        scale,
                                        step,
                                        0,
                                        token_num,
                                        ite,
                                        seq_len,
                                        stream);
}

#define INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(T)                                                                   \
    template void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,                          \
                                                           const T*              embedding_table,                      \
                                                           const T*              position_encoding,                    \
                                                           const int*            all_ids,                              \
                                                           const int*            pad_count,                            \
                                                           pPromptTuningParam prompt_param,                         \
                                                           const int             local_token_num,                      \
                                                           const int             hidden_units,                         \
                                                           const T               scale,                                \
                                                           const int             step,                                 \
                                                           const int             token_num,                            \
                                                           const int             ite,                                  \
                                                           const int             seq_len,                              \
                                                           cudaStream_t          stream)
#ifdef ENABLE_FP32
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(float);
#endif
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(half);
#ifdef ENABLE_BF16
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16);
#endif
#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT

template
__global__ void paddingEmbedding(T*            padded_embedding_kernel,
                                 T*            padded_embedding_bias,
                                 const T*      embedding_kernel,
                                 const T*      embedding_bias,
                                 const int64_t hidden_unit,
                                 const int64_t vocab_size,
                                 const int64_t vocab_size_padded)
{
    for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
         id += blockDim.x * gridDim.x) {
        int row_id = id / vocab_size_padded;
        int col_id = id % vocab_size_padded;
        if (col_id < vocab_size) {
            padded_embedding_kernel[id] = embedding_kernel[row_id * vocab_size + col_id];
        }
        else {
            padded_embedding_kernel[id] = (T)(0.0f);
        }
    }

    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < vocab_size_padded; id += blockDim.x * gridDim.x) {
        if (id < vocab_size) {
            padded_embedding_bias[id] = embedding_bias[id];
        }
        else {
            padded_embedding_bias[id] = (T)(0.0f);
        }
    }
}

template
void invokePaddingEmbedding(T*           padded_embedding_kernel,
                            T*           padded_embedding_bias,
                            const T*     embedding_kernel,
                            const T*     embedding_bias,
                            const int    hidden_unit,
                            const int    vocab_size,
                            const int    vocab_size_padded,
                            cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
    paddingEmbedding<<>>(padded_embedding_kernel,
                                                 padded_embedding_bias,
                                                 embedding_kernel,
                                                 embedding_bias,
                                                 hidden_unit,
                                                 vocab_size,
                                                 vocab_size_padded);
}

// template void invokePaddingEmbedding(float*       padded_embedding_kernel,
//                                      float*       padded_embedding_bias,
//                                      const float* embedding_kernel,
//                                      const float* embedding_bias,
//                                      const int    hidden_unit,
//                                      const int    vocab_size,
//                                      const int    vocab_size_padded,
//                                      cudaStream_t stream);

// template void invokePaddingEmbedding(half*        padded_embedding_kernel,
//                                      half*        padded_embedding_bias,
//                                      const half*  embedding_kernel,
//                                      const half*  embedding_bias,
//                                      const int    hidden_unit,
//                                      const int    vocab_size,
//                                      const int    vocab_size_padded,
//                                      cudaStream_t stream);
// #ifdef ENABLE_BF16
// template void invokePaddingEmbedding(__nv_bfloat16*       padded_embedding_kernel,
//                                      __nv_bfloat16*       padded_embedding_bias,
//                                      const __nv_bfloat16* embedding_kernel,
//                                      const __nv_bfloat16* embedding_bias,
//                                      const int            hidden_unit,
//                                      const int            vocab_size,
//                                      const int            vocab_size_padded,
//                                      cudaStream_t         stream);
// #endif

template
__global__ void paddingEmbeddingKernel(T*        padded_embedding_kernel,
                                       const T*  embedding_kernel,
                                       const int hidden_unit,
                                       const int vocab_size,
                                       const int vocab_size_padded)
{
    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
         id += blockDim.x * gridDim.x) {
        int row_id = id / hidden_unit;
        int col_id = id % hidden_unit;
        if (row_id < vocab_size) {
            padded_embedding_kernel[id] = embedding_kernel[row_id * hidden_unit + col_id];
        }
        else {
            padded_embedding_kernel[id] = (T)(0.0f);
        }
    }
}

template
void invokePaddingEmbeddingKernel(T*           padded_embedding_kernel,
                                  const T*     embedding_kernel,
                                  const int    hidden_unit,
                                  const int    vocab_size,
                                  const int    vocab_size_padded,
                                  cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
    paddingEmbeddingKernel<<>>(
        padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded);
}

// template void invokePaddingEmbeddingKernel(float*       padded_embedding_kernel,
//                                            const float* embedding_kernel,
//                                            const int    hidden_unit,
//                                            const int    vocab_size,
//                                            const int    vocab_size_padded,
//                                            cudaStream_t stream);

// template void invokePaddingEmbeddingKernel(half*        padded_embedding_kernel,
//                                            const half*  embedding_kernel,
//                                            const int    hidden_unit,
//                                            const int    vocab_size,
//                                            const int    vocab_size_padded,
//                                            cudaStream_t stream);

// #ifdef ENABLE_BF16
// template void invokePaddingEmbeddingKernel(__nv_bfloat16*       padded_embedding_kernel,
//                                            const __nv_bfloat16* embedding_kernel,
//                                            const int            hidden_unit,
//                                            const int            vocab_size,
//                                            const int            vocab_size_padded,
//                                            cudaStream_t         stream);
// #endif

template
__global__ void plusScalar(T* buf, const T val, const int size)
{
    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) {
        buf[i] += val;
    }
}

template
void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream)
{
    dim3 block(min(256, size));
    dim3 grid(ceil(size / 256.));
    plusScalar<<>>(buf, val, size);
}

template void invokePlusScalar(int* buf, const int val, const int size, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/decoding_kernels.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "gpt_kernels.h"
#include 
#include 

namespace turbomind {

// get token from all_ids at step, then lookup from the embedding table
// by the token
template
void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,
                                              const T*              embedding_table,
                                              const T*              position_encoding,
                                              const int*            all_ids,
                                              const int*            padding_count,
                                              pPromptTuningParam prompt_param,
                                              const int             local_token_num,
                                              const int             hidden_units,
                                              const T               scale,
                                              const int             step,
                                              const int             token_num,
                                              const int             ite,
                                              const int             seq_len,
                                              cudaStream_t          stream);

template
void invokeEmbeddingLookupPosEncodingPadCount(T*           from_tensor,
                                              const T*     embedding_table,
                                              const T*     position_encoding,
                                              const int*   all_ids,
                                              const int*   padding_count,
                                              const int    local_token_num,
                                              const int    hidden_units,
                                              const T      scale,
                                              const int    step,
                                              const int    token_num,
                                              const int    ite,
                                              cudaStream_t stream)
{
    invokeEmbeddingLookupPosEncodingPadCount(from_tensor,
                                             embedding_table,
                                             position_encoding,
                                             all_ids,
                                             padding_count,
                                             {(const T**)nullptr, 0, 0, false, nullptr},
                                             local_token_num,
                                             hidden_units,
                                             scale,
                                             step,
                                             token_num,
                                             ite,
                                             0,
                                             stream);
}

template
void invokePaddingEmbedding(T*           padded_embedding_kernel,
                            T*           padded_embedding_bias,
                            const T*     embedding_kernel,
                            const T*     embedding_bias,
                            const int    hidden_unit,
                            const int    vocab_size,
                            const int    vocab_size_padded,
                            cudaStream_t stream);

template
void invokePaddingEmbeddingKernel(T*           padded_embedding_kernel,
                                  const T*     embedding_kernel,
                                  const int    hidden_unit,
                                  const int    vocab_size,
                                  const int    vocab_size_padded,
                                  cudaStream_t stream);

template
void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

add_library(gemm2
        gemm.cu
        kernel.cu
        registry.cu
        dispatch_cache.cu
        gpu_metric.cu
        convert_v3.cu
        cast.cu
        unpack.cu
        context.cu
        tma.cu
        tuner/cache_utils.cu
        tuner/measurer.cu
        tuner/sampler.cu
        tuner/stopping_criterion.cc
        tuner/params.cc
        kernel/sm90_16816_4.cu
        kernel/sm90_16816_8.cu
        kernel/sm90_16816_16.cu
        kernel/sm80_16816_4.cu
        kernel/sm80_16816_8.cu
        kernel/sm80_16816_16.cu
        kernel/sm75_16816_4.cu
        kernel/sm75_16816_8.cu
        kernel/sm75_16816_16.cu
        kernel/sm70_884_4.cu
        kernel/sm70_884_8.cu
        kernel/sm70_884_16.cu
        kernel/sm90_64n32_8.cu
        cublas.cu
        moe_utils_v2.cu
        test/test_utils.cu
)

target_link_libraries(gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)


target_compile_definitions(gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)

target_compile_options(gemm2 PRIVATE
        $<$:
                -Xptxas=-v
                --generate-line-info
                --threads 16>
)
set_property(TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)



if (BUILD_TEST)
        add_executable(test_gemm_v2
                test/test_gemm_v2.cc
                ../../models/llama/LlamaLinear.cu
                ../../models/llama/LlamaDenseWeight.cc
                test/reference.cu)
        target_link_libraries(test_gemm_v2 PRIVATE gemm2 core cublas quantization_kernels gpt_kernels)

        add_executable(test_moe_utils test/test_moe_utils.cu test/test_utils.cu)
        target_link_libraries(test_moe_utils PRIVATE gemm2 core cublas)

        # if (NOT MSVC)
        #         FetchContent_Declare(
        #         repo-nvbench
        #         GIT_REPOSITORY https://github.com/NVIDIA/nvbench.git
        #         GIT_TAG        d8dced8a64d9ce305add92fa6d274fd49b569b7e
        #         )

        #         set(NVBench_ENABLE_EXAMPLES OFF)
        #         set(NVBench_ENABLE_TESTING OFF)
        #         set(BUILD_SHARED_LIBS OFF)

        #         FetchContent_MakeAvailable(repo-nvbench)

        #         add_executable(gemm_bench
        #                 test/gemm_bench.cu
        #                 # test/test_utils.cu
        #                 test/quantization.cu
        #                 test/reference.cu)
        #         target_link_libraries(gemm_bench PRIVATE gemm2 core nvbench::nvbench cublas)
        # endif ()
endif ()


================================================
FILE: src/turbomind/kernels/gemm/arch/config_simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/mma_simt.h"
#include "src/turbomind/kernels/gemm/arch/operand_simt.h"
#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/gemm_universal.h"
#include "src/turbomind/kernels/gemm/iterator_sm70.h"
#include "src/turbomind/kernels/gemm/mainloop_sm70.h"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/tiled_mma.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

namespace simt {

template
struct Sm75_Simt {

    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);

    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_K = A::SmemCopyAtom::K;

    template
    struct Type {

        // (TM, TN, TK) = R(MMA_Atom, SmemCopy_Atom)
        using MMA_Atom = MMA_SIMT;

        static constexpr int TM = MMA_Atom::M;
        static constexpr int TN = MMA_Atom::N;
        static constexpr int TK = MMA_Atom::K;

        using Partition = Blocked;

        using MMA_Map = MMA_Map;
        using MMA     = Tiled_MMA_v2;

        // using MMA_Map = RakedThreadGroupMap;

        using Mainloop = MainloopSm70,
                                      TransformA,
                                      U,
                                      GroupSizeU,
                                      B,
                                      IteratorSm70,
                                      TransformB,
                                      V,
                                      GroupSizeV,
                                      Stages,
                                      true>;

        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;
        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;

        using Epilogue = gemm::Epilogue_,
                                         Operand_C,
                                         mode_C,
                                         SplitK>;

        using Kernel = GemmUniversal;
    };
};

}  // namespace simt

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/config_sm70_s884.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/mma_sm70.h"
#include "src/turbomind/kernels/gemm/arch/operand_sm70_s884.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/gemm_universal.h"
#include "src/turbomind/kernels/gemm/iterator_sm70.h"
#include "src/turbomind/kernels/gemm/mainloop_sm70.h"
#include "src/turbomind/kernels/gemm/scheduler_sm70.cuh"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/tiled_mma.h"
#include "src/turbomind/kernels/gemm/transform.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm::sm70_s884 {

template
struct Sm70_s884 {

    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);

    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_K = A::SmemCopyAtom::K;

    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;

    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_C = MODE_;

    template
    struct Type {

        // (TM, TN, TK) = R(MMA_Atom, SmemCopy_Atom)
        using MMA_Atom = SM70_MMA_884;

        using Partition = Blocked;
        using MMA_Map   = MMA_Map;

        using MMA = Tiled_MMA_v2;

        using Mainloop = MainloopSm70,
                                      TransformA,
                                      U,
                                      GroupSizeU,
                                      B,
                                      IteratorSm70,
                                      TransformB,
                                      V,
                                      GroupSizeV,
                                      Stages,
                                      true>;  // FusePrefetch_

        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);

        using Scheduler = SchedulerSm70;

        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;
        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;

        using Epilogue = gemm::Epilogue_,
                                         Operand_C,
                                         MODE_C,
                                         SplitK>;

        using Kernel = GemmUniversal;
    };
};

template
using Config_U4_d = Sm70_s884::Operand,
                              Transform_Default,
                              VoidOperand,
                              typename GetOperand::Operand,
                              Transform_HMMA_SIMT_B,
                              typename GetOperand::Operand,
                              kRowMajor,
                              half,
                              raster_order,
                              -1>;

template
using Config_U4_g = Sm70_s884,           // A
                              Transform_Default,         // tarnsform A
                              VoidOperand,               // U
                              Operand_B_Pack,   // B
                              Transform_HMMA_SIMT_B,     // transform B,
                              Operand_V_Pack,  // V
                              kRowMajor,                 // order_C
                              half,                      // Tc
                              raster_order,
                              0>;

template
using Config_MXF4 = Sm70_s884,             // A
                              Transform_Default,           // tarnsform A
                              VoidOperand,                 // U
                              Operand_B_Pack,  // B
                              Transform_HMMA_SIMT_B,       // transform B,
                              Operand_V_Pack,     // V
                              kRowMajor,                   // order_C
                              half,                        // Tc
                              raster_order,
                              group_axis>;

template
using Config_E4M3 = Sm70_s884,             // A
                              Transform_Default,           // tarnsform A
                              VoidOperand,                 // U
                              Operand_B_Pack,  // B
                              Transform_HMMA_SIMT_B,       // transform B,
                              Operand_V_Pack,    // V
                              kRowMajor,                   // order_C
                              half,                        // Tc
                              raster_order,
                              group_axis>;

template
using Config_F16 = Sm70_s884,       // A
                             Transform_Default,     // tarnsform A
                             VoidOperand,           // U
                             Operand_B_Pack,  // B
                             Transform_Default,     // transform B
                             VoidOperand,           // V
                             kRowMajor,             // order_C
                             half,                  // Tc
                             raster_order,
                             group_axis>;

}  // namespace turbomind::gemm::sm70_s884


================================================
FILE: src/turbomind/kernels/gemm/arch/config_sm75_s16816.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/mma_sm80.h"
#include "src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/gemm_universal.h"
#include "src/turbomind/kernels/gemm/iterator_sm70.h"
#include "src/turbomind/kernels/gemm/mainloop_sm70.h"
#include "src/turbomind/kernels/gemm/scheduler_sm70.cuh"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/tiled_mma.h"
#include "src/turbomind/kernels/gemm/transform.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

namespace sm75_s16816 {

using namespace sm80_s16816;

template
struct Sm75_s16816 {

    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);

    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_K = A::SmemCopyAtom::K;

    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;

    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_C = MODE_;

    template
    struct Type {
        // Raked partition dont support `Pack_M > 1`
        using Partition = Blocked;
        using MMA_Map   = MMA_Map;
        using MMA       = Tiled_MMA_v2, MMA_Map, mma_iter_order>;

        using Mainloop = MainloopSm70,
                                      TransformA,
                                      U,
                                      GroupSizeU,
                                      B,
                                      IteratorSm70,
                                      TransformB,
                                      V,
                                      GroupSizeV,
                                      Stages,
                                      true>;  // FusePrefetch_

        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);

        using Scheduler = SchedulerSm70;

        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;
        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;

        using Epilogue = gemm::Epilogue_,
                                         Operand_C,
                                         MODE_C,
                                         SplitK>;

        using Kernel = GemmUniversal;
    };
};

// mma_iter_order has no effect yet

template  // kColMajor
using Config_U4_d = Sm75_s16816,
                                Transform_Default,
                                VoidOperand,
                                Operand_B_Pack,
                                Transform_HMMA_16816<1, 0>,
                                Operand_UV_Pack,
                                kRowMajor,
                                half,
                                raster_order,
                                -1>;

template  // kColMajor
using Config_U4_g = Sm75_s16816,             // A
                                Transform_Default,                      // tarnsform A
                                VoidOperand,                            // U
                                Operand_B_Pack,  // B
                                Transform_HMMA_16816<1, 0>,             // transform B,
                                Operand_UV_Pack,        // V
                                kRowMajor,                              // order_C
                                half,                                   // Tc
                                raster_order,
                                0>;

template
using Config_MXF4 = Sm75_s16816,  // A
                                Transform_HMMA_16816<0, 1>,                // tarnsform A
                                Operand_UV_Pack,           // U
                                Operand_B,              // B
                                Transform_Default,                         // transform B
                                VoidOperand,                               // V
                                kColMajor,                                 // order_C
                                half_t,                                    // Tc
                                raster_order,
                                group_axis>;

template
using Config_E4M3 = Sm75_s16816,  // A
                                Transform_HMMA_16816<0, 1>,                // tarnsform A
                                Operand_UV_Pack,          // U
                                Operand_B,              // B
                                Transform_Default,                         // transform B
                                VoidOperand,                               // V
                                kColMajor,                                 // order_C
                                half_t,                                    // Tc
                                raster_order,
                                group_axis>;

template
using Config_F16 = Sm75_s16816,          // A
                               Transform_Default,                   // tarnsform A
                               VoidOperand,                         // U
                               Operand_B_Pack,  // B
                               Transform_Default,                   // transform B
                               VoidOperand,                         // V
                               kRowMajor,                           // order_C
                               half,                                // Tc
                               raster_order,
                               group_axis>;

}  // namespace sm75_s16816

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/config_sm80_s16816.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/mma_sm80.h"
#include "src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/gemm_universal.h"
#include "src/turbomind/kernels/gemm/iterator_sm80.h"
#include "src/turbomind/kernels/gemm/mainloop_sm80_v2.h"
#include "src/turbomind/kernels/gemm/scheduler_sm70.cuh"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/tiled_mma.h"
#include "src/turbomind/kernels/gemm/transform.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm::sm80_s16816 {

template
struct Sm80_s16816 {

    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);

    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;
    static constexpr int SMEM_K = A::SmemCopyAtom::K;

    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;

    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;
    static constexpr auto MODE_C = MODE_;

    template

    struct Type {

        // Raked partition dont support `Pack_M > 1`
        using Partition = Blocked;
        using MMA_Map   = MMA_Map;
        using MMA       = Tiled_MMA_v2, MMA_Map, mma_iter_order>;

        using Mainloop = MainloopSm80_v2,
                                         TransformA,
                                         U,
                                         GroupSizeU,
                                         B,
                                         IteratorSm80,
                                         TransformB,
                                         V,
                                         GroupSizeV,
                                         Stages,
                                         FusePrefecth>;

        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);

        using Scheduler = SchedulerSm70;

        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;
        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;

        using Epilogue = gemm::Epilogue_,
                                         Operand_C,
                                         MODE_C,
                                         SplitK>;

        using Kernel = GemmUniversal;
    };
};

template  // kColMajor
using Config_U4_d = Sm80_s16816,             // A
                                Transform_Default,                      // tarnsform A
                                VoidOperand,                            // U
                                Operand_B_Pack,  // B
                                Transform_HMMA_16816<1, 0>,             // transform B
                                Operand_UV_Pack,        // V
                                kRowMajor,                              // order_C
                                half,                                   // Tc
                                raster_order,                           // raster order
                                -1>;                                    // group axis

template  // kColMajor
using Config_U4_g = Sm80_s16816,                // A
                                Transform_Default,                      // tarnsform A
                                VoidOperand,                            // U
                                Operand_B_Pack,  // B
                                Transform_HMMA_16816<1, 0>,             // transform B,
                                Operand_UV_Pack,        // V
                                kRowMajor,                              // order_C
                                T,                                      // Tc
                                raster_order,                           // raster order
                                0>;                                     // group axis

template
using Config_MXF4 = Sm80_s16816,  // A
                                Transform_HMMA_16816<0, 1>,                // tarnsform A
                                Operand_UV_Pack,           // U
                                Operand_B,                // B
                                Transform_Default,                         // transform B
                                VoidOperand,                               // V
                                kColMajor,                                 // order_C
                                T,                                         // Tc
                                raster_order,                              // raster order
                                group_axis>;                               // group axis

template
using Config_E4M3 = Sm80_s16816,  // A
                                Transform_HMMA_16816<0, 1>,                // tarnsform A
                                Operand_UV_Pack,          // U
                                Operand_B,                // B
                                Transform_Default,                         // transform B
                                VoidOperand,                               // V
                                kColMajor,                                 // order_C
                                T,                                         // Tc
                                raster_order,                              // raster order
                                group_axis>;                               // group axis

template
using Config_F16_g = Sm80_s16816,          // A
                                 Transform_Default,                // tarnsform A
                                 VoidOperand,                      // U
                                 Operand_B_Pack,  // B
                                 Transform_Default,                // transform B
                                 VoidOperand,                      // V
                                 kRowMajor,                        // order_C
                                 T,                                // Tc
                                 raster_order,                     // raster order
                                 0>;                               // group axis

}  // namespace turbomind::gemm::sm80_s16816


================================================
FILE: src/turbomind/kernels/gemm/arch/mma_simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/simt.h"

namespace turbomind::gemm {

template
struct MMA_SIMT {
    static constexpr int M = simt::OP_M;
    static constexpr int N = simt::OP_N;
    static constexpr int K = simt::OP_K;

    static constexpr int kThreadCount = 32;

    static constexpr auto kOpClass = OpClass::kSIMT;

    using FragA = Array;
    using FragB = Array;
    using FragC = Array;

    using OffsetC = Array;
    using FragC_  = FragC[1];

    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)
    {
        PRAGMA_UNROLL
        for (int k = 0; k < K; ++k) {
            d[0] = c[0] + float(a[k]) * float(b[k]);
        }

        // PRAGMA_UNROLL
        // for (int k = 0; k < K; ++k) {
        //     d[0] = c[0] + float(a[k] * b[k]);
        // }

        // T acc{};
        // PRAGMA_UNROLL
        // for (int k = 0; k < K; ++k) {
        //     acc += a[k] * b[k];
        // }
        // d[0] = c[0] + float(acc);
    }

    __device__ static constexpr OffsetC static_offset_C()
    {
        return {};
    }

    __device__ static int2 thread_offset_C()  // -> (m,n)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;
        return {lane_id / N, lane_id % N};
    }

    __device__ static void ReshapeC(const FragC& c, FragC_& c_)
    {
        c_[0] = c;
    }

    __device__ static int get_group_id(int thread_idx)
    {
        return thread_idx / WARP_SIZE;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/mma_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/gemm/desc.h"

namespace turbomind::gemm {

struct SM70_MMA_884 {
    // static constexpr int M = 16;
    // static constexpr int N = 16;
    static constexpr int M = 8;
    static constexpr int N = 32;
    static constexpr int K = 8;

    static constexpr int kThreadCount = 32;

    static constexpr auto kOpClass = OpClass::kMMA_s884;

    using FragA = Array;
    using FragB = Array;
    using FragC = Array;

    using OffsetC = Array;
    using FragC_  = Array[4];

    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)
    {
        mma_m8n8k4_row_col(d, (const Array&)a[0], (const Array&)b[0], (FragC&)c);
        if constexpr (K == 8) {
            mma_m8n8k4_row_col(d, (const Array&)a[4], (const Array&)b[4], (FragC&)d);
        }
    }

    __device__ static constexpr OffsetC static_offset_C()
    {
        OffsetC r{};
        PRAGMA_UNROLL
        for (int n = 0; n < 2; ++n) {
            PRAGMA_UNROLL
            for (int m = 0; m < 2; ++m) {
                r[n * 2 + m] = int2{m * 2, n * 4};
            }
        }
        return r;
    }

    __device__ static int2 thread_offset_C()  // -> (m,n)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;
        // return {
        //     (lane_id & 8) * 1 + (lane_id & 1) + lane_id / 16 * 4,
        //     (lane_id & 4) * 2 + (lane_id & 2),
        // };
        return {(lane_id & 1) + (lane_id / 16) * 4,  //
                (lane_id & 2) + (lane_id & 12) * 2};
    }

    __device__ static void ReshapeC(const FragC& c, FragC_& c_)
    {
        PRAGMA_UNROLL
        for (int m = 0; m < 4; ++m) {
            c_[m] = (Array&)c[m * 2];
        }
    }

    __device__ static int get_group_id(int thread_idx)
    {
        return thread_idx / WARP_SIZE;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/mma_sm80.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/gemm/desc.h"

namespace turbomind::gemm {

template
struct SM80_MMA_16x8x16_F32_F16_F16_F32_TN {
    static constexpr int M = 16;
    static constexpr int N = 8;
    static constexpr int K = 16;

    static constexpr int kThreadCount = 32;

    static constexpr auto kOpClass = OpClass::kMMA_s16816;

    using FragA = Array;
    using FragB = Array;
    using FragC = Array;

    using OffsetC = Array;  // (m, n)
    using FragC_  = Array[2];

    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)
    {
        mma_m16n8k16_row_col(d, a, b, (FragC&)c);
    }

    __device__ static constexpr OffsetC static_offset_C()
    {
        return {int2{0, 0}, int2{8, 0}};
    }

    __device__ static int2 thread_offset_C()  // -> (m,n)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;
        return {lane_id / 4, lane_id % 4 * 2};
    }

    __device__ static void ReshapeC(const FragC& c, FragC_& c_)
    {
        PRAGMA_UNROLL
        for (int m = 0; m < 2; ++m) {
            c_[m] = (Array&)c[m * 2];
        }
    }

    __device__ static int get_group_id(int thread_idx)
    {
        return thread_idx / WARP_SIZE;
    }
};

// This is not used yet
template
struct SM75_MMA_16x8x8_F32_F16_F16_F32_TN: SM80_MMA_16x8x16_F32_F16_F16_F32_TN {
    static constexpr int M = 16;
    static constexpr int N = 8;
    static constexpr int K = 8;

    using FragA = Array;
    using FragB = Array;
    using FragC = Array;

    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)
    {
        mma_m16n8k8_row_col(d, a, b, (FragC&)c);
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/operand_simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/arch/smem_copy_simt.h"
#include "src/turbomind/kernels/gemm/iterator.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/simt.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

namespace simt {

struct GetSmemLayout {
    template
    static constexpr auto apply(pair)
    {
        return SmemLayoutV2{};
    }
};

template
struct Operand_A {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kRowMajor;

    using SmemCopyAtom = SmemCopy_MMA_SIMT_A;

    using GetSmemLayout = GetSmemLayout;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_B {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kRowMajor;

    using SmemCopyAtom = SmemCopy_MMA_SIMT_B;

    using GetSmemLayout = GetSmemLayout;
    using GetGmemIter   = GetGmemIter;
};

template
struct _GetSmemLayoutC {
    template
    static constexpr auto apply(pair)
    {
        constexpr auto cs = mk2cs(M, N);
        return SmemLayoutV2{};
    }
};

template
struct _GetThreadMapC {
    template
    static constexpr auto apply(pair, constant)
    {
        constexpr auto cs    = mk2cs(M, N);
        constexpr int  WARPS = THREADS / WARP_SIZE;

        return ThreadMap_V2{};
    }
};

template
struct Operand_C {
    using Dtype = T;

    static constexpr Order kOrder = order;

    using GetSmemLayout = _GetSmemLayoutC;
    using GetThreadMap  = _GetThreadMapC;
};

template
struct Operand_V {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kColMajor;

    using SmemCopyAtom = SmemCopy_MMA_SIMT_V;

    struct GetSmemLayout {  // m-major
        template
        static constexpr auto apply(pair)
        {
            return SmemLayoutV2{};
        }
    };

    using GetGmemIter = GetGmemIter;
};

struct GetSmemLayout_Pack {
    template
    static constexpr auto apply(pair)
    {
        return SmemLayoutV2{};
    }
};

template
struct Operand_B_Pack {
    using Dtype = T;

    static constexpr int Pack_M = 1;

    static constexpr Pack  kPack  = HMMA_SIMT | OPERAND_B | Pack_M;
    static constexpr Order kOrder = kRowMajor;

    using SmemCopyAtom  = SmemCopyAtom_Pack_v3::SmemCopyAtom, kRowMajor, Pack_M>;
    using GetSmemLayout = GetSmemLayout_Pack;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_V_Pack {
    using Dtype = T;

    static constexpr int Pack_M = 1;

    static constexpr Pack  kPack  = HMMA_SIMT | OPERAND_V | Pack_M;
    static constexpr Order kOrder = kColMajor;

    using SmemCopyAtom = SmemCopyAtom_Pack_v3, kColMajor, Pack_M>;

    struct GetSmemLayout {  // m-major
        template
        static constexpr auto apply(pair)
        {
            return SmemLayoutV2{};
        }
    };

    using GetGmemIter = GetGmemIter;
};

}  // namespace simt

template
struct GetOperand: std::true_type {
    using Operand = simt::Operand_A;
};

template
struct GetOperand: std::true_type {
    using Operand = simt::Operand_B;
};

template
struct GetOperand: std::true_type {
    using Operand = simt::Operand_V;
};

template
struct GetOperand: std::true_type {
    using Operand = simt::Operand_B_Pack;
};

template
struct GetOperand: std::true_type {
    using Operand = simt::Operand_V_Pack;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/operand_sm70_s884.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/arch/smem_copy_sm70.h"
#include "src/turbomind/kernels/gemm/iterator.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

namespace sm70_s884 {

template
struct GetSmemLayout {
    template
    static constexpr auto apply(pair)
    {
        constexpr int2 cs = mk2cs(M, K);
        return SmemLayoutV2{};
    }
};

template
struct Operand_A {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kRowMajor;

    using SmemCopyAtom = SmemCopy_MMA_884_A;

    using GetSmemLayout = GetSmemLayout;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_B {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kRowMajor;  // (n,k)

    using SmemCopyAtom = SmemCopy_MMA_884_B;

    using GetSmemLayout = GetSmemLayout;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_V {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kColMajor;  // (n,k)

    using SmemCopyAtom = SmemCopy_MMA_884_V;

    struct GetSmemLayout {  // m-major
        template
        static constexpr auto apply(pair)
        {
            return SmemLayoutV2{};
        }
    };

    using GetGmemIter = GetGmemIter;
};

template
struct _GetSmemLayoutC {
    template
    static constexpr auto apply(pair)
    {
        constexpr auto cs = mk2cs(M, N);
        return SmemLayoutV2{};
    }
};

template
struct _GetThreadMapC {
    template
    static constexpr auto apply(pair, constant)
    {
        constexpr auto cs    = mk2cs(M, N);
        constexpr int  WARPS = THREADS / WARP_SIZE;

        return ThreadMap_V2{};
    }
};

template
struct Operand_C {
    using Dtype = T;

    static constexpr Order kOrder = order;

    using GetSmemLayout = _GetSmemLayoutC;
    using GetThreadMap  = _GetThreadMapC;
};

template
struct Operand_B_Pack {
    using Dtype = T;

    static constexpr int Pack_M = 1;

    static constexpr Pack  kPack  = HMMA_884 | OPERAND_B | Pack_M;
    static constexpr Order kOrder = kRowMajor;

    using SmemCopyAtom = SmemCopyAtom_Pack_v3, kOrder, Pack_M>;

    using GetSmemLayout = GetSmemLayout;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_V_Pack {
    using Dtype = T;

    static constexpr int Pack_M = 1;

    static constexpr Pack  kPack  = HMMA_884 | OPERAND_V | Pack_M;
    static constexpr Order kOrder = kColMajor;

    using SmemCopyAtom = SmemCopyAtom_Pack_v3, kColMajor, Pack_M>;

    struct GetSmemLayout {  // m-major
        template
        static constexpr auto apply(pair)
        {
            return SmemLayoutV2{};
        }
    };

    using GetGmemIter = GetGmemIter;
};

}  // namespace sm70_s884

template
struct GetOperand: std::true_type {
    using Operand = sm70_s884::Operand_A;
};

template
struct GetOperand: std::true_type {
    using Operand = sm70_s884::Operand_B;
};

template
struct GetOperand: std::true_type {
    using Operand = sm70_s884::Operand_V;
};

template
struct GetOperand: std::true_type {
    using Operand = sm70_s884::Operand_B_Pack;
};

template
struct GetOperand: std::true_type {
    using Operand = sm70_s884::Operand_V_Pack;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/arch/smem_copy_sm80.h"
#include "src/turbomind/kernels/gemm/iterator.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"
#include 

namespace turbomind::gemm {

namespace sm80_s16816 {

namespace detail {

struct GetSmemLayout {
    template
    static constexpr auto apply(pair)
    {
        // constexpr int S0 = S >= 16 ? 16 : 8;
        constexpr int S0 = 8;
        constexpr int C0 = C >= 64 ? 64 : (C >= 32 ? 32 : 16);
        using _Small     = std::conditional_t, Swizzle<1, 3, 3>>;
        using Swizzle    = std::conditional_t, _Small>;
        return SmemLayoutV2{};
    }
};

}  // namespace detail

template
struct GetSmemLayoutV2 {
    template
    static constexpr auto apply(pair)
    {
        constexpr int2 cs = mk2cs(M, K);
        return detail::GetSmemLayout::apply(pair{});
    }
};

// (m, k)
template
struct Operand_A {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = order;

    // using SmemCopyAtom =
    //     std::conditional_t, SmemCopy_MMA_16816_B>;

    // using SmemCopyAtom = std::conditional_t,
    //                                         LDSM_SM75_8x8>;

    using SmemCopyAtom = LDSM_SM75_8x8;

    using GetSmemLayout = GetSmemLayoutV2;
    using GetGmemIter   = GetGmemIter;
};

// (n, k)
template
struct Operand_B {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = order;

    // using SmemCopyAtom =
    //     std::conditional_t, SmemCopy_MMA_16816_A>;
    // using SmemCopyAtom = std::conditional_t,
    //                                         LDSM_SM75_8x8>;

    using SmemCopyAtom = LDSM_SM75_8x8;

    using GetSmemLayout = GetSmemLayoutV2;
    using GetGmemIter   = GetGmemIter;
};

template
struct _GetSmemLayoutC {
    template
    static constexpr auto apply(pair)
    {
        if constexpr (order == kRowMajor) {
            // x01  23
            // cccccss
            //                                    bits base shift
            return SmemLayoutV2>{};
        }
        else {
            // 234  x01
            // 23401x
            // cccccsss
            // so that x is not part of swizzling
            return SmemLayoutV2>{};
        }
    }
};

template
struct _GetThreadMapC {
    template
    static constexpr auto apply(pair, constant)
    {
        constexpr auto cs    = mk2cs(M, N);
        constexpr int  WARPS = THREADS / WARP_SIZE;

        return ThreadMap_V2{};
    }
};

template
struct Operand_C {
    using Dtype = T;

    static constexpr Order kOrder = order;

    using GetSmemLayout = _GetSmemLayoutC;
    using GetThreadMap  = _GetThreadMapC;
};

template
struct Operand_UV {
    using Dtype = T;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = kColMajor;

    using SmemCopyAtom = SmemCopy_MMA_16816_U;

    struct GetSmemLayout {
        template
        static constexpr auto apply(pair)
        {
            return SmemLayoutV2{};
        }
    };
    using GetGmemIter = GetGmemIter;
};

template
struct GetSmemLayout_Pack {
    template
    static constexpr auto apply(pair)
    {
        constexpr int2 CS = mk2cs(M, K);
        return SmemLayoutV2{};
    }
};

template
struct Operand_A_Pack {
    using Dtype = T;

    static constexpr int Pack_M = Pack_M_;

    static constexpr Pack  kPack  = HMMA_16816 | OPERAND_A | Pack_M;
    static constexpr Order kOrder = order;

    // using SmemCopyAtom = SmemCopyAtom_Pack_v2;
    using _SCp         = typename Operand_A::SmemCopyAtom;
    using SmemCopyAtom = SmemCopyAtom_Pack_v3;

    using GetSmemLayout = GetSmemLayout_Pack;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_B_Pack {
    using Dtype = T;

    static constexpr int Pack_M = Pack_M_;

    static constexpr Pack  kPack  = HMMA_16816 | OPERAND_B | Pack_M;
    static constexpr Order kOrder = order;

    using SmemCopyAtom = SmemCopyAtom_Pack_v2;

    using GetSmemLayout = GetSmemLayout_Pack;
    using GetGmemIter   = GetGmemIter;
};

template
struct Operand_UV_Pack {
    using Dtype = T;

    static constexpr int Pack_M = 1;

    static constexpr Pack  kPack  = HMMA_16816 | (is_V ? OPERAND_V : OPERAND_U) | Pack_M;
    static constexpr Order kOrder = Order::kColMajor;

    using _SCp         = typename Operand_UV::SmemCopyAtom;
    using SmemCopyAtom = SmemCopyAtom_Pack_v3;

    using GetSmemLayout = GetSmemLayout_Pack;
    using GetGmemIter   = GetGmemIter;
};

}  // namespace sm80_s16816

template
struct GetOperand: std::true_type {
    using Operand = sm80_s16816::Operand_A;
};

template
struct GetOperand: std::true_type {
    using Operand = sm80_s16816::Operand_B;
};

template
struct GetOperand: std::true_type {
    using Operand = sm80_s16816::Operand_UV;
};

template
struct GetOperand: std::true_type {
    using Operand = sm80_s16816::Operand_UV;
};

// template
// struct GetOperand: std::true_type {
//     using Operand = sm80_s16816::Operand_A_Pack;
// };

// template
// struct GetOperand: std::true_type {
//     using Operand = sm80_s16816::Operand_B_Pack;
// };

// template<>
// struct GetOperand: std::true_type {
//     using Operand = sm80_s16816::Operand_U_Pack;
// };

// template<>
// struct GetOperand: std::true_type {
//     using Operand = sm80_s16816::Operand_U_Pack;
// };

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/smem_copy_simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/gemm/simt.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

template
struct SmemCopy_MMA_SIMT_A {
    static constexpr int M = simt::OP_M;
    static constexpr int K = simt::OP_K;

    static constexpr int OP_N = simt::OP_N;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {lane_id / OP_N, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)  // -> (m, k)
    {
        Lds(*(Frag*)dst_ptr, (S &&) src_ptr);
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)  // -> (unique id, repeat id)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {pack_idx * M + lane_id / OP_N, lane_id % OP_N};
    }
};

template
struct SmemCopy_MMA_SIMT_B {
    static constexpr int M = simt::OP_N;
    static constexpr int K = simt::OP_K;

    static constexpr int OP_N = simt::OP_N;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {lane_id % OP_N, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        Lds(*(Frag*)dst_ptr, (S &&) src_ptr);
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)  // -> (unique id, repeat id)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {pack_idx * OP_N + lane_id % OP_N, lane_id / OP_N};
    }
};

template
struct SmemCopy_MMA_SIMT_V {
    static constexpr int M = simt::OP_N;
    static constexpr int K = K_;

    static constexpr int OP_N = simt::OP_N;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {pack_idx * OP_N + lane_id % OP_N, lane_id / OP_N};
    }

    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)
    {
        return {unique(thread_idx, 0).x, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool mask)
    {
        Lds(*(Frag*)dst_ptr, src_ptr);
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/smem_copy_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"

namespace turbomind::gemm {

template
struct SmemCopy_MMA_884_A {
    // static constexpr int M = 16;
    // static constexpr int K = 8;
    static constexpr int M = 8;
    static constexpr int K = 8;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        //                   4                3               01
        // const int m = lane_id / 16 * 4 + (lane_id & 8) + lane_id % 4;
        // return {pack_idx * M + m, (lane_id & 4) >> 2};

        //                   4                01
        const int m = lane_id / 16 * 4 + lane_id % 4;
        return {pack_idx * M + m, (lane_id & 12) >> 2};
    }

    __device__ static int2 get_offset(int thread_idx)
    {
        return int2{unique(thread_idx, 0).x, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        Lds(*(Frag*)dst_ptr, src_ptr);
    }
};

template
struct SmemCopy_MMA_884_B {
    // static constexpr int M = 16;
    // static constexpr int K = 8;
    static constexpr int M = 32;
    static constexpr int K = 8;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        //                4                     2                 01
        // const int m = lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;
        // return {pack_idx * M + m, (lane_id & 8) >> 3};

        //                  4                  23                  01
        const int m = lane_id / 16 * 4 + (lane_id & 12) * 2 + lane_id % 4;
        return {pack_idx * M + m, 0};
    }

    __device__ static int2 get_offset(int thread_idx)
    {
        return int2{unique(thread_idx, 0).x, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        Lds(*(Frag*)dst_ptr, src_ptr);
    }
};

template
struct SmemCopy_MMA_884_V {
    // static constexpr int M = 16;
    static constexpr int M = 32;
    static constexpr int K = K_;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        //                4                     2                 01
        // const int m = lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;
        // return {pack_idx * 16 + m, (lane_id & 8) >> 3};

        const int m = lane_id / 16 * 4 + (lane_id & 12) * 2 + lane_id % 4;
        return {pack_idx * M + m, 0};
    }

    __device__ static int2 get_offset(int thread_idx)
    {
        return int2{unique(thread_idx, 0).x, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        Lds(*(Frag*)dst_ptr, src_ptr);
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch/smem_copy_sm80.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

template
struct LDSM_x4 {
    template
    __device__ static void apply(S src_ptr, D dst_ptr)
    {
        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);
        if constexpr (trans) {
            ldsm_x4_trans(*(Array*)dst_ptr, uint_ptr);
        }
        else {
            ldsm_x4(*(Array*)dst_ptr, uint_ptr);
        }
    }
};

template
struct LDSM_x2 {
    template
    __device__ static void apply(S src_ptr, D dst_ptr)
    {
        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);
        if constexpr (trans) {
            ldsm_x2_trans(*(Array*)dst_ptr, uint_ptr);
        }
        else {
            ldsm_x2(*(Array*)dst_ptr, uint_ptr);
        }
    }
};

template
struct LDSM_x1 {
    template
    __device__ static void apply(S src_ptr, D dst_ptr)
    {
        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);
        if constexpr (trans) {
            ldsm_x1_trans(*(Array*)dst_ptr, uint_ptr);
        }
        else {
            ldsm_x1(*(Array*)dst_ptr, uint_ptr);
        }
    }
};

template
struct SmemCopy_MMA_16816_A {
    static constexpr int M = 16;
    static constexpr int K = 16;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)
    {
        const int lane_id = thread_idx % WARP_SIZE;

        const int c = lane_id / 16 * 8;
        const int s = lane_id % 16;

        return trans ? int2{c, s} : int2{s, c};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        LDSM_x4::apply((S &&) src_ptr, (D &&) dst_ptr);
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};
    }
};

template
struct SmemCopy_MMA_16816_B {
    static constexpr int M = 16;
    static constexpr int K = 16;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;

        const int c = lane_id / 8 * 8 % 16;
        const int s = lane_id % 8 + lane_id / 16 * 8;

        return trans ? int2{c, s} : int2{s, c};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        LDSM_x4::apply((S &&) src_ptr, (D &&) dst_ptr);
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};
    }
};

template
struct LDSM_SM75_8x8 {
    static constexpr int M = M_;
    static constexpr int K = K_;

    static constexpr int iM = M / 8;
    static constexpr int iK = K / 8;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        int       c, s;
        if constexpr (mat_order == kColMajor) {
            s = lane_id % 16;
            c = lane_id / 16 * 8;
        }
        else {
            s = lane_id / 16 * 8 + lane_id % 8;
            c = lane_id & 8;
        }
        int2 mk = cs2mk(c, s);
#if __CUDA_ARCH__ <= 750  // wrap ptrs around for sm_75
        mk.x %= M;
        mk.y %= K;
#endif
        return mk;
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)
    {
        constexpr bool trans = thr_order != kRowMajor;
        if constexpr (sizeof(Frag) == 16) {
            LDSM_x4::apply((S &&) src_ptr, (D &&) dst_ptr);
        }
        else if constexpr (sizeof(Frag) == 8) {
            LDSM_x2::apply((S &&) src_ptr, (D &&) dst_ptr);
        }
        else if constexpr (sizeof(Frag) == 4) {
            LDSM_x1::apply((S &&) src_ptr, (D &&) dst_ptr);
        }
        else {
            static_assert(sizeof(S) != sizeof(S), "not implemented");
        }
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};
    }
};

template
struct SmemCopy_MMA_16816_U {  // (M, K)
    static constexpr int M = 16;
    static constexpr int K = 1;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        // Note: this forbids sub-tile group sizes
        return {lane_id / 4, 0};
    }

    template
    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool mask)
    {
        PRAGMA_UNROLL
        for (int i = 0; i < 2; ++i) {
            Lds(*((Array*)dst_ptr + i), src_ptr + i * 8);
        }
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        const int lane_id = thread_idx % WARP_SIZE;
        return {pack_idx * 8 + lane_id / 4, lane_id % 4};
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/arch.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind::gemm {

// tags for dispatching & conditional codegen

template
struct Arch {
    static constexpr bool is_compatible(int arch)
    {
        return Begin <= arch && (End == -1 || arch < End);
    }
};

struct Sm70: Arch<700, 750> {
    static constexpr int value = 700;
};

struct Sm75: Arch<750, 800> {
    static constexpr int value = 750;
};

struct Sm80: Arch<800, 900> {
    static constexpr int value = 800;
};

struct Sm90: Arch<900> {
    static constexpr int value = 900;
};

inline bool is_arch_compatible(int karch, int darch)
{
    switch (karch) {
        case 0:
            return true;
        case 700:
            return Sm70::is_compatible(darch);
        case 750:
            return Sm75::is_compatible(darch);
        case 800:
            return Sm80::is_compatible(darch);
        case 900:
            return Sm90::is_compatible(darch);
        default:
            return false;
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/cast.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/cast.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"

namespace turbomind {

template
struct Cast {
    template
    __device__ static Array apply(const Array& vi)
    {
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            vo[i] = static_cast(vi[i]);
        }
        return vo;
    }
};

template
struct Cast {
    template
    __device__ static Array apply(const Array& vi)
    {
        static_assert(N % 8 == 0);
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            uint32_t& v = (uint32_t&)vo[i];
            v           = 0;
            PRAGMA_UNROLL
            for (int j = 7; j >= 0; --j) {
                v = (v << 4) | vi[i + j];
            }
        }
        return vo;
    }
};

template
struct Cast {
    template
    __device__ static Array apply(const Array& vi)
    {
        static_assert(N % 8 == 0);
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            uint32_t v = (const uint32_t&)vi[i];
            PRAGMA_UNROLL
            for (int j = 0; j < 8; ++j) {
                vo[i + j] = (v & 0xf);
                v >>= 4;
            }
        }
        return vo;
    }
};

template<>
struct Cast {
    template
    __device__ static Array apply(const Array& vi)
    {
        return vi;
    }
};

template
__global__ void cast_kernel(To* dst, const Ti* src, size_t n)
{
    n /= VecSize;

    auto p_src = (const Array*)src;
    auto p_dst = (Array*)dst;

    for (size_t p = threadIdx.x + blockDim.x * blockIdx.x; p < n; p += blockDim.x * gridDim.x) {
        Array vi;
        Ldg(vi, (const Ti*)&p_src[p]);

        Array vo = Cast::apply(vi);

        Store((To*)&p_dst[p], vo);
    }
}

template
void invokeCast(To* dst, const Ti* src, size_t n, cudaStream_t st)
{
    cast_kernel<<<256, 256, 0, st>>>(dst, src, n);
}

void extend_to_u8(uint8_t* dst, const uint4_t* src, size_t n, cudaStream_t st)
{
    invokeCast<8>(dst, src, n, st);
}

void compact_to_u4(uint4_t* dst, const uint8_t* src, size_t n, cudaStream_t st)
{
    invokeCast<8>(dst, src, n, st);
}

void extend_to_u16(uint16_t* dst, const uint4_t* src, size_t n, cudaStream_t st)
{
    invokeCast<8>(dst, src, n, st);
}

namespace {

__global__ void extend_u16_u8(uint16_t* dst, const uint8_t* src, size_t n)
{
    int64_t idx = threadIdx.x + (int64_t)blockDim.x * blockIdx.x;
    if (idx < n) {
        dst[idx] = src[idx];
    }
}

}  // namespace

void extend_to_u16(uint16_t* dst, const uint8_t* src, size_t n, cudaStream_t st)
{
    extend_u16_u8<<<(n + 511) / 512, 512, 0, st>>>(dst, src, n);
}

template
__global__ void fuse_scales_and_zeros_kernel(T* fused, const T* scales, T* zeros, size_t n)
{
    n /= VecSize;

    auto p_scales = (const Array*)scales;
    auto p_zeros  = (const Array*)zeros;

    auto p_fused = (Array*)fused;

    for (size_t p = threadIdx.x + blockDim.x * blockIdx.x; p < n; p += blockDim.x * gridDim.x) {
        Array vs;
        Ldg(vs, (const T*)&p_scales[p]);
        Array vz{};
        if (zeros) {
            Ldg(vz, (const T*)&p_zeros[p]);
        }
        Array vf;
        PRAGMA_UNROLL
        for (int i = 0; i < VecSize; ++i) {
            vf[i * 2]     = vs[i];
            vf[i * 2 + 1] = -vz[i] * vs[i];
        }
        Store((T*)&p_fused[p], vf);
    }
}

void fuse_scales_and_zeros(half* fused, const half* scales, half* zeros, size_t n, cudaStream_t st)
{
    fuse_scales_and_zeros_kernel<4><<<256, 256, 0, st>>>(fused, scales, zeros, n);
}

template
__global__ void
interleave_output_dims_kernel(T* __restrict__ fused, const T* __restrict__ a, const T* __restrict__ b, int m, int k)
{
    using Vec1 = Array;

    const int ki = blockIdx.y;

    auto p_a = reinterpret_cast(a + ki * m);
    auto p_b = reinterpret_cast(b + ki * m);

    using Vec2 = Array;

    auto p_f = reinterpret_cast(fused + ki * m * 2);

    m /= VecSize;

    const int tidx = threadIdx.x + blockIdx.x * blockDim.x;

    for (int64_t mi = tidx; mi < m; mi += blockDim.x * gridDim.x) {
        Vec1 va;
        Vec1 vb;
        Ldg(va, (const T*)&p_a[mi]);
        Ldg(vb, (const T*)&p_b[mi]);
        Vec2 vc;
        PRAGMA_UNROLL
        for (int i = 0; i < VecSize; ++i) {
            vc[i * 2]     = va[i];
            vc[i * 2 + 1] = vb[i];
        }
        Store((T*)&p_f[mi], vc);
    }
}

template
void interleave_output_dims_impl(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st)
{
    constexpr int kVecSize = std::min(8, 128 / (bitsof * 2));

    constexpr int block = 256;
    const dim3    grid(1, k);  // x is a grid stride loop

    interleave_output_dims_kernel<<>>(fused, a, b, m, k);
}

template void
interleave_output_dims_impl(uint8_t* fused, const uint8_t* a, const uint8_t* b, int m, int k, cudaStream_t st);
template void
interleave_output_dims_impl(uint16_t* fused, const uint16_t* a, const uint16_t* b, int m, int k, cudaStream_t st);
template void
interleave_output_dims_impl(uint32_t* fused, const uint32_t* a, const uint32_t* b, int m, int k, cudaStream_t st);

__global__ void adjust_ue8m0_scale_for_half_kernel(uint8_t* data, int n)
{
    int64_t idx = threadIdx.x + (int64_t)blockDim.x * blockIdx.x;
    if (idx < n) {
        /// TODO: saturate the quantized values accordingly
        data[idx] = max(0, min(30, (int)data[idx] + 15 - 127));  // exponent 31 is INF in half
    }
}

void AdjustUe8m0ScaleForHalf(uint8_t* data, int n, cudaStream_t st)
{
    constexpr int block = 512;
    const int     grid  = cdiv(n, block);
    adjust_ue8m0_scale_for_half_kernel<<>>(data, n);
}

template
__global__ void BlockscaleToGroupscale_Kernel(T1* dst, const T0* src, int64_t n, int block_size)
{
    int64_t idx = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;
    if (idx < n) {
        dst[idx] = (T1)src[idx / block_size];
    }
}

Tensor BlockscaleToGroupscale(const Tensor& scales, DataType data_type, int block_size)
{
    TM_CHECK_EQ(scales.dtype(), kFloat32);

    Tensor ret{{scales.shape(0), scales.shape(1) * block_size}, data_type, kDEVICE};

    auto stream = core::Context::stream().handle();

    auto invoke = [&](auto t) {
        using T = decltype(t);
        BlockscaleToGroupscale_Kernel<<<(ret.size() + 511) / 512, 512, 0, stream>>>(
            ret.data(), scales.data(), ret.size(), block_size);
    };

    TM_DISPATCH_DTYPES(data_type, invoke, half_t, bfloat16_t);

    return ret;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/cast.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/kernels/core/data_type.h"

namespace turbomind {

void extend_to_u8(uint8_t* dst, const uint4_t* src, size_t n, cudaStream_t st = {});

void extend_to_u16(uint16_t* dst, const uint4_t* src, size_t n, cudaStream_t st = {});

void extend_to_u16(uint16_t* dst, const uint8_t* src, size_t n, cudaStream_t st);

void compact_to_u4(uint4_t* dst, const uint8_t* src, size_t n, cudaStream_t st = {});

void transpose_u4(uint4_t* dst, const uint4_t* src, int s, int c, cudaStream_t st = {});

void fuse_scales_and_zeros(half* fused, const half* scales, half* zeros, size_t n, cudaStream_t st = {});

template
void interleave_output_dims_impl(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st);

template
inline void interleave_output_dims(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st)
{
    auto dispatch = [&](auto u) {
        using U = decltype(u);
        return interleave_output_dims_impl((U*)fused, (const U*)a, (const U*)b, m, k, st);
    };
    if constexpr (bitsof == 8) {
        return dispatch(uint8_t{});
    }
    else if constexpr (bitsof == 16) {
        return dispatch(uint16_t{});
    }
    else if constexpr (bitsof == 32) {
        return dispatch(uint32_t{});
    }
}

void AdjustUe8m0ScaleForHalf(uint8_t* data, int n, cudaStream_t st);

Tensor BlockscaleToGroupscale(const Tensor& scales, DataType data_type, int block_size);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/context.cu
================================================

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/context.h"
#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include "src/turbomind/utils/monotonic.h"
#include 
#include 
#include 
#include 

namespace turbomind::gemm {

static std::optional get_gemm_desc(const Operation&    operation,
                                             const MatrixLayout& Adesc,
                                             const MatrixLayout& Udesc,
                                             const MatrixLayout& Bdesc,
                                             const MatrixLayout& Vdesc,
                                             const MatrixLayout& Cdesc,
                                             const MatrixLayout& Ddesc,
                                             int                 arch)
{

    // Constant dimensions are set to the exact value
    // Variable dimensions are set to sum of the values

    const int m0 = Adesc.rows, k0 = Adesc.cols;
    const int k1 = Bdesc.rows, n0 = Bdesc.cols;
    const int m1 = Ddesc.rows, n1 = Ddesc.cols;

    const int l0 = Adesc.num, l1 = Bdesc.num, l2 = Ddesc.num;

    if (m0 != m1 || n0 != n1 || k0 != k1 || l0 != l1 || l0 != l2) {
        fprintf(stderr, "%d %d %d %d %d %d %d %d %d\n", m0, m1, n0, n1, k0, k1, l0, l1, l2);
        return {};
    }

    GemmDesc desc{arch,
                  Adesc.type,
                  Bdesc.type,
                  Ddesc.type,
                  Adesc.order,
                  Bdesc.order,
                  Ddesc.order,
                  get_mode(Adesc),
                  get_mode(Bdesc),
                  get_mode(Ddesc),
                  Adesc.pack,
                  Bdesc.pack,
                  Udesc.pack,
                  Vdesc.pack,
                  operation.quant_a,
                  operation.quant_b,
                  operation.epilogue,
                  operation.batch_dim,
                  -1};

    desc.m   = m0;
    desc.n   = n0;
    desc.k   = k0;
    desc.num = std::max(l0, 1);

    if (desc.num > 1) {
        desc.group_axis = operation.batch_dim;
    }

    return desc;
}

std::vector get_swizzle(const int4& shape, const LaunchSpec& spec, const std::vector& swizzle)
{
    std::vector vec;
    const int        max_swizzle = spec.kernel->GetMaxSwizzle(shape);
    for (const auto& s : swizzle) {
        if (s <= max_swizzle && std::find(vec.begin(), vec.end(), s) == vec.end()) {
            vec.push_back(s);
        }
    }
    std::vector ret;
    for (const auto& s : vec) {
        auto tmp    = spec;
        tmp.swizzle = s;
        ret.push_back(tmp);
    }
    return ret;
}

Context::Context(const cudaDeviceProp& prop)
{
    arch_     = prop.major * 100 + prop.minor * 10;
    sm_count_ = prop.multiProcessorCount;
}

bool Context::Init(const Operation&    operation,
                   const MatrixLayout& Adesc,
                   const MatrixLayout& Udesc,
                   const MatrixLayout& Bdesc,
                   const MatrixLayout& Vdesc,
                   const MatrixLayout& Cdesc,
                   const MatrixLayout& Ddesc)
{
    auto desc = get_gemm_desc(operation, Adesc, Udesc, Bdesc, Vdesc, Cdesc, Ddesc, arch_);
    if (!desc) {
        return false;
    }

    desc_       = *desc;
    desc_trans_ = transpose(desc_);

    return true;
}

std::vector Context::Filter(const std::vector& kernels) const
{
    std::vector> feasible;
    auto get_batch_dim  = [](auto k, auto& g) { return g.batch_dim ? k->desc().cta_tile.y : k->desc().cta_tile.x; };
    int  max_batch_size = 0;  // max batch size of single CTA tile

    for (auto& k : kernels) {
        auto& g = get_desc(*k);
        if (k->is_feasible(g)) {
            auto bsz = get_batch_dim(k, g);
            feasible.emplace_back(k, bsz);
            max_batch_size = std::max(bsz, max_batch_size);
        }
    }

    // Batch size of the GEMM problem
    const int batch_size = desc_.batch_dim ? desc_.n : desc_.m;
    // std::cout << "BATCH SIZE: " << batch_size << "\n";

    // Find smallest kernel the problem can fit into (may not exist)
    for (const auto& [k, bsz] : feasible) {
        if (bsz >= batch_size) {
            max_batch_size = std::min(max_batch_size, bsz);
        }
    }

    const auto pred = [&](auto k) {  //
        return k.second > max_batch_size;
    };
    feasible.erase(std::remove_if(feasible.begin(), feasible.end(), pred), feasible.end());

    std::vector ret;
    for (auto& [k, bsz] : feasible) {
        // std::cout << "KERNEL: " << k->name() << ", BSZ: " << bsz << std::endl;
        ret.push_back(k);
    }

    return ret;
}

std::vector Context::Populate(const Kernel& kernel, const PopulateParam& param) const
{
    // early exit for cuBLAS backend
    if (kernel.desc().backend) {
        return {LaunchSpec{const_cast(&kernel), 0, 1}};
    }

    const auto& gemm = get_desc(kernel);

    const int m = gemm.m, n = gemm.n, k = gemm.k, num = std::max(1, gemm.num);

    const auto& desc = kernel.desc();
    const auto& info = kernel.info();

    const int64_t tiled_shape_m = cdiv(m, desc.cta_tile.x * (desc.group_axis == 0 ? num : 1));
    const int64_t tiled_shape_n = cdiv(n, desc.cta_tile.y * (desc.group_axis == 1 ? num : 1));
    const int     chunk_cnt_k   = cdiv(k, kernel.chunk_size_k());

    // Despite we only have sm_count * constant tensor cores, this is the granularity for scheduling
    const int   concurrency     = sm_count_ * kernel.info().max_active_ctas;
    const float waves_per_split = float(tiled_shape_m * tiled_shape_n) / concurrency;
    const float splits_per_wave = 1.f / waves_per_split;

    // Tile quantization
    const int64_t ceil_m = tiled_shape_m * desc.cta_tile.x;
    const int64_t ceil_n = tiled_shape_n * desc.cta_tile.y;

    // int max_splits = kernel.GetMaxSplits(m, n, k, param.barriers_size, param.partials_size);
    int max_splits = kernel.GetMaxSplits({m, n, k, num}, 0, param.barriers_size, param.partials_size);

    // std::cout << "max_splits: " << max_splits << std::endl;

    max_splits = std::min(param.max_splits, max_splits);

    std::vector specs;

    /// TODO: revise this according to the lastest scheduler
    for (int splits = 1; splits <= max_splits; ++splits) {
        // Split quantization, penalize uneven splits
        const int64_t split_ceil_k = cdiv(chunk_cnt_k, splits) * kernel.chunk_size_k();
        // Footprint for single split
        const int64_t split_mma_cost = ceil_m * ceil_n * split_ceil_k;
        // Footprint for single wave
        const int64_t wave_mma_cost = split_mma_cost * splits_per_wave;

        // Wave quantization
        // const int waves = (int)std::ceil(wave_per_split * splits);

        // Bold simulation of thread block scheduling
        const int   grid_size    = tiled_shape_m * tiled_shape_n * splits * num;
        const int   full_waves   = grid_size / concurrency;
        const int   residue      = grid_size % concurrency;
        const float partial_wave = (float)cdiv(residue, sm_count_) / info.max_active_ctas;
        const float waves        = full_waves + partial_wave;

        if (splits > 1 && waves > param.max_waves) {
            break;
        }
        // ceil(tiled_mn / C * splits) * C / tiled_mn * ceil_m * ceil_n * split_ceil_k
        const int64_t mma_cost = wave_mma_cost * waves;

        // IO has less severe quantization effect
        const int64_t mio_cost_a = byte_size(desc.type_a, tiled_shape_n * m * split_ceil_k) * splits * num;
        const int64_t mio_cost_b = byte_size(desc.type_b, tiled_shape_m * n * split_ceil_k) * splits * num;
        /// TODO: read type from `desc_.accum` when added
        const int64_t mio_cost_c = byte_size(desc.type_c, (int64_t)m * n) * (splits - 1) * 2 * num;
        const int64_t mio_cost   = mio_cost_a + mio_cost_b + mio_cost_c;

        // std::cout << kernel.name() << " " << splits << " " << waves << " " << (float)mio_cost << " " <<
        // (float)mma_cost
        //           << "\n";

        // metrics.emplace_back(splits, KernelMetric{mio_cost, mma_cost});

        LaunchSpec spec{};
        spec.kernel    = const_cast(&kernel);
        spec.splits    = splits;
        spec.swizzle   = param.swizzle;
        spec.estimated = {mio_cost, mma_cost};
        specs.push_back(spec);
    }

    return specs;
}

std::vector Context::Swizzle(const LaunchSpec& spec, const std::vector& swizzle) const
{
    auto& desc = get_desc(*spec.kernel);
    return get_swizzle({desc.m, desc.n, desc.k, desc.num}, spec, swizzle);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/context.h
================================================
#pragma once

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/types.h"
#include 

namespace turbomind::gemm {

struct PopulateParam {
    int    max_splits;
    int    max_waves;
    int    swizzle;
    size_t barriers_size;
    size_t partials_size;
};

class Context {
public:
    explicit Context(const cudaDeviceProp& prop);

    bool Init(const Operation&    operation,
              const MatrixLayout& Adesc,
              const MatrixLayout& Udesc,
              const MatrixLayout& Bdesc,
              const MatrixLayout& Vdesc,
              const MatrixLayout& Cdesc,
              const MatrixLayout& Ddesc);

    std::vector Filter(const std::vector& kernels) const;

    std::vector Populate(const Kernel& kernel, const PopulateParam& param) const;

    std::vector Swizzle(const LaunchSpec& spec, const std::vector& swizzle) const;

    const GemmDesc& desc() const
    {
        return desc_;
    }

    const GemmDesc& get_desc(const Kernel& kernel) const
    {
        return kernel.desc().transpose ? desc_trans_ : desc_;
    }

    // Alignment
    // (align_m, align_n, align_k) -> is_aligned
    //  gcd_mnk need to be part of gemm desc

    // Max splits
    // (max_mn_tiles, max_k_tiles) -> max_splits

    // CTA Swizzling
    // - GemmScheduler: return get_log_tile
    // - DynamicScheduler: bypass

    // Cost estimation
    //

protected:
    int arch_{};
    int sm_count_{};

    GemmDesc desc_{};
    GemmDesc desc_trans_{};
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/convert.cuh
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/math.h"

#include "src/turbomind/kernels/attention/quantization.h"

#include "src/turbomind/kernels/gemm/cp_async.h"
#include "src/turbomind/kernels/gemm/format.h"
#include "src/turbomind/kernels/gemm/iterator_sm70.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

template
__device__ void print_type(T)
{
    if (threadIdx.x == 0) {
        printf("%s\n", __PRETTY_FUNCTION__);
    }
}

namespace turbomind::gemm {

template
struct ConvertOperand {

    static constexpr int M = M_;
    static constexpr int K = K_;

    using Operand = MakeOperand, M_, K_, 1>;

    using Ts         = typename Operand::Dtype;
    using SmemLayout = typename Operand::SmemLayout;
    using GmemIter   = typename Operand::GmemIter;

    using Atom = typename Operand::SmemCopyAtom;

    using SmemCopy = SmemCopy;

    using Accessor = SmemAccessor;

    static constexpr auto kOrderS = Operand::kOrder;

    static constexpr int ITER_K = ceil_div(K, Atom::K);

    /// TODO: generailize this
    static constexpr int WARP_CNT = 1;

    using PtrD = get_pointer_type;

    struct Param {
        int         m;
        int         k;
        MatrixParam src;
        MatrixParam dst;
    };

    using SharedStorage = Array;

    template
    static constexpr int get_fragment_size(Array (&)[M])
    {
        return N;
    }

    template
    static constexpr int get_fragment_num(Array (&)[M])
    {
        return M;
    }

    __device__ constexpr int2 _mk2cs(int m, int k)
    {
        return mk2cs(m, k);
    }

    __device__ void operator()(const Param& param, char* smem_buf)
    {
        Ts* smem = (Ts*)smem_buf;

        const int cta_cnt_m = ceil_div(param.m, M);
        const int cta_cnt_k = ceil_div(param.k, K);

        const int cta_idx_m = blockIdx.x;

        const int cta_offset_m = cta_idx_m * M;
        const int residue_m    = min(M, param.m - cta_offset_m);

        const int warp_id = threadIdx.x / WARP_SIZE;

        const int warp_offset_m = 0;

        Converter converter{};

        typename SmemCopy::Frag data;

        constexpr int kFragSize = get_fragment_size(data);
        constexpr int kFragNum  = get_fragment_num(data);
        constexpr int kPackSize = kFragSize * Pack_M;

        const int pack_cnt_k = ceil_div(param.k, Atom::K);
        const int pack_cnt_m = ceil_div(param.m, Atom::M * Pack_M);

        if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
            // printf("m=%d, k=%d, lds = %d\n", param.m, param.k, param.lds);
            // printf(
            //     "CTA_M=%d, CTA_K=%d, cta_cnt_m=%d, cta_cnt_k=%d, cta_idx_m=%d, ITER_K=%d, pack_cnt_m=%d,
            //     pack_cnt_k=%d\n", M_, K_, cta_cnt_m, cta_cnt_k, cta_idx_m, ITER_K, pack_cnt_m, pack_cnt_k);
            // printf("frag_size=%d, frag_num=%d, pack_size=%d\n", kFragSize, kFragNum, kPackSize);
        }

        const int cta_offset_k = (cta_cnt_k - 1) * K;
        const int residue_k    = min(K, param.k - cta_offset_k);

        const auto mat_S = resolve(param.src, 0);
        const auto mat_D = resolve(param.dst, 0);

        // Handle residue k first
        GmemIter gmem{mat_S, {cta_offset_m, cta_offset_k}, {residue_m, residue_k}};

        gmem.smem_data_ = smem;
        gmem.ClearSmem();

        __syncthreads();

        // gmem.Prefetch(true);

        typename GmemIter::Fragments fragments{};
        gmem.Fetch(fragments, true);
        gmem.Store(fragments);

        // Rest full k tiles
        gmem            = GmemIter{mat_S, {cta_offset_m, 0}, {residue_m, K}};
        gmem.smem_data_ = smem;

        SmemCopy smem_copy({warp_offset_m, 0});

        // last, 0, 1, 2, 3, ..., last - 1
        int cta_idx_k = cta_cnt_k - 1;

        get_pointer_type mat_D_ptr{(Td*)mat_D.ptr.ptr};

        for (int k_stage = 0; k_stage < cta_cnt_k; ++k_stage) {
            __syncthreads();

            PRAGMA_UNROLL
            for (int k = 0; k < ITER_K; ++k) {
                // Assuming `SmemCopy` is a warp-level operation
                // Load from smem as we are doing GEMMs
                // SmemCopy::copy(smem, data, int2{warp_offset_m, 0}, k);
                smem_copy(smem, data, k);

                PRAGMA_UNROLL
                for (int m = 0; m < kFragNum; m += Pack_M) {
                    // Convert and pack rmem data
                    Array packed = converter((Array&)data[m]);

                    // Logical pack coords
                    const int pack_idx_k = cta_idx_k * ITER_K + k;
                    const int pack_idx_m = ((cta_idx_m * WARP_CNT + warp_id) * kFragNum + m) / Pack_M;

                    // Linear pack index
                    const int pack_index = cs2idx(_mk2cs(pack_idx_m, pack_idx_k),  //
                                                  _mk2cs(pack_cnt_m, pack_cnt_k).x);

                    auto [unique_id, repeat_id] = Atom::unique(threadIdx.x, pack_index);

                    // Store in [pack_id, lane_id], static cast is needed to decay SubBytePtr to T*
                    auto dst_ptr = static_cast(mat_D_ptr + unique_id * kPackSize);

                    if (pack_idx_m < pack_cnt_m && pack_idx_k < pack_cnt_k && repeat_id == 0) {
                        Store(dst_ptr, packed);
                    }
                }
            }

            __syncthreads();

            if (k_stage == cta_cnt_k - 1) {
                break;
            }

            // gmem.Prefetch(true);
            gmem.Fetch(fragments, true);
            gmem.Store(fragments);
            gmem.Advance();

            cta_idx_k = k_stage;
        }
    }

    __device__ void print(...) {}

    __device__ void print(Array _x)
    {
        auto& x = (const Array&)_x;
        printf("tidx=%d, %f %f %f %f\n", (int)threadIdx.x, (float)x[0], (float)x[1], (float)x[2], (float)x[3]);
    }
};

extern __shared__ char smem_buf[];

template
__global__ void convert_kernel(typename Kernel::Param param)
{
    Kernel kernel;
    kernel(param, smem_buf);
}

constexpr bool is_AB(Op_Tag op)
{
    if (op == OPERAND_A || op == OPERAND_B) {
        return true;
    }
    else {
        return false;
    }
}

constexpr bool is_UV(Op_Tag op)
{
    return !is_AB(op);
}

template
constexpr int unit_size(basic_type)
{
    return 1;
}

constexpr int unit_size(basic_type)
{
    return 4;
}

constexpr int unit_size(basic_type)
{
    return 8;
}

// MMA     : H_16816, H_1688, H_884, H_SIMT
// Operand : A, B, U, V
// Order   : row, col
// Dtype   : u16, u8, u4 (u6, u3)
// PackNum : 1, 2, 4

template
struct Config {
    static constexpr int CTA_M = 64;
    static constexpr int CTA_K = 32;

    static constexpr int BLOCK_SIZE = 32;

    using Stype = typename Operand::Dtype;
    using Dtype = Dtype_;

    using Kernel = ConvertOperand>;
};

template
void Convert_v2_Impl(const void* S, const MatrixLayout& Sdesc, void* D, const MatrixLayout& Ddesc, cudaStream_t stream)
{
    using Kernel = typename Config::Kernel;
    using Stype  = typename Config::Stype;
    using Dtype  = typename Config::Dtype;

    constexpr int CTA_M = Config::CTA_M;

    static constexpr int kSmemSize = sizeof(typename Kernel::SharedStorage);

    if (kSmemSize > (48 << 10)) {
        cudaFuncSetAttribute(convert_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
    }

    typename Kernel::Param param{Sdesc.rows, Sdesc.cols, to_param((void*)S, Sdesc), to_param((void*)D, Ddesc)};

    constexpr int threads = Config::BLOCK_SIZE;
    const int     blocks  = ceil_div(Sdesc.rows, CTA_M);

    // std::cout << __PRETTY_FUNCTION__ << std::endl;
    // std::cout << __PRETTY_FUNCTION__ << "\nThreadMap:\n";
    // Print(typename Kernel::GmemIter::ThreadMap{});

    convert_kernel<<>>(param);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/convert.h
================================================

#include 
#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

struct LayoutConverter {

    Order order;
    Pack  pack;

    virtual int Convert(const void*         S,  //
                        const MatrixLayout& Sdesc,
                        void*               D,
                        MatrixLayout&       Ddesc,
                        cudaStream_t        stream) const = 0;
};

// Pointers to singletons
std::array GetConverters(DataType data_type,
                                                    DataType weight_type,  //
                                                    DataType input_type,
                                                    bool     grouped,
                                                    int      sm);

// Free with `cudaFree`
void* MakeStridedPtrs(const std::vector>& ptrs, cudaStream_t stream);
void* MakeBlockedPtrs(const std::vector>& ptrs, cudaStream_t stream);

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/convert_v3.cu
================================================

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/convert.cuh"
#include "src/turbomind/kernels/gemm/convert.h"
#include "src/turbomind/kernels/gemm/types.h"

#include "src/turbomind/kernels/gemm/arch/operand_simt.h"
#include "src/turbomind/kernels/gemm/arch/operand_sm70_s884.h"
#include "src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h"

namespace turbomind::gemm {

template
struct LayoutConverterImpl: public LayoutConverter {

    LayoutConverterImpl(): LayoutConverter{}
    {
        this->order = order_;
        this->pack  = mma_tag | op_tag | pack_num;
    }

    int Convert(const void*         S,
                const MatrixLayout& Sdesc_,  // (m,k) / (n,k)
                void*               D,
                MatrixLayout&       Ddesc,  // (m,k) / (n,k)
                cudaStream_t        stream) const override
    {
        // TM_CHECK_EQ(Sdesc.pack, 0U) << "Source must be non-packed format";

        const bool trans = op_tag == OPERAND_B || op_tag == OPERAND_V;
        // (k, n) -> (n, k)
        MatrixLayout Sdesc = trans ? transpose(Sdesc_) : Sdesc_;
        // MatrixLayout Ddesc = trans ? transpose(Ddesc_) : Ddesc_;

        TM_CHECK_NOTNULL(S);
        TM_CHECK_NOTNULL(D);

        using Operand = typename GetOperand::Operand;

        Convert_v2_Impl>(S, Sdesc, D, Ddesc, stream);

        constexpr Pack pack = mma_tag | op_tag | pack_num;

        // Update leading dimension
        Ddesc.ld = mk2cs(Packing_v2::apply({Sdesc.rows, Sdesc.cols})).x;

        return 0;
    }
};

template
static LayoutConverter* GetImpl()
{
    constexpr auto mma      = get_mma_tag(pack);
    constexpr auto operand  = get_operand_tag(pack);
    constexpr auto pack_num = get_pack_num(pack);

    static LayoutConverterImpl impl{};

    return &impl;
}

template
struct Cvt {
    template
    LayoutConverter* operator()(Arch, constant, constant) const
    {
        return GetImpl();
    }
};

constexpr constant<(Pack)HMMA_16816> s16816h{};
constexpr constant<(Pack)HMMA_884>   s884h{};

template
constexpr auto operator|(constant, constant)
{
    return constant{};
}

std::array GetConverters(DataType data_type,
                                                    DataType weight_type,  //
                                                    DataType input_type,
                                                    bool     grouped,
                                                    int      sm)
{
    constexpr constant kRow{};
    constexpr constant kCol{};

    constexpr constant A{};
    constexpr constant B{};
    constexpr constant U{};
    constexpr constant V{};

    constexpr constant<1> _1{};
    constexpr constant<2> _2{};

    constexpr Arch<80> sm8_{};
    constexpr Sm75     sm75{};
    constexpr Sm70     sm70{};

    if (weight_type == kHalf || weight_type == kBfloat16) {
        constexpr Cvt W;
        if (grouped) {
            // clang-format off
            if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _1), {}};
            if (sm == 75) return {W(sm75, kRow, s16816h | B | _1), {}};
            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), {}};
            // clang-format on
        }
        else {
            return {};  //  trivial case: dense floating point
        }
    }

    // For performance reasons, u4 use different layouts for grouped/non-grouped GEMM
    if (weight_type == kUint4) {
        constexpr Cvt  W;  // e4m3     weight
        constexpr Cvt S;  // f16/bf16 scales&zeros
        if (grouped) {
            // clang-format off
            if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _2), S(sm8_, kCol, s16816h | V | _1)};
            if (sm == 75) return {W(sm75, kRow, s16816h | B | _2), S(sm75, kCol, s16816h | V | _1)};
            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};
            // clang-format on
        }
        else {
            // clang-format off
            if (sm >= 80) return {W(sm8_, kCol, s16816h | B | _2), S(sm8_, kCol, s16816h | V | _1)};
            if (sm == 75) return {W(sm75, kCol, s16816h | B | _2), S(sm75, kCol, s16816h | V | _1)};
            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};
            // clang-format on
        }
    }

    if (weight_type == kFloat4_e2m1) {
        constexpr Cvt W;  // e2m1  weight
        constexpr Cvt  S;  // ue8m0 scales
        // clang-format off
        if (sm >= 80) return {W(sm8_, kCol, s16816h | A | _1), S(sm8_, kCol, s16816h | U | _1)};
        if (sm == 75) return {W(sm75, kCol, s16816h | A | _1), S(sm75, kCol, s16816h | U | _1)};
        if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};
        // clang-format on
    }

    if (weight_type == kFloat8_e4m3) {
        constexpr Cvt  W;  // e4m3     weight
        constexpr Cvt S;  // f16/bf16 scales
        // clang-format off
        if (sm >= 80) return {W(sm8_, kCol, s16816h | A | _1), S(sm8_, kCol, s16816h | U | _1)};
        if (sm == 75) return {W(sm75, kCol, s16816h | A | _1), S(sm75, kCol, s16816h | U | _1)};
        if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};
        // clang-format on
    }

    TM_CHECK(0) << "Invalid combination: " << sm << " " << data_type << " " << weight_type << " " << input_type << " "
                << grouped;

    return {};
}

namespace {

template
struct Param {
    StridedPtr  data[N];
    StridedPtr* ptr;
    int         n;
};

template
__global__ void fill_strided_ptrs(Param param)
{
    const int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < param.n) {
        param.ptr[idx] = param.data[idx];
    }
}

}  // namespace

void* MakeStridedPtrs(const std::vector>& ptrs, cudaStream_t stream)
{
    constexpr int N = 64;
    Param      param{};
    static_assert(sizeof(param) <= 4096);  // max parameter size for cuda11
    StridedPtr* ptr{};
    cudaMallocAsync(&ptr, sizeof(StridedPtr) * ptrs.size(), stream);
    param.ptr = ptr;
    for (int i = 0; i < (int)ptrs.size(); i += N) {
        const int n = std::min(ptrs.size() - i, N);
        for (int j = 0; j < n; ++j) {
            auto& [p, s]  = ptrs[i + j];
            param.data[j] = StridedPtr{p, s};
        }
        param.n = n;
        fill_strided_ptrs<<<1, N, 0, stream>>>(param);
        param.ptr += N;
    }
    return ptr;
}

namespace {

template
__global__ void fill_blocked_ptrs(Array src, void** dst, int n)
{
    const int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < n) {
        dst[idx] = src[idx];
    }
}

}  // namespace

void* MakeBlockedPtrs(const std::vector>& ptrs, cudaStream_t stream)
{
    constexpr int   N = 64;
    Array src{};
    static_assert(sizeof(src) <= 4096);  // max parameter size for cuda11
    void** dst{};
    cudaMallocAsync(&dst, sizeof(void*) * ptrs.size(), stream);
    for (int i = 0; i < (int)ptrs.size(); i += N) {
        const int n = std::min(ptrs.size() - i, N);
        for (int j = 0; j < n; ++j) {
            auto& [p, s] = ptrs[i + j];
            src[j]       = p;
        }
        fill_blocked_ptrs<<<1, N, 0, stream>>>(src, dst, n);
        dst += n;
    }
    return dst - ptrs.size();
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/cp_async.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif

namespace turbomind {

enum class CacheOp
{
    kDefault,  // use global when possible
    kAlways,
    kGlobal,
};

template
struct GetCacheOp {
    static constexpr auto value = cache_op;
};

template<>
struct GetCacheOp {
    static constexpr auto value = CacheOp::kGlobal;
};

template
struct GetCacheOp {
    static constexpr auto value = CacheOp::kAlways;
};

enum class EvictPolicy
{
    kEvictNormal,
    kEvictFirst,
    kEvictLast,
};

namespace cache_policy {

struct Default {
    static constexpr auto kCacheOp     = CacheOp::kDefault;
    static constexpr auto kEvictPolicy = EvictPolicy::kEvictNormal;
};

struct Stream {
    static constexpr auto kCacheOp     = CacheOp::kDefault;
    static constexpr auto kEvictPolicy = EvictPolicy::kEvictFirst;
};

struct Reuse {
    static constexpr auto kCacheOp     = CacheOp::kAlways;
    static constexpr auto kEvictPolicy = EvictPolicy::kEvictNormal;
};

};  // namespace cache_policy

template
struct CP_ASYNC {
};

template
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global [%1], [%2], 16;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global.L2::cache_hint [%1], [%2], 16, %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "l"(cache_policy));
    }
    // clang-format on
};

template<>
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global" L2_CACHEHINT(64) " [%1], [%2], 16;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global.L2::cache_hint" L2_CACHEHINT(64) " [%1], [%2], 16, %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "l"(cache_policy));
    }
    // clang-format on
};

template<>
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], 16;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global.L2::cache_hint" L2_CACHEHINT(128) " [%1], [%2], 16, %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "l"(cache_policy));
    }
    // clang-format on
};

template<>
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], 16;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.cg.shared.global.L2::cache_hint" L2_CACHEHINT(256) " [%1], [%2], 16, %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "l"(cache_policy));
    }
    // clang-format on
};

template
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global [%1], [%2], %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global.L2::cache_hint [%1], [%2], %3, %4;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size), "l"(cache_policy));
    }
    // clang-format on
};

template
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global" L2_CACHEHINT(64) " [%1], [%2], %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global.L2::cache_hint" L2_CACHEHINT(64) " [%1], [%2], %3, %4;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size), "l"(cache_policy));
    }
    // clang-format on
};

template
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global.L2::cache_hint" L2_CACHEHINT(128) " [%1], [%2], %3, %4;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size), "l"(cache_policy));
    }
    // clang-format on
};

template
struct CP_ASYNC {
    // clang-format off
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size));
    }
    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)
    {
        asm volatile("{\n  .reg .pred p;\n  setp.ne.b32 p, %0, 0;\n"
                     "  @p cp.async.ca.shared.global.L2::cache_hint" L2_CACHEHINT(256) " [%1], [%2], %3, %4;\n"
                     "}\n" ::"r"((int)mask), "r"(smem_ptr), "l"(src), "n"(size), "l"(cache_policy));
    }
    // clang-format on
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/cta_map.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

TM_HOST_DEVICE constexpr int get_log_tile(int size, int tile_size)
{
    if (tile_size >= 32 && size >= 24)
        return 5;
    if (tile_size >= 16 && size >= 12)
        return 4;
    if (tile_size >= 8 && size >= 6)
        return 3;
    if (tile_size >= 4 && size >= 3)
        return 2;
    if (tile_size >= 2 && size >= 2)
        return 1;
    return 0;
}

TM_HOST_DEVICE constexpr int2 get_tiled_shape(int m, int n, int cta_m, int cta_n)
{
    return {ceil_div(m, cta_m), ceil_div(n, cta_n)};
}

struct CtaMap_ {

    TM_HOST_DEVICE static int3 get_tiled_shape(int m, int n, int k, int cta_m, int cta_n, int split_cnt)
    {
        return {(m + cta_m - 1) / cta_m, (n + cta_n - 1) / cta_n, split_cnt};
    }

    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int N)
    {
        return gemm::get_log_tile(tiled_mn.y, N);
    }

    TM_HOST_DEVICE static dim3 get_grid_shape(int3 tiled_shape, int log_tile)
    {
        int tile = 1 << log_tile;
        return {static_cast(tiled_shape.x * tile),
                static_cast((tiled_shape.y + tile - 1) / tile),
                static_cast(tiled_shape.z)};
    }

    TM_DEVICE static int3 get_tile_offset(int log_tile)
    {
        int block_idx_x = blockIdx.x;
        int block_idx_y = blockIdx.y;
        int block_idx_z = blockIdx.z;
        return {(block_idx_x >> log_tile),  //
                (block_idx_y << log_tile) + (block_idx_x & ((1 << log_tile) - 1)),
                block_idx_z};
    }
};

template
class GemmScheduler {

    static constexpr auto order = order_;

    int4 gemm_shape_;
    int4 tiled_shape_;
    int  log_tile_;

    int chunk_offset_;
    int chunks_per_split_;
    int iter_k_per_chunk_;

    int4 tile_offset_;
    int2 iter_k_range_;

public:
    TM_HOST_DEVICE
    GemmScheduler(int4 gemm_shape, int2 tiled_mn, int splits, int log_tile, int cta_k, int chunk_size):
        gemm_shape_{gemm_shape}, tiled_shape_{tiled_mn.x, tiled_mn.y, splits}, log_tile_{log_tile}
    {
        const int chunk_cnt = cdiv(gemm_shape_.z, chunk_size);

        iter_k_per_chunk_ = chunk_size / cta_k;
        chunks_per_split_ = chunk_cnt / splits;
        chunk_offset_     = splits - chunk_cnt % splits;
    }

    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)
    {
        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);
    }

    TM_HOST_DEVICE static dim3 get_grid_shape(int4 tiled_shape, int log_tile)
    {
        const int tile = 1 << log_tile;
        if constexpr (order == kColMajor) {
            return {(unsigned)(tiled_shape.x * tile), (unsigned)(cdiv(tiled_shape.y, tile)), (unsigned)(tiled_shape.z)};
        }
        else {
            return {(unsigned)(tiled_shape.y * tile), (unsigned)(cdiv(tiled_shape.x, tile)), (unsigned)(tiled_shape.z)};
        }
    }

    TM_HOST_DEVICE dim3 get_grid_shape() const
    {
        return get_grid_shape(tiled_shape_, log_tile_);
    }

    TM_HOST_DEVICE std::true_type init(int block_idx_x, int block_idx_y, int block_idx_z)
    {
        if constexpr (order == kColMajor) {
            tile_offset_ = {(block_idx_x >> log_tile_),
                            (block_idx_y << log_tile_) + (block_idx_x & ((1 << log_tile_) - 1)),
                            (block_idx_z)};
        }
        else {
            tile_offset_ = {(block_idx_y << log_tile_) + (block_idx_x & ((1 << log_tile_) - 1)),
                            (block_idx_x >> log_tile_),
                            (block_idx_z)};
        }
        tile_offset_.w       = 0;
        const int chunk_id   = tile_offset_.z * chunks_per_split_ + max(tile_offset_.z - chunk_offset_, 0);
        const int iter_k_beg = chunk_id * iter_k_per_chunk_;
        const int iter_k_cnt = (chunks_per_split_ + int(tile_offset_.z >= chunk_offset_)) * iter_k_per_chunk_;
        iter_k_range_        = {iter_k_beg, iter_k_beg + iter_k_cnt};

        return {};
    }

    TM_DEVICE std::true_type init()
    {
        return init(blockIdx.x, blockIdx.y, blockIdx.z);
    }

    TM_DEVICE int4 gemm_shape() const
    {
        return gemm_shape_;
    }

    TM_DEVICE int4 tiled_shape() const
    {
        return tiled_shape_;
    }

    TM_DEVICE int4 tile_offset() const
    {
        return tile_offset_;
    }

    TM_DEVICE int2 iter_k_range() const
    {
        return iter_k_range_;
    }

    TM_DEVICE int tile_id() const
    {
        return tile_offset_.x * tiled_shape_.y + tile_offset_.y;
    }
};

template
class DynamicScheduler {

    static constexpr auto order = order_;

    int ctas_;

    const int4* __restrict__ gemm_shapes_;    // [group_num]
    const int4* __restrict__ tiled_shapes_;   // [group_num]
    const int2* __restrict__ offsets_mn_;     // [group_num]
    const int4* __restrict__ tile_offsets_;   // [ctas]
    const int2* __restrict__ iter_k_ranges_;  // [ctas]
    const int* __restrict__ tile_ids_;        // [ctas]

    int4 gemm_shape_;
    int4 tiled_shape_;
    int4 tile_offset_;
    int2 iter_k_range_;
    int2 base_mn_;

public:
    DynamicScheduler(const Tape& tape):
        ctas_{tape.ctas},
        gemm_shapes_{tape.gemm_shapes},
        tiled_shapes_{tape.tiled_shapes},
        tile_offsets_{tape.tile_offsets},
        iter_k_ranges_{tape.iter_k_ranges},
        tile_ids_{tape.tile_ids}
    {
    }

    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)
    {
        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);
    }

    TM_HOST_DEVICE dim3 get_grid_shape()
    {
        return {(unsigned)ctas_, 1, 1};
    }

    TM_DEVICE bool init()
    {
        const int block_idx = blockIdx.x;

        const auto [cta_m_id, cta_n_id, cta_k_id, group_id] = __ldg(tile_offsets_ + block_idx);

        if (group_id < 0) {
            return false;
        }

        gemm_shape_  = __ldg(gemm_shapes_ + group_id);
        tiled_shape_ = __ldg(tiled_shapes_ + group_id);
        base_mn_     = __ldg(offsets_mn_ + group_id);

        tile_offset_ = {cta_m_id, cta_n_id, cta_k_id, group_id};

        iter_k_range_ = __ldg(iter_k_ranges_ + block_idx);

        return true;
    }

    TM_DEVICE int4 gemm_shape() const
    {
        return gemm_shape_;
    }

    TM_DEVICE int4 tiled_shape() const
    {
        return tiled_shape_;
    }

    TM_DEVICE int4 tile_offset() const
    {
        return tile_offset_;
    }

    TM_DEVICE int2 iter_k_range() const
    {
        return iter_k_range_;
    }

    TM_DEVICE int tile_id() const
    {
        return tile_ids_[blockIdx.x];
    }
};

template
struct is_dynamic_scheduler: std::false_type {
};

template
struct is_dynamic_scheduler>: std::true_type {
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/cublas.cu
================================================
#include 

#include "src/turbomind/core/cuda_data_type.h"
#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

class CublasKernel: public Kernel {
public:
    explicit CublasKernel(): cublas_{}
    {
        cublasCreate(&cublas_);
        if (0) {
            cublasSetMathMode(cublas_, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
        }

        desc_.backend    = 1;
        desc_.group_axis = -1;

        info_.chunk_size_k      = 1;
        info_.dynamic_smem_size = 0;

        info_.name = GetName();
    }

    ~CublasKernel() override
    {
        cublasDestroy(cublas_);
        cublas_ = {};
    }

    int Launch(const Operation&    operation,
               float               alpha,
               const void*         A,
               const MatrixLayout& Adesc,
               const void*         U,
               const MatrixLayout& Udesc,
               const void*         B,
               const MatrixLayout& Bdesc,
               const void*         V,
               const MatrixLayout& Vdesc,
               float               beta,
               const void*         C,
               const MatrixLayout& Cdesc,
               void*               D,
               const MatrixLayout& Ddesc,
               int                 swizzle,
               int                 splits,
               Workspace&          workspace,
               cudaStream_t        stream) override
    {
        cublasOperation_t transa = Adesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;
        cublasOperation_t transb = Bdesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;

        const int m = Adesc.rows;
        const int n = Bdesc.cols;
        const int k = Adesc.cols;

        TM_CHECK_EQ(Bdesc.rows, k);
        TM_CHECK_EQ(Ddesc.rows, m);
        TM_CHECK_EQ(Ddesc.cols, n);

        TM_CHECK(C == nullptr || C == D);

        if (stream_ != stream) {
            cublasSetStream(cublas_, stream);
            stream_ = stream;
        }

        if (workspace_ != workspace.partials || workspace_size_ != workspace.partials_size) {
            cublasSetWorkspace(cublas_, workspace.partials, workspace.partials_size);
            workspace_      = workspace.partials;
            workspace_size_ = workspace.partials_size;
        }

        auto ec = cublasGemmEx(cublas_,
                               transa,
                               transb,
                               m,
                               n,
                               k,
                               &alpha,
                               A,
                               to_cuda_dtype(Adesc.type),
                               Adesc.ld,
                               B,
                               to_cuda_dtype(Bdesc.type),
                               Bdesc.ld,
                               &beta,
                               D,
                               to_cuda_dtype(Ddesc.type),
                               Ddesc.ld,
                               CUDA_R_32F,
                               CUBLAS_GEMM_DEFAULT_TENSOR_OP);

        return ec == CUBLAS_STATUS_SUCCESS ? 0 : 1;
    }

    bool is_feasible(const GemmDesc& desc) const noexcept override
    {
        constexpr std::tuple flat3{Striding::kFlat, Striding::kFlat, Striding::kFlat};

        if (std::tie(desc.striding_a, desc.striding_b, desc.striding_c) != flat3) {
            return false;
        }
        if (std::tie(desc.pack_a, desc.pack_b, desc.pack_u, desc.pack_v) != std::tuple{0, 0, 0, 0}) {
            return false;
        }
        if (desc.epilogue != Epilogue::kNone) {
            return false;
        }
        if (desc.num > 1) {
            return false;
        }
        if (desc.quant_a || desc.quant_b) {
            return false;
        }
        if (desc.group_axis >= 0) {
            return false;
        }
        if (desc.order_c != kColMajor) {
            return false;
        }
        if (desc.type_a != kHalf && desc.type_a != kBfloat16 && desc.type_a != kFloat) {
            return false;
        }
        if (desc.type_b != desc.type_a) {
            return false;
        }
        if (desc.type_c != desc.type_a && desc.type_c != kFloat) {
            return false;
        }
        return true;
    }

    int GetMaxSwizzle(const int4&) const override
    {
        return 0;
    }

    int GetMaxSplits(const int4&, int, size_t, size_t) const override
    {
        return 1;
    }

private:
    cublasHandle_t cublas_{};
    cudaStream_t   stream_{};
    void*          workspace_{};
    size_t         workspace_size_{};
};

void Registry::cublas_float()
{
    Add(std::make_unique());
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/desc.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

// aggregate that uniquely identifies a GEMM problem
struct GemmDesc {
    int       arch;
    DataType  type_a;
    DataType  type_b;
    DataType  type_c;
    Order     order_a;
    Order     order_b;
    Order     order_c;
    Striding  striding_a;
    Striding  striding_b;
    Striding  striding_c;
    Pack      pack_a;
    Pack      pack_b;
    Pack      pack_u;
    Pack      pack_v;
    QuantDesc quant_a;
    QuantDesc quant_b;
    Epilogue  epilogue;
    int       batch_dim;
    int       group_axis;
    int       m;
    int       n;
    int       k;
    int       num;
};

static_assert(std::is_trivially_copyable_v);

inline GemmDesc transpose(GemmDesc d)
{
    std::swap(d.type_a, d.type_b);
    std::swap(d.order_a, d.order_b);
    d.order_a = ~d.order_a;
    d.order_b = ~d.order_b;
    d.order_c = ~d.order_c;
    std::swap(d.striding_a, d.striding_b);
    std::swap(d.pack_a, d.pack_b);
    std::swap(d.pack_u, d.pack_v);
    std::swap(d.quant_a, d.quant_b);
    std::swap(d.m, d.n);
    d.batch_dim = 1 - d.batch_dim;
    if (d.group_axis >= 0) {
        d.group_axis = 1 - d.group_axis;
    }
    return d;
}

inline std::string to_string(const GemmDesc& d)
{
    std::stringstream ss;
    ss << "sm" << d.arch / 10;
    ss << "_" << to_string(d.type_a);  //
    if (d.quant_a) {
        ss << to_string(d.quant_a);
    }
    ss << "_" << to_string(d.type_b);  //
    if (d.quant_b) {
        ss << to_string(d.quant_b);
    }
    ss << "_" << to_string(d.type_c);
    ss << "_"                                    //
       << (d.order_a == kColMajor ? 'n' : 't')   //
       << (d.order_b == kColMajor ? 'n' : 't')   //
       << (d.order_c == kColMajor ? 'n' : 't');  //
    ss << "_"                                    //
       << to_string(d.striding_a)                //
       << to_string(d.striding_b)                //
       << to_string(d.striding_c);
    ss << "_" << d.m << "x" << d.n << "x" << d.k;
    ss << "_" << d.num;
    return ss.str();
}

enum class OpClass
{
    kSIMT,
    kMMA_s884,
    kMMA_s16816,
    kGMMA_s64n16
};

inline const char* to_string(OpClass op)
{
    switch (op) {
        case OpClass::kSIMT:
            return "simt";
        case OpClass::kMMA_s884:
            return "s884";
        case OpClass::kMMA_s16816:
            return "s16816";
        default:
            return "unknown_op_cls";
    }
}

// aggregate that uniquely identifies a kernel
struct KernelDesc {
    int       arch;
    OpClass   op_class;
    DataType  type_a;
    DataType  type_b;
    DataType  type_c;
    Order     order_a;
    Order     order_b;
    Order     order_c;
    Striding  striding_a;
    Striding  striding_b;
    Striding  striding_c;
    Pack      pack_a;
    Pack      pack_b;
    Pack      pack_u;
    Pack      pack_v;
    QuantDesc quant_a;
    QuantDesc quant_b;
    int       policy_a;
    int       policy_b;
    int3      cta_tile;
    int3      mma_tile;
    int2      cluster_shape;
    int3      align;
    int2      c_tile;
    int       stages;
    bool      split_k;
    int       group_axis;
    int       backend;
    bool      transpose;
};

static_assert(std::is_trivially_copyable_v);

struct KernelInfo {
    int dynamic_smem_size;
    int max_active_ctas;
    int chunk_size_k;

    std::string name;

    cudaFuncAttributes attr;
};

inline KernelDesc transpose(const KernelDesc& d)
{
    KernelDesc k{d};

    k.arch     = d.arch;
    k.op_class = d.op_class;

    k.order_a = ~d.order_b;
    k.order_b = ~d.order_a;
    k.order_c = ~d.order_c;

    k.type_a = d.type_b;
    k.type_b = d.type_a;

    k.striding_a = d.striding_b;
    k.striding_b = d.striding_a;

    k.pack_a = d.pack_b;
    k.pack_b = d.pack_a;
    k.pack_u = d.pack_v;
    k.pack_v = d.pack_u;

    k.quant_a = d.quant_b;
    k.quant_b = d.quant_a;

    k.policy_a = d.policy_b;
    k.policy_b = d.policy_a;

    auto swap = [](auto& v) { std::swap(v.x, v.y); };

    swap(k.cta_tile);
    swap(k.mma_tile);
    swap(k.cluster_shape);
    swap(k.align);
    swap(k.c_tile);

    return k;
}

class Kernel;
struct LaunchSpec {
    Kernel* kernel;
    int     swizzle;
    int     splits;
    float   measured;

    std::array estimated;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/dispatch_cache.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/dispatch_cache.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/types.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 

static inline bool operator==(const int3& a, const int3& b)
{
    return a.x == b.x && a.y == b.y && a.z == b.z;
}

static inline bool operator==(const int2& a, const int2& b)
{
    return a.x == b.x && a.y == b.y;
}

namespace turbomind::gemm {

static inline decltype(auto) as_tuple(const KernelDesc& d)
{
    return std::tie(d.arch,
                    d.op_class,
                    d.type_a,
                    d.type_b,
                    d.type_c,
                    d.order_a,
                    d.order_b,
                    d.order_c,
                    d.striding_a,
                    d.striding_b,
                    d.striding_c,
                    d.pack_a,
                    d.pack_b,
                    d.pack_u,
                    d.pack_v,
                    d.quant_a,
                    d.quant_b,
                    d.policy_a,
                    d.policy_b,
                    d.cta_tile,
                    d.mma_tile,
                    d.cluster_shape,
                    d.align,
                    d.c_tile,
                    d.stages,
                    d.split_k,
                    d.backend,
                    d.transpose,
                    d.group_axis);
}

static inline bool operator==(const QuantDesc& a, const QuantDesc& b)
{
    return a.type == b.type && a.group_size == b.group_size;
}

static inline bool operator==(const KernelDesc& a, const KernelDesc& b)
{
    return as_tuple(a) == as_tuple(b);
}

namespace {

struct Record {
    GemmDesc   gemm;
    KernelDesc kernel;

    int swizzle;
    int splits;
};

}  // namespace

void ExportDispatchCache(std::ostream& os, const std::vector>& entries)
{

    for (const auto& [g, spec] : entries) {
        Record record{};
        record.gemm    = g;
        record.kernel  = spec.kernel->desc();
        record.splits  = spec.splits;
        record.swizzle = spec.swizzle;

        os.write((const char*)&record, sizeof(record));
    }
}

void ImportDispatchCache(std::istream&                                 is,
                         std::vector>& entries,
                         const std::vector&                   kernels)
{
    is.seekg(0, is.end);
    const auto size_in_bytes = is.tellg();
    is.seekg(0, is.beg);

    if (size_in_bytes % sizeof(Record)) {
        std::cerr << "File size is not a multiple of record size, faild to import records.\n";
    }

    const int n = size_in_bytes / sizeof(Record);

    for (int i = 0; i < n; ++i) {
        Record record;
        is.read((char*)&record, sizeof(Record));

        LaunchSpec spec{};
        spec.splits  = record.splits;
        spec.swizzle = record.swizzle;

        for (const auto& p : kernels) {
            if (p->desc() == record.kernel) {
                spec.kernel = p;
                break;
            }
        }
        if (spec.kernel) {
            entries.emplace_back(record.gemm, spec);
        }
        else {
            std::cerr << "No kernel found for entry " << i << "\n";
        }
    }
}

namespace {

inline decltype(auto) as_tuple(const GemmDesc& d)
{
    return std::tie(d.arch,
                    d.type_a,
                    d.type_b,
                    d.type_c,
                    d.order_a,
                    d.order_b,
                    d.order_c,
                    d.striding_a,
                    d.striding_b,
                    d.striding_c,
                    d.pack_a,
                    d.pack_b,
                    d.pack_u,
                    d.pack_v,
                    d.quant_a.type,
                    d.quant_a.group_size,
                    d.quant_b.type,
                    d.quant_b.group_size,
                    d.batch_dim,
                    d.group_axis,
                    d.m,
                    d.n,
                    d.k,
                    d.num);
    // Note: `d.epilogue` is not used yet
}

}  // namespace

inline bool operator<(const GemmDesc& a, const GemmDesc& b)
{
    return as_tuple(a) < as_tuple(b);
}

int extract_batch_size(GemmDesc& desc)
{
    return std::exchange(desc.batch_dim == 0 ? desc.m : desc.n, 0);
}

void set_batch_size(GemmDesc& desc, int batch_size)
{
    (desc.batch_dim == 0 ? desc.m : desc.n) = batch_size;
}

struct DispatchCache::Impl {

    struct Flat {
        std::vector> idxs;
        std::vector          specs;
    };

    const std::vector kernels_;
    std::map   cache_;

    Impl(std::vector kernels): kernels_(std::move(kernels)) {}

    std::optional Find(GemmDesc desc, bool exact) const
    {
        const int batch_size = extract_batch_size(desc);
        // std::cerr << batch_size << " " << desc.m << " " << desc.n << " " << desc.k << " " << std::boolalpha << exact
        //           << "\n";
        const auto it = cache_.find(desc);
        if (it != cache_.end()) {
            const auto& [idxs, specs] = it->second;
            // Find index via key
            const auto p =
                std::lower_bound(idxs.begin(), idxs.end(), std::make_pair(batch_size, 0), [](auto& a, auto& b) {  //
                    return a.first < b.first;
                });
            // std::cout << it->second.specs.size() << std::endl;
            if (p != idxs.end() && (!exact || p->first == batch_size)) {
                // std::cerr << p->first << " " << p->second << "\n";
                return specs[p->second];
            }
        }
        return {};
    }

    bool Insert(GemmDesc desc, const LaunchSpec& spec)
    {
        const int batch_size = extract_batch_size(desc);

        auto it = cache_.find(desc);
        if (it == cache_.end()) {
            it = cache_.emplace_hint(it, desc, Flat{});
        }
        auto& [idxs, specs] = it->second;
        // Find index via key
        const auto p =
            std::lower_bound(idxs.begin(), idxs.end(), std::make_pair(batch_size, 0), [](auto& a, auto& b) {  //
                return a.first < b.first;
            });
        // Exact match, skip
        if (p != idxs.end() && p->first == batch_size) {
            return false;
        }
        // Insert
        idxs.insert(p, {batch_size, (int)specs.size()});
        specs.push_back(spec);
        return true;
    }

    int Export(std::ostream& os) const
    {
        std::vector> entries;
        for (const auto& [desc, flat] : cache_) {
            auto tmp = desc;
            for (const auto& [batch_size, index] : flat.idxs) {
                set_batch_size(tmp, batch_size);
                entries.emplace_back(tmp, flat.specs[index]);
            }
        }
        Summary(entries);
        ExportDispatchCache(os, entries);
        return entries.size();
    }

    int Import(std::istream& is)
    {
        std::vector> entries;
        ImportDispatchCache(is, entries, kernels_);
        Summary(entries);
        for (auto [desc, spec] : entries) {
            const int batch_size = extract_batch_size(desc);
            auto      it         = cache_.find(desc);
            if (it == cache_.end()) {
                it = cache_.emplace_hint(it, desc, Flat{});
            }
            auto& [idxs, specs] = it->second;
            // Order is not maintained at this point
            idxs.emplace_back(batch_size, (int)specs.size());
            specs.push_back(spec);
        }
        // Sort indices and deduplicate
        for (auto& [desc, flat] : cache_) {
            auto& [idxs, specs] = flat;
            std::stable_sort(idxs.begin(), idxs.end(), [](auto a, auto b) { return a.first < b.first; });
            idxs.erase(std::unique(idxs.begin(), idxs.end(), [](auto a, auto b) { return a.first == b.first; }),
                       idxs.end());
            // Remove unreferenced specs and update spec indices
            std::vector tmp;
            for (auto& [key, val] : idxs) {
                int old = std::exchange(val, tmp.size());
                tmp.push_back(specs[old]);
            }
            specs = std::move(tmp);
        }
        return entries.size();
    }

    // Print a summary of how many cases a kernel is used
    void Summary(const std::vector>& entries) const
    {
        std::vector uses{nullptr};
        std::copy(kernels_.begin(), kernels_.end(), std::back_inserter(uses));

        for (const auto& [_, s] : entries) {
            uses.push_back(s.kernel);
        }
        std::sort(uses.begin(), uses.end());
        std::vector> count;
        for (size_t i = 1; i < uses.size(); ++i) {
            if (uses[i] != uses[i - 1]) {
                count.emplace_back(-1, uses[i]);
            }
            ++count.back().first;
        }
        std::sort(count.begin(), count.end(), std::greater<>{});
        for (const auto& [n, k] : count) {
            std::cout << k->name() << ": " << n << "\n";
        }
    }
};

DispatchCache::DispatchCache(std::vector kernels): impl_(std::make_unique(std::move(kernels))) {}

DispatchCache::~DispatchCache() = default;

std::optional DispatchCache::Find(const GemmDesc& desc) const
{
    return impl_->Find(desc, true);
}

std::optional DispatchCache::LowerBound(const GemmDesc& desc) const
{
    return impl_->Find(desc, false);
}

bool DispatchCache::Insert(const GemmDesc& desc, const LaunchSpec& spec)
{
    return impl_->Insert(desc, spec);
}

int DispatchCache::Export(std::ostream& os) const
{
    return impl_->Export(os);
}

int DispatchCache::Import(std::istream& is)
{
    return impl_->Import(is);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/dispatch_cache.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/desc.h"

#include 
#include 
#include 

namespace turbomind::gemm {

class DispatchCache {
public:
    DispatchCache(std::vector kernels);

    ~DispatchCache();

    std::optional LowerBound(const GemmDesc& desc) const;

    std::optional Find(const GemmDesc& desc) const;

    bool Insert(const GemmDesc& desc, const LaunchSpec& spec);

    int Export(std::ostream& os) const;

    int Import(std::istream& is);

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/epilogue.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/sync.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/predicate.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

template
struct ChannelCombination_v3 {
    const Tc* __restrict__ scale_bias_ptr;

    template
    __device__ void operator()(Array (&x)[S][C], int2 cs0, pair, Pred& pred) const
    {
        __align__(16) Array scale_bias[S];

        if (scale_bias_ptr) {
            constexpr int ds  = sizeof(Tc) * delta_s;
            auto          ptr = reinterpret_cast(scale_bias_ptr + cs0.y);
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                if (pred(s, 0)) {
                    Ldg(scale_bias[s], reinterpret_cast(ptr));
                }
                ptr += ds;
            }
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                auto tmp = cast(scale_bias[s]);
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    using namespace ops;
                    x[s][c] = x[s][c] * tmp[0] + tmp[1];
                }
            }
        }
    }
};

template
__device__ void Scale(pair,
                      pair,
                      pair,
                      Array (&x)[S][C],
                      const MatrixParam& param_S,
                      const MatrixParam& param_C,
                      int                gemm_id,
                      int2               cs0,
                      Pred&              pred)
{
    if (scale_S && param_S.ptr) {
        const auto mat = resolve(param_S, gemm_id);
        const T*   ptr = (const T*)mat.ptr.ptr;
        T          param[S];
        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            const int ss  = cs0.y + s * delta_S;
            const int idx = mat.idxs ? __ldg(mat.idxs + ss) : ss;
            if (pred(s, 0)) {
                param[s] = __ldg((const T*)(ptr + idx));
            }
            PRAGMA_UNROLL
            for (int c = 0; c < C; ++c) {
                using namespace ops;
                x[s][c] = x[s][c] * param[s];
            }
        }
    }

    if (scale_C && param_C.ptr) {
        const T*      ptr = (const T*)resolve(param_C, gemm_id).ptr.ptr + cs0.x;
        constexpr int dc  = sizeof(Array) * delta_C;
        Array   param[C];
        PRAGMA_UNROLL
        for (int c = 0; c < C; ++c) {
            if (pred(0, c)) {
                Ldg(param[c], (const T*)(ptr + dc * c));
            }
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                using namespace ops;
                x[s][c] = x[s][c] * param[c];
            }
        }
    }
}

struct MatrixCombination_v3 {

    MatrixParam param_c;
    float       alpha;
    float       beta;

    template
    __device__ void operator()(Tc*,  //
                               constant,
                               Array (&x)[S][C],
                               int2 cs0,
                               int  gemm_id,
                               pair,
                               Pred& pred) const
    {
        if (beta) {
            const auto c = resolve(param_c, gemm_id);

            Array  frag[S][C];
            constexpr int dc  = sizeof(Tc) * delta_c;
            const int     ds  = sizeof(Tc) * delta_s * c.ptr.stride;
            const char*   ptr = (const char*)c.ptr.ptr + sizeof(Tc) * dot(cs0, long2{1, c.ptr.stride});
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    if (pred(s, c)) {
                        Load(frag[s][c], reinterpret_cast(ptr));
                        using namespace ops;
                        x[s][c] = x[s][c] * alpha + cast(frag[s][c]) * beta;
                    }
                    ptr += dc;
                }
                ptr -= dc * C;
                ptr += ds;
            }
        }
        else if (alpha != 1.f) {
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    using namespace ops;
                    x[s][c] = x[s][c] * alpha;
                }
            }
        }
    }
};

template
struct GatedActivation {
    template
    __device__ static void apply(Array& x)
    {
        static_assert(N % 2 == 0);
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 2) {
            x[i / 2] = static_cast(Act::apply(x[i]) * x[i + 1]);
        }
    }
};

struct Silu {
    __device__ static float apply(float x)
    {
        return fdividef(x, 1.f + expf(-x));
    }
};

struct EpilogueParam {
    MatrixParam c;
    MatrixParam partials;
    int*        locks;

    // MatrixParam scale_S;
    // MatrixParam scale_C;

    MatrixCombination_v3 combine_mat;

    bool silu_act;
};

template
struct Epilogue_ {

    using Dtype = typename OperandC::Dtype;

    static constexpr auto kOrder = OperandC::kOrder;
    static constexpr auto kMode  = mode_C;
    static constexpr bool SplitK = SplitK_;

    using Tc = Tc_;

    static constexpr int TM = TM_;
    static constexpr int TN = TN_;

    using SmemLayout = decltype(OperandC::GetSmemLayout::apply(pair{}));

    using SmemAccessorV2 = SmemAccessorV2;

    using SharedStorage = Array;

    using Map = decltype(OperandC::GetThreadMap::apply(pair{}, constant{}));

    static constexpr int S       = Map::kIterS;
    static constexpr int C       = Map::kIterC;
    static constexpr int kAccess = Map::kAccessC;

    template
    using OutputC = Array;

    template
    __device__ void Rearrange(FragC& frag_C, SharedStorage& storage, OutputC (&out)[S][C])
    {
        SmemAccessorV2 smem_C{storage.data()};

        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);

        constexpr int kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);
        constexpr int kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);

        int phases[kPeriodS][kPeriodC];
        PRAGMA_UNROLL
        for (int s = 0; s < kPeriodS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < kPeriodC; ++c) {
                phases[s][c] = SmemLayout::apply(s * Map::kDeltaS + thr_cs.y, c * Map::kDeltaC + thr_cs.x);
            }
        }

        constexpr bool kRaked = true;

        PRAGMA_UNROLL
        for (int m = 0; m < M; m += TM) {
            PRAGMA_UNROLL
            for (int n = 0; n < N; n += TN) {
                // Store to shared memory
                RearrangeC::apply(frag_C, smem_C, {m, n}, pair{});

                // Load from shared memory
                PRAGMA_UNROLL
                for (int s = 0; s < S; ++s) {
                    PRAGMA_UNROLL
                    for (int c = 0; c < C; ++c) {
                        const int cc = c * Map::kDeltaC + thr_cs.x;
                        const int ss = s * Map::kDeltaS + thr_cs.y;

                        const int2 mn =
                            kRaked ? cs2mk(c * Map::kDeltaC, s * Map::kDeltaS) : cs2mk(cc, ss);
                        const int  mm   = mn.x - m;
                        const int  nn   = mn.y - n;
                        const bool mask = (M <= TM || (0 <= mm && mm < TM)) && ((N <= TN) || (0 <= nn && nn < TN));

                        const int2 _cs      = mk2cs(m, n);
                        const int  offset_0 = SmemLayout::apply(  //
                            s / kPeriodS * kPeriodS * Map::kDeltaS - _cs.y,
                            c / kPeriodC * kPeriodC * Map::kDeltaC - _cs.x);
                        const int  offset_p = phases[s % kPeriodS][c % kPeriodC];

                        if (mask) {
                            Load(out[s][c], &storage[offset_0 + offset_p]);
                        }
                    }
                }
                __syncthreads();
            }
        }
    }

    template
    __device__ void StoreC(const VecC& vec_C, const MatrixData& c, int2 cs0, Pred& pred)
    {
        constexpr int dc  = sizeof(T) * Map::kDeltaC;
        const int     ds  = sizeof(T) * Map::kDeltaS * c.ptr.stride;
        char*         ptr = (char*)c.ptr.ptr + sizeof(T) * dot(cs0, long2{1, c.ptr.stride});
        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < C; ++c) {
                const auto tmp = cast(vec_C[s][c]);
                if (pred(s, c)) {
                    Store(reinterpret_cast(ptr), tmp);
                }
                ptr += dc;
            }
            ptr -= dc * C;
            ptr += ds;
        }
    }

#if 0
    template
    __device__ void
    Reduce(FragC& frag_C, int splits, int64_t split_size, const int2& cta_cs, Pred& pred, const EpilogueParam& param)
    {
        using Vec         = OutputC;
        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);
        for (int k = 0; k < splits; ++k) {
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    const int     ss  = thr_cs.y + s * Map::kDeltaS;
                    const int     cc  = thr_cs.x + c * Map::kDeltaC;
                    const int64_t idx = k * split_size + (cta_cs.y + ss) * param.partial_C_ld + (cta_cs.x + cc);
                    if (true) {
                        Vec tmp;
                        Load(tmp, ¶m.partial_C[idx]);
                        using namespace ops;
                        frag_C[s][c] = frag_C[s][c] + tmp;
                    }
                }
            }
        }
    }
#endif

    template
    __device__ void Reduce(FragC& frag_C, const MatrixData& p, bool is_first, bool is_last, int2 cs0, Pred& pred)
    {
        constexpr int dc = sizeof(Dtype) * Map::kDeltaC;
        const int     ds = sizeof(Dtype) * Map::kDeltaS * p.ptr.stride;

        char* ptr = (char*)p.ptr.ptr + sizeof(Dtype) * dot(cs0, long2{1, p.ptr.stride});

        Pred ld_mask = is_first ? Pred{} : pred;
        Pred st_mask = is_last ? Pred{} : pred;

        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < C; ++c) {
                OutputC tmp{};  // ! ZERO-filled
                if (ld_mask(s, c)) {
                    Load(tmp, reinterpret_cast(ptr));
                }
                if (1) {
                    using namespace ops;
                    frag_C[s][c] = frag_C[s][c] + tmp;
                }
                if (st_mask(s, c)) {
                    Store(reinterpret_cast(ptr), frag_C[s][c]);
                }
                ptr += dc;
            }
            ptr -= dc * C;
            ptr += ds;
        }
    }

    template
    __device__ void operator()(FragC&               frag_C,
                               const int4&          tile_offset,
                               const int2&          extents,
                               int                  splits,
                               int                  tile_id,
                               bool                 is_last,
                               const EpilogueParam& param,
                               SharedStorage&       storage)
    {
        const int2 cta_cs = mk2cs(tile_offset.x * M, tile_offset.y * N);
        const int2 end_cs = mk2cs(extents);

        OutputC tmp_C[S][C];

        Rearrange(frag_C, storage, tmp_C);

        Predicate pred{};  //  1 regs

        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);
        const int2 cs0    = {cta_cs.x + thr_cs.x, cta_cs.y + thr_cs.y};

        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < C; ++c) {
                const int ss = thr_cs.y + s * Map::kDeltaS;
                const int cc = thr_cs.x + c * Map::kDeltaC;
                if (ss < end_cs.y && cc < end_cs.x) {
                    pred.set(s, c);
                }
            }
        }

        if (SplitK_ && splits > 1) {
            int* barrier = ¶m.locks[tile_id];

            sem_wait(barrier, tile_offset.z, threadIdx.x == 0);

            const MatrixData p = resolve(param.partials, tile_offset.w);

            Reduce(tmp_C, p, tile_offset.z == 0, is_last, cs0, pred);

            const int post_id = is_last ? 0 : tile_offset.z + 1;
            sem_post(barrier, post_id, threadIdx.x == 0);

            if (!is_last) {
                return;
            }
        }

        constexpr pair delta_cs{};

        // opt-in scaling
        // Scale(scale_SC{}, mode_SC{}, delta_cs, tmp_C, param.scale_S, param.scale_C, tile_offset.w, cs0, pred);

        param.combine_mat((Tc*)0, constant{}, tmp_C, cs0, tile_offset.w, delta_cs, pred);

        const MatrixData c = resolve(param.c, tile_offset.w);

        if (param.silu_act) {
            constexpr int dc  = sizeof(Tc) * Map::kDeltaC / 2;
            const int     ds  = sizeof(Tc) * Map::kDeltaS * c.ptr.stride;
            auto          ptr = (char*)c.ptr.ptr + sizeof(Tc) * dot({cs0.x / 2, cs0.y}, long2{1, c.ptr.stride});
            PRAGMA_UNROLL
            for (int s = 0; s < S; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    GatedActivation::apply(tmp_C[s][c]);
                    if (pred(s, c)) {
                        const auto tmp = cast((Array&)tmp_C[s][c]);
                        Store(reinterpret_cast(ptr), tmp);
                    }
                    ptr += dc;
                }
                ptr -= dc * C;
                ptr += ds;
            }
        }
        else {
            StoreC(tmp_C, c, cs0, pred);
        }
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/format.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"

namespace turbomind::gemm {

template
struct Converter {
};

template
struct Converter {
    template
    __device__ Array operator()(Array x)
    {
        return x;
    }
};

template<>
struct Converter {

    static __device__ Array pack(const Array& vi)
    {
        Array ui = (Array&)vi;

        ui[0] |= (ui[0] >> 12);
        ui[1] |= (ui[1] >> 12);

        //  7 6 5 4 3 2 1 0
        // _7_67564_3_23120
        uint32_t uo = __byte_perm(ui[0], ui[1], 0x5140);

        return (Array&)uo;
    }

    template
    __device__ Array operator()(const Array& x)
    {
        static_assert(sizeof(U) == 2);
        auto&             vi = (const Array&)x;
        Array tmp;
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            tmp[i] = static_cast(vi[i]);
        }
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 8) {
            (Array&)vo[i] = pack((Array&)tmp[i]);
        }
        return vo;
    }
};

template<>
struct Converter {
    template
    __device__ Array operator()(const Array& x)
    {
        static_assert(N % 4 == 0);
        Array vo;
        PRAGMA_UNROLL
        for (int i = 0; i < N; i += 4) {
            // 3120
            vo[i + 0] = (uint8_t)x[i + 0];
            vo[i + 1] = (uint8_t)x[i + 2];
            vo[i + 2] = (uint8_t)x[i + 1];
            vo[i + 3] = (uint8_t)x[i + 3];
        }
        return vo;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/core/check.h"
#include "src/turbomind/kernels/gemm/context.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/dispatch_cache.h"
#include "src/turbomind/kernels/gemm/gemm.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/tuner/params.h"
#include "src/turbomind/kernels/gemm/tuner/sampler.h"
#include "src/turbomind/kernels/gemm/types.h"
#include 
#include 
#include 
#include 
#include 
#include 

namespace turbomind::gemm {

void ExportDispatchCache(std::ostream& os, const std::vector>& entries);

void ImportDispatchCache(std::istream&                                 is,
                         std::vector>& entries,
                         const std::vector>&   kernels);

namespace {

template
std::vector ArgSort(size_t size, const Cmp& cmp)
{
    std::vector idxs(size);
    std::iota(idxs.begin(), idxs.end(), 0);
    std::stable_sort(idxs.begin(), idxs.end(), cmp);
    return idxs;
}

}  // namespace

struct Gemm::Impl {

    Impl():
        props_{GetCudaDeviceProps()},
        arch_{props_->major * 100 + props_->minor * 10},
        registry_{props_},
        cache_{registry_.kernels()}
    {
        if (auto str = std::getenv("TM_GEMM_TUNE")) {
            try {
                ParseTuningParams(tuning_, str);
            }
            catch (...) {
                std::cerr << "[Gemm2] Failed to parse `TM_GEMM_TUNE`, default value will be used.\n";
                tuning_ = {};
            }
        }
        if (std::getenv("TM_GEMM_WARN_CACHE_MISS")) {
            warn_cache_miss_ = true;
        }
        measurer_.emplace(CreateStoppingCriterion(tuning_.min_iter, tuning_.max_iter, tuning_.max_time));
    }

    // find launch spec in dispatch cache, dispatch by heuristic on cache miss
    LaunchSpec Dispatch(Context& ctx, DispatchPolicy policy, size_t barriers_size, size_t partials_size)
    {
        const auto& desc = ctx.desc();
        if (policy & DispatchPolicy::kReuse) {
            if (auto spec = cache_.LowerBound(desc)) {
                return *spec;
            }
            if (warn_cache_miss_) {
                std::cerr << "Failed to find a feasible kernel in the cache, will dispatch by heuristic: "
                          << to_string(ctx.desc()) << std::endl;
            }
        }

        if (auto spec = cache_.Find(desc)) {
            return *spec;
        }

        auto specs = Find(ctx, barriers_size, partials_size, 1);
        if (!specs.empty()) {
            cache_.Insert(desc, specs.front());
            return specs.front();
        }
        return {};
    }

    std::vector Find(Context& ctx, size_t barrier_size, size_t partials_size, int top_k)
    {
        std::vector feasible = ctx.Filter(registry_.kernels());

        std::vector> clusters;
        {
            std::vector tmp;
            tmp.reserve(feasible.size());
            for (const auto& k : feasible) {
                LaunchSpec spec{k};
                tmp.push_back(spec);
            }
            clusters = Cluster(tmp, ClusteringParam{false, true});
        }
        std::vector proxies;
        proxies.reserve(clusters.size());

        for (const auto& c : clusters) {
            proxies.push_back(c.front().kernel);
        }

        std::vector> specs;

        PopulateParam param{};
        param.max_splits    = tuning_.max_splits;
        param.max_waves     = tuning_.max_waves;
        param.swizzle       = tuning_.swizzle.at(0);
        param.barriers_size = barrier_size;
        param.partials_size = partials_size;

        for (int cluster_id = 0; cluster_id < (int)proxies.size(); ++cluster_id) {
            auto& kernel = *proxies[cluster_id];

            auto tmp = ctx.Populate(kernel, param);
            for (const auto& s : tmp) {
                specs.emplace_back(cluster_id, s);
            }
        }

        // std::cerr << "#kernel: " << kernels.size() << ", #cluster: " << clusters.size()
        //           << ", #metric: " << metrics.size() << "\n";

        int64_t mio_max = 0;
        int64_t mma_max = 0;
        for (const auto& [_, s] : specs) {
            auto& [mio, mma] = s.estimated;
            mio_max          = std::max(mio_max, mio);
            mma_max          = std::max(mma_max, mma);
        }
        std::vector mio_ratio;
        std::vector mma_ratio;
        std::vector avg_ratio;
        for (const auto& [_, s] : specs) {
            auto& [mio, mma] = s.estimated;
            mio_ratio.push_back((float)mio / mio_max);
            mma_ratio.push_back((float)mma / mma_max);
            avg_ratio.push_back(.5 * (mio_ratio.back() + mma_ratio.back()));
        }
        auto idxs = ArgSort(specs.size(), [&](int i, int j) {  //
            return avg_ratio[i] < avg_ratio[j];
        });

        // for (const auto& i : idxs) {
        //     auto [cid, s, m] = metrics[i];
        //     std::cout << clusters[cid].front().kernel->name() << " s" << s << " " << avg_ratio[i] << " " <<
        //     mio_ratio[i]
        //               << " " << mma_ratio[i] << " " << m.mio_cost << " " << m.mma_cost << "\n";
        // }

        top_k = top_k > 0 ? std::min(idxs.size(), top_k) : (int)idxs.size();
        std::vector ret;
        ret.reserve(top_k);
        for (int i = 0; i < top_k; ++i) {
            const auto& [cluster_id, spec] = specs[idxs[i]];
            // Apply `splits` to all kernels in the cluster
            for (const auto& s : clusters[cluster_id]) {
                auto tmp   = spec;
                tmp.kernel = s.kernel;
                ret.push_back(tmp);
            }
        }

        return ret;
    }

    template
    int Measure(
        Context& ctx, size_t barriers_size, size_t partials_size, int top_k, LaunchFunc launch_func, cudaStream_t st)
    {
        // Early exit on exact match
        if (cache_.Find(ctx.desc())) {
            return 0;
        }
        // std::cerr << "GEMM: " << desc.m << "x" << desc.n << "x" << desc.k << "\n";

        const auto tmp = Find(ctx, barriers_size, partials_size, tuning_.top_k);

        std::vector specs;
        for (const auto& spec : tmp) {
            // populate swizzle parameters
            const auto swis = ctx.Swizzle(spec, tuning_.swizzle);
            specs.insert(specs.end(), swis.begin(), swis.end());
        }

        specs = Sampler{*measurer_, tuning_.clusters}.Run(specs, launch_func, st);

        // for (const auto& s : specs) {
        //     std::cout << s.kernel->name()          //
        //               << " swizzle=" << s.swizzle  //
        //               << ", splits=" << s.splits   //
        //               << ", measured=" << s.measured << "ms\n";
        //     break;
        // }

        if (!specs.empty()) {
            cache_.Insert(ctx.desc(), specs.front());
        }
        else {
            std::cerr << "No valid kernel found for the problem\n";
            return -1;
        }

        return 0;
    }

    /// TODO: move to cuda utils
    static std::unique_ptr GetCudaDeviceProps()
    {
        auto props     = std::make_unique();
        int  device_id = -1;
        cudaGetDevice(&device_id);
        cudaGetDeviceProperties(props.get(), device_id);
        return props;
    }

    std::shared_ptr props_;

    int arch_;

    Registry registry_;

    TuningParams tuning_;

    bool warn_cache_miss_{};

    std::optional measurer_;

    DispatchCache cache_;
};

// implementation of GEMM interfaces

Gemm::Gemm(): impl_{new Impl{}} {}

Gemm::~Gemm() = default;

int Gemm::Run(const Operation&    operation,
              float               alpha,
              const void*         A,
              const MatrixLayout& Adesc,
              const void*         U,
              const MatrixLayout& Udesc,
              const void*         B,
              const MatrixLayout& Bdesc,
              const void*         V,
              const MatrixLayout& Vdesc,
              float               beta,
              const void*         C,
              const MatrixLayout& Cdesc,
              void*               D,
              const MatrixLayout& Ddesc,
              const Workspace&    workspace,
              cudaStream_t        stream)
{

    Context context{*impl_->props_};

    const auto desc = context.Init(operation, Adesc, Udesc, Bdesc, Vdesc, Cdesc, Ddesc);

    if (!desc) {
        fprintf(stderr, "invalid argument.\n");
        TM_CHECK(0);
        return -1;
    }

    const auto launch = [=](LaunchSpec spec, cudaStream_t st) {
        auto _workspace = workspace;
        return spec.kernel->Launch(operation,
                                   alpha,
                                   A,
                                   Adesc,
                                   U,
                                   Udesc,
                                   B,
                                   Bdesc,
                                   V,
                                   Vdesc,
                                   beta,
                                   C,
                                   Cdesc,
                                   D,
                                   Ddesc,
                                   spec.swizzle,
                                   spec.splits,
                                   _workspace,
                                   st);
    };

#if 0
    if (operation.reserved) {
        auto specs = impl_->Find(context, workspace.barriers_size, workspace.partials_size, 0);
        auto cases = (std::vector>*)operation.reserved;
        for (const auto& spec : specs) {
            cases->push_back([=] {
                launch(spec, stream);
                return spec;
            });
        }
        return -1;
    }
#endif

    LaunchSpec spec{};

    if (operation.dispatch & DispatchPolicy::kMeasure) {
        impl_->Measure(context, workspace.barriers_size, workspace.partials_size, 1, launch, stream);
    }

    spec = impl_->Dispatch(context, operation.dispatch, workspace.barriers_size, workspace.partials_size);

    if (spec.kernel) {
        // std::cout << "[Gemm] dispatch: " << spec.kernel->name()  //
        //           << " split_k=" << spec.splits                  //
        //           << " swizzle=" << spec.swizzle << std::endl;
        return launch(spec, stream);
    }

    TM_CHECK(0) << "No feasible kernel found for the problem: " << to_string(context.desc());

    return -1;
}

int Gemm::Export(std::ostream& os)
{
    return impl_->cache_.Export(os);
}

int Gemm::Import(std::istream& is)
{
    return impl_->cache_.Import(is);
}

std::vector Gemm::GetTuningSeq() const
{
    return impl_->tuning_.seq;
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include 

#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

class Gemm {
public:
    static constexpr size_t kBarriersSize = 1 << 20;
    static constexpr size_t kPartialsSize = 32 << 20;

    Gemm();

    ~Gemm();

    [[nodiscard]] int Run(const Operation&    operation,
                          float               alpha,
                          const void*         A,
                          const MatrixLayout& Adesc,
                          const void*         U,
                          const MatrixLayout& Udesc,
                          const void*         B,
                          const MatrixLayout& Bdesc,
                          const void*         V,
                          const MatrixLayout& Vdesc,
                          float               beta,
                          const void*         C,
                          const MatrixLayout& Cdesc,
                          void*               D,
                          const MatrixLayout& Ddesc,
                          const Workspace&    workspace,
                          cudaStream_t        stream);

    [[maybe_unused]] int Export(std::ostream& os);

    [[maybe_unused]] int Import(std::istream& is);

    [[nodiscard]] std::vector GetTuningSeq() const;

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/math.h"

#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/thread_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

struct GemmParam {
    MatrixParam a;
    MatrixParam b;
    MatrixParam u;
    MatrixParam v;
};

template
__inline__ __device__ MatrixData resolve_op(const MatrixParam& param, int gemm_id)
{
    return resolve(param, gemm_id);
}

template
struct GemmUniversal {

    // using Impl = typename Mainloop::Impl;
    using Impl = Mainloop;

    using Ta = typename Impl::Ta;
    using Tb = typename Impl::Tb;
    using Tu = typename Impl::Tu;
    using Tv = typename Impl::Tv;

    using Arch      = Arch_;
    using Scheduler = Scheduler_;
    using Epilogue  = Epilogue_;

    using Tc = typename Epilogue::Tc;

    // col major == M-major (A)
    // row major == N-major (B)
    static constexpr Order kOrderC = Epilogue::kOrder;

    static constexpr int CTA_M = Impl::CTA_M;
    static constexpr int CTA_N = Impl::CTA_N;
    static constexpr int CTA_K = Impl::CTA_K;

    static constexpr bool kDynamicSched = Scheduler::group_axis >= 0;
    static constexpr bool kSplitK       = Epilogue::SplitK;

    using FragC = typename Impl::FragC;

    static constexpr int WARP_CNT = Impl::WARPS;

    using OperandA = typename Mainloop::OperandA;
    using OperandB = typename Mainloop::OperandB;
    using OperandU = typename Mainloop::OperandU;
    using OperandV = typename Mainloop::OperandV;

    static constexpr int kChunkSizeK = std::max(CTA_K, std::max(OperandU::kGroupSize, OperandV::kGroupSize));

    static constexpr int kGSizeU = OperandU::kGroupSize;
    static constexpr int kGSizeV = OperandV::kGroupSize;

    struct SharedStorage {
        union {
            typename Mainloop::SharedStorage mainloop;
            typename Epilogue::SharedStorage epilogue;
        };
        typename Scheduler::SharedStorage sched;
    };

    static constexpr Order kOrderA = OperandA::kOrder;
    static constexpr Order kOrderB = OperandB::kOrder;
    static constexpr Order kOrderU = OperandU::kOrder;
    static constexpr Order kOrderV = OperandV::kOrder;

    static constexpr Pack kPackA = OperandA::kPack;
    static constexpr Pack kPackB = OperandB::kPack;

    using Param = GemmParam;

    __device__ void operator()(const Param& param, const EpilogueParam& epi_param, Scheduler& sched, char* smem_buf)
    {
        SharedStorage& storage = *reinterpret_cast(smem_buf);

        typename Scheduler::Tile tile;

        if (!sched.init(tile, storage.sched, std::false_type{})) {
            return;
        }

        const auto& [M, N, K] = tile.shape.__a;

        const auto tile_id = tile.tile_id;

        const int offset_m = tile_id[0] * CTA_M;
        const int offset_n = tile_id[1] * CTA_N;

        const int offset_k = tile.k_iters[0] * CTA_K;

        if (offset_m >= M || offset_n >= N || offset_k >= K) {  // empty tile
            return;
        }

        const int extent_m = min(CTA_M, M - offset_m);
        const int extent_n = min(CTA_N, N - offset_n);

        // Is 8 enough?
        __align__(8) FragC frag_C{};

        int tile_iter = tile.k_iters[1];

        const int g = tile.group_id;

        const auto mat_A = resolve_op(param.a, g);
        const auto mat_B = resolve_op(param.b, g);
        const auto mat_U = resolve_op(param.u, g);
        const auto mat_V = resolve_op(param.v, g);

        typename OperandA::GmemIter gmem_A{mat_A, {offset_m, offset_k}, {extent_m, CTA_K}};
        typename OperandB::GmemIter gmem_B{mat_B, {offset_n, offset_k}, {extent_n, CTA_K}};

        const int2 offset_U{offset_m, cdiv(offset_k, kGSizeU)}, extent_U{extent_m, cdiv(CTA_K, kGSizeU)};
        typename OperandU::GmemIter gmem_U{mat_U, offset_U, extent_U};

        const int2 offset_V{offset_n, cdiv(offset_k, kGSizeV)}, extent_V{extent_n, cdiv(CTA_K, kGSizeV)};
        typename OperandV::GmemIter gmem_V{mat_V, offset_V, extent_V};

        Mainloop mainloop{};
        mainloop(gmem_A, gmem_B, gmem_U, gmem_V, frag_C, tile_iter, storage.mainloop);

        {
            sched.init(tile, storage.sched, std::true_type{});

            const auto [M, N, K] = tile.shape.__a;

            int4 tile_offset{tile.tile_id[0], tile.tile_id[1], tile.tile_id[2], tile.group_id};

            const int2 extents = {min(CTA_M, M - tile_offset.x * CTA_M), min(CTA_N, N - tile_offset.y * CTA_N)};

            const bool is_last = (tile.k_iters[0] + tile.k_iters[1]) * CTA_K == K;

            Epilogue epilogue{};
            epilogue(frag_C,  //
                     tile_offset,
                     extents,
                     sched.tiles_[2],
                     tile.linear_tile_id,
                     is_last,
                     epi_param,
                     storage.epilogue);
        }
    }
};

extern __shared__ char smem_buf[];

template
__global__ void gemm_kernel(Param param, EpilogueParam epi_param, Scheduler sched)
{
#if __CUDA_ARCH__
    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {
        Kernel kernel;
        kernel(param, epi_param, sched, smem_buf);
    }
#endif
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal_sm90.h
================================================
#pragma once

#include 

#include 

#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
#include "cute/arch/mma_sm90_desc.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/iterator_sm70.h"
#include "src/turbomind/kernels/gemm/iterator_sm90.h"
#include "src/turbomind/kernels/gemm/scheduler.cuh"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

namespace GMMA = cute::SM90::GMMA;

inline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)
{
    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);

    cute::GmmaDescriptor desc{};
    desc.bitfield.start_address_       = uint_ptr >> 4;
    desc.bitfield.layout_type_         = layout_type;
    desc.bitfield.leading_byte_offset_ = 0;
    desc.bitfield.stride_byte_offset_  = 1024 >> 4;
    desc.bitfield.base_offset_         = 0;

    return desc;
}

template
struct SmemDescIterV2 {
    union {
        uint32_t u32_[2];
        uint64_t u64_;
    };

    uint32_t base_;

    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}

    __device__ void Advance(int stage)
    {
        u32_[0] += Step;
        if (stage == Stages - 1) {
            u32_[0] = base_;
        }
    }

    __device__ SmemDescIterV2& operator+=(int offset)
    {
        u32_[0] += offset;
        return *this;
    }

    __device__ SmemDescIterV2& operator-=(int offset)
    {
        u32_[0] -= offset;
        return *this;
    }

    __device__ operator uint64_t()
    {
        return u64_;
    }
};

template
inline __device__ void
wgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence)
{
    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);
}

template
inline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)
{
    return wgmma_impl(desc_a, desc_b, frag_C, clear, std::make_index_sequence{});
}

inline __device__ void warpgroup_fence_operand(float& reg)
{
    asm volatile("" : "+f"(reg)::"memory");
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])
{
    PRAGMA_UNROLL
    for (int m = 0; m < M; ++m) {
        PRAGMA_UNROLL
        for (int n = 0; n < N; ++n) {
            PRAGMA_UNROLL
            for (int k = 0; k < K; ++k) {
                warpgroup_fence_operand(x[m][n][k]);
            }
        }
    }
}

template
struct GemmUniversalSm90 {

    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS;
    using MMA_Atom = GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN<>;
    static constexpr typename cute::MMA_Traits::Shape_MNK MMA_Shape{};

    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);
    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);
    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);

    static constexpr int kWorkGroupM = 1;
    static constexpr int kWorkGroupN = 2;

    static constexpr int CTA_M = 128;
    static constexpr int CTA_N = MMA_ATOM_N * kWorkGroupN;
    static constexpr int CTA_K = 128;

    static constexpr int WARPGORUPS = kWorkGroupM * kWorkGroupN;

    static constexpr int MMA_M = MMA_ATOM_M * kWorkGroupM;
    static constexpr int MMA_N = MMA_ATOM_N * kWorkGroupN;
    static constexpr int MMA_K = MMA_ATOM_K;

    static constexpr int MMA_ITER_M = CTA_M / MMA_M;  // 2
    static constexpr int MMA_ITER_N = CTA_N / MMA_N;  // 1
    static constexpr int MMA_ITER_K = CTA_K / MMA_K;  // 4

    static constexpr int kMulticastA = 1;
    static constexpr int kMulticastB = 2;

    static constexpr int kClusterSize = kMulticastA * kMulticastB;

    static constexpr int Stages = 3;

    static constexpr bool kSplitK     = false;
    static constexpr int  kChunkSizeK = CTA_K;

    static constexpr int WARPGROUP_SIZE = 128;

    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);

    using Ta = __nv_fp8_e4m3;
    using Tb = __nv_fp8_e4m3;
    using Tc = nv_bfloat16;

    using Tu = float;
    using Tv = float;

    using Arch      = Arch_;
    using Scheduler = TileScheduler;

    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;
    using ConsumerBar = cutlass::arch::ClusterBarrier;

    static constexpr int CTA_M_U = cdiv(CTA_M, 128);
    static constexpr int CTA_K_U = cdiv(CTA_K, 128);
    static constexpr int CTA_K_V = cdiv(CTA_K, 128);
    static constexpr int CTA_N_V = cdiv(CTA_N, 1);

    static constexpr int kTmaTxBytes =
        sizeof(Ta) * (CTA_M * CTA_K) + sizeof(Tb) * (CTA_K * CTA_N) + sizeof(Tv) * CTA_N_V * CTA_K_V;

    struct SharedStorage {
        struct Source {
            __align__(128) Array A;
            __align__(128) Array B;
            __align__(128) Tu U[Stages][round_up(CTA_M_U * CTA_K_U, 32)];
            __align__(128) Tv V[Stages][round_up(CTA_N_V * CTA_K_V, 32)];  // (k1,n256)
        };
        Source source;
        __align__(128) Array C;
        __align__(128) float UV[WARPGORUPS][round_up(CTA_M_U * CTA_N_V, 32)];
        __align__(128) uint64_t producer_bar[Stages];
        __align__(128) uint64_t consumer_bar[Stages];
    };

    __device__ void operator()(const CUtensorMap& tm_a,
                               const CUtensorMap& tm_b,
                               const CUtensorMap& tm_c,
                               const CUtensorMap& tm_u,
                               const CUtensorMap& tm_v,
                               const void*        U_,
                               int                ldU,
                               const void*        V_,
                               int                ldV,
                               Scheduler          sched,
                               char*              smem_buf)
    {
        sched.grid_init();

        SharedStorage& storage = *reinterpret_cast(smem_buf);

        uint64_t* producer_bar = storage.producer_bar;
        uint64_t* consumer_bar = storage.consumer_bar;

        if (threadIdx.x == 0) {
            PRAGMA_UNROLL
            for (int s = 0; s < Stages; ++s) {
                ProducerBar::init(&producer_bar[s], 1);
                ConsumerBar::init(&consumer_bar[s], kClusterSize * WARPGORUPS);
            }
            cutlass::arch::fence_view_async_shared();
            if constexpr (kClusterSize > 1) {
                cutlass::arch::fence_barrier_init();
            }
        }

        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();

        const int warpgroup_id = cutlass::canonical_warp_group_idx();

        if (warpgroup_id == WARPGORUPS) {
            cutlass::arch::warpgroup_reg_dealloc<32>();

            static_assert(CTA_M % kMulticastA == 0);
            static_assert(CTA_N % kMulticastB == 0);

            const int cta_id = cute::block_id_in_cluster().x;

            const int mc_offset_m = kMulticastA > 1 ? cta_id * (CTA_M / kMulticastA) : 0;
            const int mc_offset_n = kMulticastB > 1 ? cta_id * (CTA_N / kMulticastB) : 0;

            auto  smem_A = storage.source.A.data() + mc_offset_m * CTA_K;
            auto  smem_B = storage.source.B.data() + mc_offset_n * CTA_K;
            auto& smem_U = storage.source.U;
            auto& smem_V = storage.source.V;

            if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {
                cutlass::PipelineState write_state{0, 1, 0};
                while (sched.next()) {
                    auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                    if (!cluster_tile_p) {
                        // OOB tile caused by swizzle pattern
                        continue;
                    }

                    const auto tile_offset              = sched.tile_offset();
                    const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                    const int offset_m = tile_offset.x * CTA_M;
                    const int offset_n = tile_offset.y * CTA_N;
                    const int offset_k = 0 * CTA_K;

                    int k_iter = iter_k_end - iter_k_beg;

                    GmemIteratorSm90 gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {CTA_K, 0}};
                    GmemIteratorSm90 gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {CTA_K, 0}};

                    // column-major
                    GmemIteratorSm90 gmem_V{&tm_v, {offset_n, offset_k / 128}, {0, 1}};

                    // auto gmem_U = (const Tu*)U_ + (offset_m / 128) * ldU + (offset_k / 128);
                    // auto step_U = 1;

                    // auto gmem_V = (const Tv*)V_ + offset_n + (offset_k / 128) * ldV;
                    // auto step_V = ldV;

                    while (k_iter > 0) {
                        int pipe = write_state.index();
                        ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                        ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);
                        gmem_A.Load(&producer_bar[pipe], &smem_A[pipe * CTA_M * CTA_K]);
                        // {
                        //     // printf("%f\n", *gmem_U);
                        // smem_U[pipe][0] = *gmem_U;
                        // gmem_U += step_U;
                        // }
                        gmem_B.Load(&producer_bar[pipe], &smem_B[pipe * CTA_N * CTA_K]);
                        gmem_V.Load(&producer_bar[pipe], &smem_V[pipe][0]);

                        ++write_state;
                        --k_iter;
                    }
                }
            }
        }
        else {
            cutlass::arch::warpgroup_reg_alloc<232>();

            auto& smem_A  = storage.source.A;
            auto& smem_B  = storage.source.B;
            auto& smem_U  = storage.source.U;
            auto& smem_V  = storage.source.V;
            auto& smem_UV = storage.UV[warpgroup_id];

            const int warp_group_id_m = warpgroup_id % kWorkGroupM;
            const int warp_group_id_n = warpgroup_id / kWorkGroupM;

            auto smem_desc_A = make_smem_desc(&smem_A[warp_group_id_m * MMA_ATOM_M * CTA_K], 1);
            auto smem_desc_B = make_smem_desc(&smem_B[warp_group_id_n * MMA_ATOM_N * CTA_K], 1);

            SmemDescIterV2> 4)> smem_iter_A{smem_desc_A};
            SmemDescIterV2> 4)> smem_iter_B{smem_desc_B};

            constexpr int kStepMA = (sizeof(Ta) * MMA_M * CTA_K) >> 4;
            constexpr int kStepNB = (sizeof(Tb) * MMA_N * CTA_K) >> 4;
            constexpr int kStepKA = (sizeof(Ta) * MMA_K) >> 4;
            constexpr int kStepKB = (sizeof(Tb) * MMA_K) >> 4;

            cutlass::PipelineState read_state{};
            cutlass::PipelineState release_state{};

            while (sched.next()) {
                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                if (!cluster_tile_p) {
                    // OOB tile caused by swizzle pattern
                    continue;
                }

                MMA_Atom::CRegisters frag_C[MMA_ITER_M][MMA_ITER_N];
                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};  /// TODO: check the z-fill is eliminated

                const auto tile_offset              = sched.tile_offset();
                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                const int offset_m = tile_offset.x * CTA_M;
                const int offset_n = tile_offset.y * CTA_N;
                const int offset_k = 0;

                auto gmem_U = (const Tu*)U_ + (offset_m / 128) * ldU + (offset_k / 128);
                auto step_U = 1;

                int k_iter = iter_k_end - iter_k_beg;

                auto tile_gemm = [&] {
                    PRAGMA_UNROLL
                    for (int k = 0; k < MMA_ITER_K; ++k) {
                        PRAGMA_UNROLL
                        for (int m = 0; m < MMA_ITER_M; ++m) {
                            PRAGMA_UNROLL
                            for (int n = 0; n < MMA_ITER_N; ++n) {
                                wgmma(smem_iter_A, smem_iter_B, frag_C[m][n], k == 0);
                                smem_iter_B += kStepNB;
                            }
                            smem_iter_B -= MMA_ITER_N * kStepNB;
                            smem_iter_A += kStepMA;
                        }
                        smem_iter_A += kStepKA - MMA_ITER_M * kStepMA;
                        smem_iter_B += kStepKB;
                    }
                    smem_iter_A -= MMA_ITER_K * kStepKA;
                    smem_iter_B -= MMA_ITER_K * kStepKB;
                    cute::warpgroup_commit_batch();

                    smem_iter_A.Advance(read_state.index());
                    smem_iter_B.Advance(read_state.index());
                };

                auto consumer_arrive = [&] {
                    if constexpr (kClusterSize > 1) {
                        ConsumerBar::arrive(&consumer_bar[release_state.index()],
                                            threadIdx.x % WARPGROUP_SIZE,
                                            threadIdx.x % WARPGROUP_SIZE < kClusterSize);
                    }
                    else {
                        if (threadIdx.x % WARPGROUP_SIZE == 0) {
                            ConsumerBar::arrive(&consumer_bar[release_state.index()]);
                        }
                    }
                };

                if constexpr (kClusterSize > 1) {
                    if (!cta_tile_p) {
                        // other CTAs in the cluster are still alive
                        for (; k_iter > 0; --k_iter) {
                            ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());
                            consumer_arrive();
                            smem_iter_A.Advance(read_state.index());
                            smem_iter_B.Advance(read_state.index());
                            ++read_state;
                            ++release_state;
                        }
                        continue;
                    }
                }

                float scale_U{};
                auto  Load_U = [&] {
                    scale_U = *gmem_U;
                    gmem_U += step_U;
                };

                auto scale_accum = [&]() {  // cta_n = mma_iter_n * wg_n * mma_atom_n
                    // auto scale_U = smem_U[read_state.index()][0];

                    PRAGMA_UNROLL
                    for (int i = threadIdx.x % WARPGROUP_SIZE; i < MMA_ATOM_N; i += WARPGROUP_SIZE) {
                        smem_UV[i] = scale_U * smem_V[read_state.index()][i + warp_group_id_n * MMA_ATOM_N];
                    }
                    cute::warpgroup_wait<0>();

                    const int lane_id = threadIdx.x % WARP_SIZE;

                    cutlass::arch::NamedBarrier(WARPGROUP_SIZE, warpgroup_id + 1).sync();
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {
                        PRAGMA_UNROLL
                        for (int c = 0; c < MMA_ATOM_N; c += 8) {
                            Array scale_Vs;
                            int             idx = n * MMA_N + c + (lane_id & 3) * 2;
                            Load(scale_Vs, &smem_UV[idx]);
                            PRAGMA_UNROLL
                            for (int m = 0; m < MMA_ITER_M; ++m) {
                                accum_C[m][n][c / 2 + 0] += frag_C[m][n][c / 2 + 0] * scale_Vs[0];
                                accum_C[m][n][c / 2 + 1] += frag_C[m][n][c / 2 + 1] * scale_Vs[1];
                                accum_C[m][n][c / 2 + 2] += frag_C[m][n][c / 2 + 2] * scale_Vs[0];
                                accum_C[m][n][c / 2 + 3] += frag_C[m][n][c / 2 + 3] * scale_Vs[1];
                            }
                        }
                    }
                };

                Load_U();
                ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());
                cute::warpgroup_arrive();
                warpgroup_fence_operand(frag_C);
                tile_gemm();
                warpgroup_fence_operand(frag_C);
                scale_accum();
                consumer_arrive();
                --k_iter;
                ++read_state;
                ++release_state;

                while (k_iter > 0) {
                    Load_U();
                    ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());
                    cute::warpgroup_arrive();
                    warpgroup_fence_operand(frag_C);
                    tile_gemm();
                    warpgroup_fence_operand(frag_C);
                    scale_accum();
                    consumer_arrive();
                    --k_iter;
                    ++read_state;
                    ++release_state;
                }

                if (threadIdx.x == 0) {
                    cute::tma_store_wait<0>();
                }

                cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE).sync();

                // epilogue
                const int warp_id = threadIdx.x / WARP_SIZE;
                const int lane_id = threadIdx.x % WARP_SIZE;

                auto& smem_C = storage.C;

                // (M,N):(1,M)
                PRAGMA_UNROLL
                for (int m = 0; m < MMA_ITER_M; ++m) {
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {
                        PRAGMA_UNROLL
                        for (int i = 0; i < MMA_ATOM_N; i += 16) {
                            // clang-format off
                            // const int mm   = m * MMA_M + warp_id * 16 + (lane_id & 8);
                            // const int nn   = n * MMA_N +     i        + (lane_id & 7) + (lane_id & 16) / 2;
                            const int mm   = m * MMA_M + (warp_id & 3) * 16 + (lane_id & 8);
                            const int nn   = n * MMA_N + warp_group_id_n * MMA_ATOM_N + i + (lane_id & 7) + (lane_id & 16) / 2;
                            // clang-format on
                            __align__(16) Array tvec = cast(*(Array*)&accum_C[m][n][i / 2]);
                            cute::SM90_U16x8_STSM_T::copy((uint32_t&)tvec[0],
                                                          (uint32_t&)tvec[2],
                                                          (uint32_t&)tvec[4],
                                                          (uint32_t&)tvec[6],
                                                          (cutlass::uint128_t&)smem_C[nn * CTA_M + mm]);
                        }
                    }
                }
                cute::tma_store_fence();  // visibility: smem -> async proxy
                cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE).sync();

                if (threadIdx.x == 0) {
                    cute::SM90_TMA_STORE_2D::copy(&tm_c, &smem_C, offset_m, offset_n);
                    cute::tma_store_arrive();
                }
            }  // scheduler loop

            if (threadIdx.x == 0) {
                cute::tma_store_wait<0>();
            }
        }

        cute::cluster_arrive();
        cute::cluster_wait();

    }  // operator()
};

extern __shared__ char smem_buf[];

template
__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_sm90(const __grid_constant__ CUtensorMap tm_a,
                                                                        const __grid_constant__ CUtensorMap tm_b,
                                                                        const __grid_constant__ CUtensorMap tm_c,
                                                                        const __grid_constant__ CUtensorMap tm_u,
                                                                        const __grid_constant__ CUtensorMap tm_v,
                                                                        const void*                         U_,
                                                                        int                                 ldU,
                                                                        const void*                         V_,
                                                                        int                                 ldV,
                                                                        typename Kernel::Scheduler          sched)
{
#if __CUDA_ARCH__
    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {
        Kernel kernel;
        kernel(tm_a, tm_b, tm_c, tm_u, tm_v, U_, ldU, V_, ldV, sched, smem_buf);
    }
#endif
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal_sm90_v2.h
================================================
#pragma once

#include 
#include 

#include 

#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
#include "cute/arch/mma_sm90_desc.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/smem.h"

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/iterator_sm90.h"
#include "src/turbomind/kernels/gemm/scheduler.cuh"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

namespace GMMA = cute::SM90::GMMA;

inline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)
{
    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);

    cute::GmmaDescriptor desc{};
    desc.bitfield.start_address_       = uint_ptr >> 4;
    desc.bitfield.layout_type_         = layout_type;
    desc.bitfield.leading_byte_offset_ = 0;
    desc.bitfield.stride_byte_offset_  = 1024 >> 4;
    desc.bitfield.base_offset_         = 0;

    return desc;
}

template
struct SmemDescIterV2 {
    union {
        uint32_t u32_[2];
        uint64_t u64_;
    };

    uint32_t base_;

    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}

    __device__ void Advance(int stage)
    {
        u32_[0] += Step;
        if (stage == Stages - 1) {
            u32_[0] = base_;
        }
    }

    __device__ void Reset(int stage)
    {
        u32_[0] = base_ + stage * Step;
    }

    __device__ SmemDescIterV2& operator+=(int offset)
    {
        u32_[0] += offset;
        return *this;
    }

    __device__ SmemDescIterV2& operator-=(int offset)
    {
        u32_[0] -= offset;
        return *this;
    }

    __device__ operator uint64_t()
    {
        return u64_;
    }
};

template
inline __device__ void
wgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence)
{
    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);
}

template
inline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)
{
    return wgmma_impl(desc_a, desc_b, frag_C, clear, std::make_index_sequence{});
}

inline __device__ void warpgroup_fence_operand(float& reg)
{
    asm volatile("" : "+f"(reg)::"memory");
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])
{
    PRAGMA_UNROLL
    for (int m = 0; m < M; ++m) {
        PRAGMA_UNROLL
        for (int n = 0; n < N; ++n) {
            PRAGMA_UNROLL
            for (int k = 0; k < K; ++k) {
                warpgroup_fence_operand(x[m][n][k]);
            }
        }
    }
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[N][K])
{
    PRAGMA_UNROLL
    for (int n = 0; n < N; ++n) {
        PRAGMA_UNROLL
        for (int k = 0; k < K; ++k) {
            warpgroup_fence_operand(x[n][k]);
        }
    }
}

template
__device__ void for_(std::index_sequence, Func func)
{
    return (func(constant{}), ...);
}

namespace arch {

template
struct Cluster {
    static constexpr int M = M_;
    static constexpr int N = N_;

    static constexpr int C = mk2cs(M, N).x;
    static constexpr int S = mk2cs(M, N).y;

    static constexpr int size = M * N;

    static constexpr uint16_t kMaskC = (1 << C) - 1;
    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;

    __device__ static ushort2 mask_cs(int cta_id)
    {
        const auto [c, s] = cta_cs(cta_id);
        return make_ushort2(kMaskS << c, kMaskC << s * C);
    }

    __device__ static ushort2 mask_mn(int cta_id)
    {
        auto [c, s] = mask_cs(cta_id);
        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};
    }

    __device__ static int2 cta_cs(int cta_id)
    {
        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};
    }

    __device__ static int2 cta_mn(int cta_id)
    {
        return cs2mk(cta_cs(cta_id));
    }

    int2    cta_mn_;
    ushort2 mask_mn_;

    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}

    __device__ int cta_m()
    {
        return cta_mn_.x;
    }

    __device__ int cta_n()
    {
        return cta_mn_.y;
    }

    __device__ uint16_t mask_m()
    {
        return mask_mn_.x;
    }

    __device__ uint16_t mask_n()
    {
        return mask_mn_.y;
    }
};

}  // namespace arch

struct GemmUniversalSm90_v2 {

    static constexpr bool kDebug = false;

    using Arch = Sm90;

    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS;
    using MMA_Atom = GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>;
    static constexpr typename cute::MMA_Traits::Shape_MNK MMA_Shape{};

    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);
    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);
    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);

    static constexpr int WARPGORUPS = 2;

    static constexpr int TILE_M = 128;
    static constexpr int TILE_N = MMA_ATOM_N;
    static constexpr int TILE_K = 128;

    static constexpr int MMA_ITER_M = TILE_M / MMA_ATOM_M;
    static constexpr int MMA_ITER_N = TILE_N / MMA_ATOM_N;
    static constexpr int MMA_ITER_K = TILE_K / MMA_ATOM_K;

    static constexpr int kMulticastA = 1;
    static constexpr int kMulticastB = 2;

    static constexpr int kClusterSize = kMulticastA * kMulticastB;

    static constexpr int Stages = 4;

    static constexpr bool kSplitK     = false;
    static constexpr int  kChunkSizeK = TILE_K;

    static constexpr int WARPGROUP_SIZE = 128;

    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);

    using Ta = __nv_fp8_e4m3;
    using Tb = __nv_fp8_e4m3;
    using Tc = nv_bfloat16;

    using Tu = float;
    using Tv = float;

    using Cluster = arch::Cluster;

    using Scheduler = TileScheduler;

    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;
    using ConsumerBar = cutlass::arch::ClusterBarrier;

    static constexpr int MAX_K = 32768;

    static constexpr int TILE_M_U = cdiv(TILE_M, 1);
    static constexpr int CTA_K_U  = cdiv(TILE_K, 128);

    static constexpr int kTmaTxBytes =
        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * TILE_M_U * CTA_K_U;

    // ! Smem addr must be SBO aligned for TMA load/store
    struct SharedStorage {
        struct Source {
            __align__(1024) Array A;
            __align__(1024) Array B;
            __align__(1024) Tu U[Stages][round_up(TILE_M_U * CTA_K_U, 32)];
            __align__(1024) Tv V[2][WARPGORUPS][cdiv(MAX_K, 128)];
        };
        Source source;
        __align__(1024) Array C;
        __align__(128) uint64_t producer_bar[Stages];
        __align__(128) uint64_t consumer_bar[Stages];
        int pipe_count[WARPGORUPS];
    };

    static constexpr int kSmemSize = sizeof(SharedStorage);

    static constexpr int kSwizzleC = 2 * std::gcd(TILE_N, 128 / sizeof(Tc));

    using LayoutC = std::conditional_t= 32,
                                       SmemLayoutV2,
                                       SmemLayoutV2>;

    __device__ void operator()(const CUtensorMap& tm_a,
                               const CUtensorMap& tm_b,
                               const CUtensorMap& tm_c,
                               const CUtensorMap& tm_u,
                               const CUtensorMap& tm_v,
                               const void*        U_,
                               int                ldU,
                               const void*        V_,
                               int                ldV,
                               Scheduler          sched,
                               char*              smem_buf)
    {
        SharedStorage& storage = *reinterpret_cast(smem_buf);

        uint64_t* producer_bar = storage.producer_bar;
        uint64_t* consumer_bar = storage.consumer_bar;

        if (threadIdx.x == 0) {
            PRAGMA_UNROLL
            for (int s = 0; s < Stages; ++s) {
                ProducerBar::init(&producer_bar[s], 1);
                ConsumerBar::init(&consumer_bar[s], kClusterSize * 4);
            }
            cutlass::arch::fence_view_async_shared();
            if constexpr (kClusterSize > 1) {
                cutlass::arch::fence_barrier_init();
            }
            PRAGMA_UNROLL
            for (int i = 0; i < WARPGORUPS; ++i) {
                storage.pipe_count[i] = 0;
            }
        }

        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();

        const int warpgroup_id = cutlass::canonical_warp_group_idx();

        if (warpgroup_id == WARPGORUPS) {
            cutlass::arch::warpgroup_reg_dealloc<40>();

            static_assert(TILE_M % kMulticastA == 0);
            static_assert(TILE_N % kMulticastB == 0);

            if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {

                Cluster cluster(cute::block_id_in_cluster().x);

                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);
                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);

                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;
                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;
                auto& smem_U = storage.source.U;

                sched.grid_init();

                cutlass::PipelineState write_state{0, 1, 0};

                while (sched.next()) {
                    auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                    if (!cluster_tile_p) {
                        // OOB tile caused by swizzle pattern
                        continue;
                    }

                    const auto tile_offset              = sched.tile_offset();
                    const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                    const int offset_k = iter_k_beg * TILE_K;

                    const uint16_t mask_A = cluster.mask_m();
                    const uint16_t mask_B = cluster.mask_n();

                    const int offset_m = tile_offset.x * TILE_M;
                    const int offset_n = tile_offset.y * TILE_N;

                    int k_iter = iter_k_end - iter_k_beg;

                    GmemIteratorSm90 gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};
                    GmemIteratorSm90 gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};

                    // column-major
                    GmemIteratorSm90 gmem_U{&tm_u, {offset_m + mc_offset_m, offset_k / 128}, {0, 1}};

                    for (; k_iter > 0; --k_iter) {
                        int pipe = write_state.index();
                        ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                        ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);
                        gmem_A.Load(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);
                        gmem_B.Load(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);
                        gmem_U.Load(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);
                        ++write_state;
                    }
                }
            }
        }
        else {
            cutlass::arch::warpgroup_reg_alloc<232>();

            sched.grid_init(WARPGORUPS);

            auto& smem_A = storage.source.A;
            auto& smem_B = storage.source.B;
            auto& smem_U = storage.source.U;

            auto smem_desc_A = make_smem_desc(&smem_A, 1);
            auto smem_desc_B = make_smem_desc(&smem_B, 1);

            SmemDescIterV2> 4)> smem_iter_A{smem_desc_A};
            SmemDescIterV2> 4)> smem_iter_B{smem_desc_B};

            constexpr int kStepMA = (sizeof(Ta) * MMA_ATOM_M * TILE_K) >> 4;
            constexpr int kStepNB = (sizeof(Tb) * MMA_ATOM_N * TILE_K) >> 4;
            constexpr int kStepKA = (sizeof(Ta) * MMA_ATOM_K) >> 4;
            constexpr int kStepKB = (sizeof(Tb) * MMA_ATOM_K) >> 4;

            auto math_barrier_sync = [&](int phase, int alive = 1) {
                constexpr int base    = (int)cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;
                constexpr int threads = WARPGORUPS * WARPGROUP_SIZE;
                int           res;
                asm volatile("{\n"
                             "  .reg.pred p;\n"
                             "  setp.ne.b32 p, %3, 0;\n"
                             "  barrier.cta.red.or.pred p, %1, %2, p;\n"
                             "  selp.s32 %0, 1, 0, p;\n"
                             "}\n"
                             : "=r"(res)
                             : "r"(base + warpgroup_id ^ phase), "r"(threads), "r"(alive));
                return res;
            };

            cutlass::arch::NamedBarrier wg_barrier(WARPGROUP_SIZE, warpgroup_id + 2);  // 2,3

            sched.next(warpgroup_id);

            if (warpgroup_id == 1) {
                math_barrier_sync(1);
            }

            while (sched.next(WARPGORUPS)) {
                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                if (!cluster_tile_p) {
                    // OOB tile caused by swizzle pattern
                    continue;
                }

                MMA_Atom::CRegisters frag_C[MMA_ITER_M][MMA_ITER_N];
                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};

                const auto tile_offset              = sched.tile_offset();
                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                const auto [M, N, K, L] = sched.gemm_shape();

                const int offset_m = tile_offset.x * TILE_M;
                const int offset_n = tile_offset.y * TILE_N;
                const int offset_k = 0;

                int k_iter = iter_k_end - iter_k_beg;

                const int warp_id = threadIdx.x / WARP_SIZE;
                const int lane_id = threadIdx.x % WARP_SIZE;

                const int wg_lane = threadIdx.x % WARPGROUP_SIZE;

                cutlass::PipelineState pipe_state{};

                auto consumer_arrive = [&] {
                    __syncwarp();
                    if constexpr (kClusterSize > 1) {
                        ConsumerBar::arrive(&consumer_bar[pipe_state.index()], lane_id, lane_id < kClusterSize);
                    }
                    else {
                        if (lane_id == 0) {
                            ConsumerBar::arrive(&consumer_bar[pipe_state.index()]);
                        }
                    }
                };

                if constexpr (kClusterSize > 1) {
                    if (!cta_tile_p) {  // other CTAs in the cluster are still alive
                        math_barrier_sync(0);
                        pipe_state.advance(storage.pipe_count[warpgroup_id ^ 1]);
                        for (; k_iter > 0; --k_iter) {
                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                            consumer_arrive();
                            ++pipe_state;
                        }
                        if (wg_lane == 0) {
                            storage.pipe_count[warpgroup_id] = pipe_state.count();
                        }
                        math_barrier_sync(1);
                        continue;
                    }
                }

                auto Copy = [k = cdiv(K, 128)](Tv* dst, const Tv* src) {
                    for (int i = threadIdx.x % WARPGROUP_SIZE; i < k; i += WARPGROUP_SIZE) {
                        dst[i] = __ldg(&src[i]);
                    }
                };
                auto gmem_V = (const Tv*)V_ + (offset_n / 128) * ldV + (offset_k / 128);
                Copy(storage.source.V[0][warpgroup_id], gmem_V);

                uint32_t pred_V{};
                int      iter_V{};

                constexpr int OUTER_N = std::gcd(MMA_ATOM_N, 128);
                if constexpr (OUTER_N != 128) {

                    static_assert(MMA_ATOM_N <= 128 + OUTER_N, "MMA inst is crossing more than 2 scale blocks");

                    constexpr uint32_t mask = (1UL << (TILE_M / OUTER_N)) - 1;

                    int phase = 128 - offset_n % 128;
                    pred_V    = (mask << (phase / OUTER_N)) & mask;

                    if (pred_V && offset_n / 128 + 1 < cdiv(N, 128)) {
                        Copy(storage.source.V[1][warpgroup_id], gmem_V + ldV);
                    }

                    // if constexpr (kWorkGroupN > 1) {
                    //     constexpr int tiles = MMA_ATOM_N / OUTER_N;
                    //     pred_V              = (pred_V >> (warp_group_id_n * tiles)) & ((1 << tiles) - 1);
                    // }
                }

                float scale_V[2];
                auto  Load_V = [&] {
                    scale_V[0] = storage.source.V[0][warpgroup_id][iter_V];
                    if (pred_V) {
                        scale_V[1] = storage.source.V[1][warpgroup_id][iter_V];
                    }
                    ++iter_V;
                };

                float     scale_U[MMA_ITER_M][2];
                const int offset_U = warp_id % 4 * 16 + lane_id / 4;
                auto      Load_U   = [&] {
                    for (int m = 0; m < MMA_ITER_M; ++m) {
                        scale_U[m][0] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M];
                        scale_U[m][1] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M + 8];
                    }
                };

                auto scale_accum = [&](int m) {  // cta_n = mma_iter_n * wg_n * mma_atom_n
                    float scales[2][2];
                    scales[0][0] = scale_U[m][0] * scale_V[0];
                    scales[1][0] = scale_U[m][1] * scale_V[0];
                    scales[0][1] = scale_U[m][0] * scale_V[1];
                    scales[1][1] = scale_U[m][1] * scale_V[1];
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {
                        PRAGMA_UNROLL
                        for (int c0 = 0; c0 < MMA_ATOM_N; c0 += OUTER_N) {
                            bool pred = (pred_V & (1U << (c0 / OUTER_N)));
                            PRAGMA_UNROLL
                            for (int cc = 0; cc < OUTER_N; cc += 8) {
                                int c = c0 + cc;
                                // clang-format off
                                accum_C[m][n][c / 2 + 0] += (pred ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 0];
                                accum_C[m][n][c / 2 + 1] += (pred ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 1];
                                accum_C[m][n][c / 2 + 2] += (pred ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 2];
                                accum_C[m][n][c / 2 + 3] += (pred ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 3];
                                // clang-format on
                            }
                        }
                    }

                };

                auto gmma = [&](int m) {
                    PRAGMA_UNROLL
                    for (int k = 0; k < MMA_ITER_K; ++k) {
                        PRAGMA_UNROLL
                        for (int n = 0; n < MMA_ITER_N; ++n) {
                            wgmma(smem_iter_A, smem_iter_B, frag_C[m][n], k == 0);
                            smem_iter_B += kStepNB;
                        }
                        smem_iter_B -= MMA_ITER_N * kStepNB;
                        smem_iter_A += kStepKA;
                        smem_iter_B += kStepKB;
                    }
                    smem_iter_A -= MMA_ITER_K * kStepKA;
                    smem_iter_B -= MMA_ITER_K * kStepKB;
                    smem_iter_A += kStepMA;
                    cute::warpgroup_commit_batch();
                };

                static_assert(MMA_ITER_N == 1);

                math_barrier_sync(0);

                pipe_state.advance(storage.pipe_count[warpgroup_id ^ 1]);

                smem_iter_A.Reset(pipe_state.index());
                smem_iter_B.Reset(pipe_state.index());
                Load_V();
                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                Load_U();
                cute::warpgroup_arrive();
                gmma(0);
                gmma(1);
                cute::warpgroup_wait<1>();
                scale_accum(0);
                cute::warpgroup_wait<0>();
                scale_accum(1);
                consumer_arrive();
                ++pipe_state;
                --k_iter;

                Load_V();
                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                Load_U();
                smem_iter_A.Reset(pipe_state.index());
                smem_iter_B.Reset(pipe_state.index());

                for (; k_iter > 1; --k_iter) {
                    cute::warpgroup_arrive();
                    gmma(0);
                    gmma(1);
                    cute::warpgroup_wait<1>();
                    scale_accum(0);
                    cute::warpgroup_wait<0>();
                    scale_accum(1);
                    consumer_arrive();
                    ++pipe_state;
                    Load_V();
                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());
                }

                cute::warpgroup_arrive();
                gmma(0);
                gmma(1);
                cute::warpgroup_wait<1>();
                scale_accum(0);
                cute::warpgroup_wait<0>();
                scale_accum(1);
                consumer_arrive();
                ++pipe_state;

                if (wg_lane == 0) {
                    storage.pipe_count[warpgroup_id] = pipe_state.count();
                }

                math_barrier_sync(1);

                // epilogue
                PRAGMA_UNROLL
                for (int m = 0; m < MMA_ITER_M; ++m) {
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {

                        constexpr int N       = LayoutC::C0;
                        constexpr int SW_bits = log2(kSwizzleC / 16);

                        static_assert(!SW_bits || MMA_ATOM_N % LayoutC::C0 == 0);

                        const int m0 = m * MMA_ATOM_M;
                        const int n0 = n * MMA_ATOM_N;

                        PRAGMA_UNROLL
                        for (int i = 0; i < MMA_ATOM_N; i += 16) {
                            __align__(16) Array tvec = cast(*(Array*)&accum_C[m][n][i / 2]);

                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);
                            int nn = n0 + i / N * N;

                            int addr = ((nn / N) * TILE_M * N) + (mm * N) + (nn % N);

                            int s = lane_id % 8;
                            int c = (lane_id & 16) / 2 + i % N;

                            addr += Swizzle::apply(s * N + c);

                            auto& uvec = (Array&)tvec;
                            cute::SM90_U32x4_STSM_N::copy(
                                uvec[0], uvec[1], uvec[2], uvec[3], (cutlass::uint128_t&)storage.C[addr]);
                        }
                    }
                }

                cute::tma_store_fence();  // visibility: smem -> async proxy

                wg_barrier.sync();

                const int wg_thread_id = threadIdx.x % WARPGROUP_SIZE;

                if (wg_thread_id < LayoutC::C1) {
                    const int tma_n = wg_thread_id * LayoutC::C0;
                    cute::SM90_TMA_STORE::copy(
                        &tm_c, &storage.C[wg_thread_id * TILE_M * LayoutC::C0], offset_n + tma_n, offset_m);
                    cute::tma_store_arrive();
                    cute::tma_store_wait<0>();
                }

                wg_barrier.sync();

            }  // scheduler loop

            if (warpgroup_id == 0) {
                math_barrier_sync(0, 0);
                while (math_barrier_sync(1, 0)) {
                    math_barrier_sync(0, 0);
                }
            }
            else {
                while (math_barrier_sync(0, 0)) {
                    math_barrier_sync(1, 0);
                }
            }

            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {
                cute::tma_store_wait<0>();
            }
        }

        if constexpr (kClusterSize > 1) {
            cute::cluster_arrive();
            cute::cluster_wait();
        }

    }  // operator()
};

extern __shared__ char smem_buf[];

template
__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_sm90(const __grid_constant__ CUtensorMap tm_a,
                                                                        const __grid_constant__ CUtensorMap tm_b,
                                                                        const __grid_constant__ CUtensorMap tm_c,
                                                                        const __grid_constant__ CUtensorMap tm_u,
                                                                        const __grid_constant__ CUtensorMap tm_v,
                                                                        const void*                         U_,
                                                                        int                                 ldU,
                                                                        const void*                         V_,
                                                                        int                                 ldV,
                                                                        typename Kernel::Scheduler          sched)
{
#if __CUDA_ARCH__
    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {
        Kernel kernel;
        kernel(tm_a, tm_b, tm_c, tm_u, tm_v, U_, ldU, V_, ldV, sched, smem_buf);
    }
#endif
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h
================================================
#pragma once

#include 
#include 

#include 
#include 

#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_desc.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
#include "cute/arch/mma_sm90_desc.hpp"

#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"

#include "src/turbomind/kernels/core/smem.h"

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/cp_async.h"
#include "src/turbomind/kernels/gemm/iterator_sm90.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/scheduler.cuh"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

#include "src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h"
#include "src/turbomind/kernels/gemm/sm90_utils.h"

namespace turbomind::gemm {

template
struct GemmUniversalSm90_v3 {

    static constexpr bool kDebug = false;

    using Arch = Sm90;

    static constexpr int TILE_M = 128;
    static constexpr int TILE_N = 192;
    static constexpr int TILE_K = 128;

    static constexpr int WG_M = 2;
    static constexpr int WG_N = 1;

    static constexpr int WG_TILE_M = TILE_M / WG_M;
    static constexpr int WG_TILE_N = TILE_N / WG_N;

    static constexpr int kSchedWarpGroups = 1;

    static constexpr int WARPGORUPS = WG_M * WG_N;

    using GMMA = ScaledGmmaFP8_TN;

    static constexpr int kMulticastA = multicast_a;
    static constexpr int kMulticastB = multicast_b;

    static constexpr int kClusterSize = kMulticastA * kMulticastB;

    static constexpr int Stages = 4;

    static constexpr bool kSplitK     = false;
    static constexpr int  kChunkSizeK = TILE_K;

    static constexpr int WARPGROUP_SIZE = 128;

    static constexpr int kMathGroupSize = WARPGROUP_SIZE * WARPGORUPS;

    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);

    using Ta = __nv_fp8_e4m3;
    using Tb = __nv_fp8_e4m3;
    using Tc = nv_bfloat16;

    using Tu = float;
    using Tv = float;

    using Cluster = arch::Cluster;

    static constexpr auto is_grouped_gemm = is_grouped_gemm_;

    using Scheduler = TileScheduler;

    static constexpr int kMulticastU = is_grouped_gemm ? 1 : kMulticastA;

    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;
    using ConsumerBar = cutlass::arch::ClusterBarrier;

    static constexpr int kAlignmentU = 16 / sizeof(Tu);
    static constexpr int kBoxU       = TILE_M + (is_grouped_gemm ? kAlignmentU : 0);

    // Alignment requirement for SMEM addr. This forbids multicast factor 8.
    static_assert(kMulticastU == 1 || sizeof(Tu) * kBoxU / kMulticastU % 128 == 0);

    static constexpr int kTmaTxBytes =
        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * kBoxU;

    // ! SMEM addr must be SBO aligned for TMA load/store
    struct SharedStorage {
        __align__(1024) Array A;
        __align__(1024) Array B;
        __align__(1024) Array C;
        __align__(128) Tu U[Stages][round_up(kBoxU, 128)];  // at least 128 byte alignment
        __align__(128) Tv V[Stages][2];
        __align__(128) CUtensorMap tensor_map[5];
        __align__(8) uint64_t producer_bar[Stages];
        __align__(8) uint64_t consumer_bar[Stages];
        typename Scheduler::Storage sched;
    };

    static constexpr int kSmemSize = sizeof(SharedStorage);

    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));

    using LayoutC = std::conditional_t= 32,
                                       SmemLayoutV2,
                                       SmemLayoutV2>;

    static constexpr int OUTER_N       = GMMA::OUTER_N;
    static constexpr int MMA_SUBTILE_N = GMMA::OP_N / OUTER_N;

    __device__ void operator()(const CUtensorMap& tm_a,
                               const CUtensorMap& tm_b,
                               const CUtensorMap& tm_c,
                               const CUtensorMap& tm_u,
                               const CUtensorMap& tm_v,
                               const MatrixParam& param_A,
                               const MatrixParam& param_B,
                               const MatrixParam& param_U,
                               const MatrixParam& param_V,
                               const MatrixParam& param_C,
                               Scheduler          sched,
                               CUtensorMap*       tensormap_buf,
                               char*              smem_buf)
    {
        SharedStorage& storage = *reinterpret_cast(smem_buf);

        uint64_t* producer_bar = storage.producer_bar;
        uint64_t* consumer_bar = storage.consumer_bar;

        if (threadIdx.x == 0) {
            PRAGMA_UNROLL
            for (int s = 0; s < Stages; ++s) {
                ProducerBar::init(&producer_bar[s], 1 + 1);
                ConsumerBar::init(&consumer_bar[s], WARPGORUPS * kClusterSize * 4);
            }
            sched.init_dyanmic(storage.sched, kClusterSize * (WARPGORUPS * 4 + 1));
            cutlass::arch::fence_view_async_shared();
            if constexpr (kClusterSize > 1) {
                cutlass::arch::fence_barrier_init();
            }
        }

        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();

        const int wg_idx = cutlass::canonical_warp_group_idx();

        if (wg_idx == WARPGORUPS) {
            cutlass::arch::warpgroup_reg_dealloc<40>();

            static_assert(TILE_M % kMulticastA == 0);
            static_assert(TILE_N % kMulticastB == 0);

            cutlass::arch::NamedBarrier producers_bar(WARP_SIZE * 2, 7);

            const int  warp_id = cutlass::canonical_warp_idx_sync();
            const bool cta_0   = cute::block_id_in_cluster().x == 0;

            if (warp_id % 4 == 0) {

                Cluster cluster(cute::block_id_in_cluster().x);

                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);
                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);

                auto  smem_A = storage.A.data() + mc_offset_m * TILE_K;
                auto  smem_B = storage.B.data() + mc_offset_n * TILE_K;
                auto& smem_U = storage.U;
                auto& smem_V = storage.V;

                if constexpr (is_grouped_gemm) {
                    init_tma_descs<3>({&tm_a, &tm_b, &tm_u}, storage.tensor_map);
                }

                cutlass::PipelineState write_state{0, 1, 0};

                auto sched_state = sched.init_consumer(storage.sched);

                int lane_predicate = cute::elect_one_sync();

                typename Scheduler::Tile* tile;

                while (sched_state.acquire(tile)) {

                    if (tile->is_valid_cluster) {

                        const CUtensorMap* Adesc = &tm_a;
                        const CUtensorMap* Bdesc = &tm_b;
                        const CUtensorMap* Udesc = &tm_u;

                        const Tv* gmem_V0 = (const Tv*)param_V.ptr;
                        const Tv* gmem_V1;

                        if constexpr (is_grouped_gemm) {
                            const int g  = tile->group_idx;
                            const int m0 = tile->m0;
                            const int m1 = tile->m1;
                            const int m  = m1 - m0;

                            Array global_addrs;
                            global_addrs[0] = (Ta*)param_A.ptr + m0 * (int64_t)param_A.stride;
                            global_addrs[1] = ((void**)param_B.ptr)[g];

                            const int beg_u = m0 / kAlignmentU * kAlignmentU;
                            const int end_u = round_up(m1, kAlignmentU);
                            global_addrs[2] = (Tu*)param_U.ptr + beg_u;

                            Array dims;
                            dims[0] = m;
                            dims[1] = sched.gemm_shape().y;
                            dims[2] = end_u - beg_u;

                            auto descs = update_tma_descs(tensormap_buf, storage.tensor_map, global_addrs, dims);
                            Adesc      = &descs[0];
                            Bdesc      = &descs[1];
                            Udesc      = &descs[2];

                            gmem_V0 = ((Tv**)gmem_V0)[g];

                            PRAGMA_UNROLL
                            for (int i = 0; i < 3; ++i) {
                                cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)&descs[i]);
                            }
                        }

                        if (lane_predicate) {
                            const int offset_k = 0;

                            const uint16_t mask_A = cluster.mask_m();
                            const uint16_t mask_B = cluster.mask_n();

                            const int offset_m = tile->offset_m;
                            const int offset_n = tile->offset_n;

                            int k_iter = sched.k_iters_;

                            GmemIteratorSm90 gmem_A{
                                Adesc, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};
                            GmemIteratorSm90 gmem_B{
                                Bdesc, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};

                            const int mc_offset_u = kMulticastU > 1 ? mc_offset_m : 0;
                            // column-major
                            GmemIteratorSm90 gmem_U{
                                Udesc, {offset_m + mc_offset_u, offset_k / 128}, {0, 1}};

                            gmem_V0 += (offset_n / 128) * param_V.stride + (offset_k / 128);
                            gmem_V1 = gmem_V0;
                            if (offset_n / 128 + 1 < cdiv(sched.gemm_shape().y, 128)) {
                                gmem_V1 += param_V.stride;
                            }

                            for (; k_iter > 0; --k_iter) {
                                int pipe = write_state.index();
                                ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                                ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);
                                gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);
                                gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);
                                gmem_U.Step(&producer_bar[pipe], smem_U[pipe] + mc_offset_u, mask_A);
                                uint32_t uint_ptr_V = cast_smem_ptr_to_uint(smem_V[pipe]);
                                CP_ASYNC::apply(uint_ptr_V, gmem_V0, true);
                                CP_ASYNC::apply(uint_ptr_V + sizeof(Tv), gmem_V1, true);
                                ++gmem_V0;
                                ++gmem_V1;
                                cutlass::arch::cpasync_barrier_arrive_noinc(&producer_bar[pipe]);
                                ++write_state;
                            }
                        }
                    }

                    if constexpr (Scheduler::is_dynamic) {
                        if (cta_0) {
                            producers_bar.arrive_unaligned();
                        }
                    }

                    sched_state.release();

                }  // scheduler loop

                // release last tile
                sched_state.release();

                if constexpr (kClusterSize > 1) {
                    if (lane_predicate) {
                        for (int i = 0; i < Stages; ++i) {
                            ConsumerBar::wait(&consumer_bar[write_state.index()], write_state.phase());
                            ++write_state;
                        }
                    }
                    __syncwarp();
                }
            }
            else if (warp_id % 4 == 1 && cta_0) {
                auto state = sched.init_producer(storage.sched);
                while (state.next()) {
                    if constexpr (Scheduler::is_dynamic) {
                        producers_bar.arrive_and_wait_unaligned();
                    }
                }
                sched.tail(state);
            }
        }
        else {
            cutlass::arch::warpgroup_reg_alloc<232>();

            if constexpr (is_grouped_gemm) {
                if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {
                    init_tma_descs<1>({&tm_c}, storage.tensor_map + 3 + wg_idx);
                }
            }

            auto& smem_A = storage.A;
            auto& smem_B = storage.B;
            auto& smem_U = storage.U;
            auto& smem_V = storage.V;

            const int wg_idx_m = WG_M > 1 ? wg_idx % WG_M : 0;
            const int wg_idx_n = WG_N > 1 ? wg_idx / WG_M : 0;

            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);
            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);

            SmemDescIterV2> 4)> smem_iter_A{smem_desc_A};
            SmemDescIterV2> 4)> smem_iter_B{smem_desc_B};

            cutlass::arch::NamedBarrier barrier(WARPGROUP_SIZE, 2 + wg_idx);  // 0, 1

            cutlass::PipelineState pipe_state{};

            const int warp_id = cutlass::canonical_warp_idx_sync();
            const int lane_id = cutlass::canonical_lane_idx();

            auto consumer_arrive = [&] {
                auto bar = &consumer_bar[pipe_state.index()];
                __syncwarp();
                if constexpr (kClusterSize > 1) {
                    ConsumerBar::arrive(bar, lane_id, lane_id < kClusterSize);
                }
                else {
                    if (lane_id == 0) {
                        ConsumerBar::arrive(bar);
                    }
                }
            };

            auto sched_state = sched.init_consumer(storage.sched);

            typename Scheduler::Tile* tile;

            sched_state.acquire(tile);

            while (tile->alive) {

                if (tile->is_valid_cta) {
                    GMMA::AccumC accum_C{};
                    GMMA::FragC  frag_C;

                    auto pred_V = Fetch_V(tile, wg_idx_n);

                    float scale_V[2];
                    auto  Load_V = [&] {
                        scale_V[0] = smem_V[pipe_state.index()][0];
                        scale_V[1] = smem_V[pipe_state.index()][1];
                    };

                    int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;
                    if constexpr (is_grouped_gemm) {
                        offset_U += tile->m0 % kAlignmentU;
                    }
                    GMMA::FragU frag_U;
                    auto        Load_U = [&] {
                        GMMA::foreach_m(frag_U, [&](auto& U, int m) {
                            U[0] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M];
                            U[1] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M + 8];
                        });
                    };

                    auto gmma = [&] {  //
                        GMMA::apply(smem_iter_A, smem_iter_B, frag_C, accum_C, frag_U, scale_V, pred_V);
                    };

                    if constexpr (is_grouped_gemm) {
                        if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {
                            cute::tma_store_wait<0>();
                        }
                        // No need to sync here as the update is warp synchronized
                        if (warp_id % 4 == 0) {
                            int  m0 = tile->m0, m1 = tile->m1;
                            auto global_addr = (Tc*)param_C.ptr + m0 * (int64_t)param_C.stride;
                            int  idx         = 3 + wg_idx;
                            update_tma_descs<1>(
                                tensormap_buf + idx, storage.tensor_map + idx, {global_addr}, {m1 - m0});
                        }
                        barrier.sync();
                    }

                    int k_iter = sched.k_iters_;

                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_V();
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());
                    gmma();
                    consumer_arrive();
                    ++pipe_state;
                    --k_iter;

                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_V();
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());

                    PRAGMA_NO_UNROLL
                    for (; k_iter > 1; --k_iter) {
                        gmma();
                        consumer_arrive();
                        ++pipe_state;
                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                        Load_V();
                        Load_U();
                        smem_iter_A.Reset(pipe_state.index());
                        smem_iter_B.Reset(pipe_state.index());
                    }

                    gmma();

                    const int thread_idx = threadIdx.x % WARPGROUP_SIZE;
                    if constexpr (!is_grouped_gemm) {
                        if (thread_idx < LayoutC::C1) {
                            cute::tma_store_wait<0>();
                        }
                        barrier.sync();
                    }

                    consumer_arrive();
                    ++pipe_state;

                    Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];

                    // epilogue
                    GMMA::foreach_C(accum_C, [&](const auto& C, int m, int n) {
                        constexpr int N       = LayoutC::C0;
                        constexpr int SW_bits = log2(kSwizzleC / 16);

                        static_assert(!SW_bits || GMMA::OP_N % LayoutC::C0 == 0);
                        static_assert(GMMA::OP_N % 16 == 0);

                        const int m0 = m * GMMA::OP_M;
                        const int n0 = n * GMMA::OP_N;

                        PRAGMA_UNROLL
                        for (int i = 0; i < GMMA::OP_N; i += 16) {
                            __align__(16) Array tvec = cast((Array&)C[i / 2]);

                            // fill(tvec, Tc(255));

                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);
                            int nn = n0 + i / N * N;

                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);

                            int s = lane_id % 8;
                            int c = (lane_id & 16) / 2 + i % N;

                            addr += Swizzle::apply(s * N + c);

                            auto& uvec = (Array&)tvec;
                            cute::SM90_U32x4_STSM_N::copy(uvec[0],  //
                                                          uvec[1],
                                                          uvec[2],
                                                          uvec[3],
                                                          (cutlass::uint128_t&)smem_C[addr]);
                        }
                    });

                    cute::tma_store_fence();  // visibility: smem -> async proxy

                    barrier.sync();

                    const int offset_m = tile->offset_m;
                    const int offset_n = tile->offset_n;

                    const void* Cdesc = &tm_c;

                    if (thread_idx < LayoutC::C1) {
                        const int tma_n = thread_idx * LayoutC::C0;
                        if constexpr (is_grouped_gemm) {
                            Cdesc = tensormap_buf + blockIdx.x * 5 + 3 + wg_idx;
                            cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);
                        }
                        cute::SM90_TMA_STORE::copy(Cdesc,
                                                   &smem_C[thread_idx * WG_TILE_M * LayoutC::C0],
                                                   offset_n + wg_idx_n * WG_TILE_N + tma_n,
                                                   offset_m + wg_idx_m * WG_TILE_M);
                        cute::tma_store_arrive();
                    }
                }
                else if (tile->is_valid_cluster) {
                    int k_iter = sched.k_iters_;
                    for (; k_iter > 0; --k_iter) {
                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                        consumer_arrive();
                        ++pipe_state;
                    }
                }

                sched_state.release();
                sched_state.acquire(tile);

            }  // scheduler loop

            // release last tile
            sched_state.release();

            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {
                cute::tma_store_wait<0>();
            }
        }

    }  // operator()

    template
    __device__ void init_tma_descs(Array param_desc, CUtensorMap* smem_desc)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;

        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {
            PRAGMA_UNROLL
            for (int i = 0; i < N; ++i) {
                ((uint2*)&smem_desc[i])[lane_id] = ((uint2*)param_desc[i])[lane_id];
            }
        }

        __syncwarp();
    }

    template
    __device__ CUtensorMap*
    update_tma_descs(CUtensorMap* gmem_desc, CUtensorMap* smem_desc, Array global_addrs, Array dims)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;
        if (lane_id == 0) {
            PRAGMA_UNROLL
            for (int i = 0; i < N; ++i) {
                uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);
                // clang-format off
                asm volatile("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" ::"r"(uint_ptr), "l"(global_addrs[i]));
                if (i != 2) {
                    asm volatile("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" ::"r"(uint_ptr), "r"(dims[i]));
                } else { // special case for U
                    asm volatile("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" ::"r"(uint_ptr), "r"(dims[i]));
                }
                // clang-format on
            }
        }

        __syncwarp();

        constexpr int kNumPerCta = 5;  // a,b,u,c0,c1
        auto          gmem_ptr   = &gmem_desc[blockIdx.x * kNumPerCta];
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);
            // clang-format off
            asm volatile("tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" :: "l"(gmem_ptr + i), "r"(uint_ptr));
            // clang-format on
        }

        return gmem_ptr;
    }

    __device__ auto Fetch_V(typename Scheduler::Tile* tile, int wg_idx_n)
    {
        constexpr int BLK_SUBTILE_N = 128 / OUTER_N;
        static_assert(MMA_SUBTILE_N - 1 < BLK_SUBTILE_N + 1);  // n1 - 1 + n0 - 1 < 2 * n0

        Array pred_V{};
        if constexpr (MMA_SUBTILE_N != 1) {
            int offset = tile->offset_n % 128 + wg_idx_n * WG_TILE_N;
            static_assert(WG_N == 1);
            // Safely skip pred_V_0 when distributing WGs along M
            PRAGMA_UNROLL
            for (int i = 1; i < MMA_SUBTILE_N; ++i) {
                pred_V[i] = (i * OUTER_N + offset) >= 128;
            }
        }

        return pred_V;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal_sm90_v4.h
================================================
#pragma once

#include 
#include 

#include 

#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_desc.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
#include "cute/arch/mma_sm90_desc.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/smem.h"

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/iterator_sm90.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/scheduler.cuh"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

namespace GMMA = cute::SM90::GMMA;

inline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)
{
    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);

    cute::GmmaDescriptor desc{};
    desc.bitfield.start_address_       = uint_ptr >> 4;
    desc.bitfield.layout_type_         = layout_type;
    desc.bitfield.leading_byte_offset_ = 0;
    desc.bitfield.stride_byte_offset_  = 1024 >> 4;
    desc.bitfield.base_offset_         = 0;

    return desc;
}

template
struct SmemDescIterV2 {
    union {
        uint32_t u32_[2];
        uint64_t u64_;
    };

    uint32_t base_;

    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}

    __device__ void Advance(int stage)
    {
        u32_[0] += Step;
        if (stage == Stages - 1) {
            u32_[0] = base_;
        }
    }

    __device__ void Reset(int stage)
    {
        u32_[0] = base_ + stage * Step;
    }

    __device__ SmemDescIterV2& operator+=(int offset)
    {
        u32_[0] += offset;
        return *this;
    }

    __device__ SmemDescIterV2& operator-=(int offset)
    {
        u32_[0] -= offset;
        return *this;
    }

    __device__ operator uint64_t()
    {
        return u64_;
    }
};

template
inline __device__ void
wgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence)
{
    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);
}

template
inline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)
{
    return wgmma_impl(desc_a, desc_b, frag_C, clear, std::make_index_sequence{});
}

inline __device__ void warpgroup_fence_operand(float& reg)
{
    asm volatile("" : "+f"(reg)::"memory");
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])
{
    PRAGMA_UNROLL
    for (int m = 0; m < M; ++m) {
        PRAGMA_UNROLL
        for (int n = 0; n < N; ++n) {
            PRAGMA_UNROLL
            for (int k = 0; k < K; ++k) {
                warpgroup_fence_operand(x[m][n][k]);
            }
        }
    }
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[N][K])
{
    PRAGMA_UNROLL
    for (int n = 0; n < N; ++n) {
        PRAGMA_UNROLL
        for (int k = 0; k < K; ++k) {
            warpgroup_fence_operand(x[n][k]);
        }
    }
}

template
__device__ void for_(std::index_sequence, Func func)
{
    return (func(constant{}), ...);
}

namespace arch {

template
struct Cluster {
    static constexpr int M = M_;
    static constexpr int N = N_;

    static constexpr int C = mk2cs(M, N).x;
    static constexpr int S = mk2cs(M, N).y;

    static constexpr int size = M * N;

    static constexpr uint16_t kMaskC = (1 << C) - 1;
    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;

    __device__ static ushort2 mask_cs(int cta_id)
    {
        const auto [c, s] = cta_cs(cta_id);
        return make_ushort2(kMaskS << c, kMaskC << s * C);
    }

    __device__ static ushort2 mask_mn(int cta_id)
    {
        auto [c, s] = mask_cs(cta_id);
        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};
    }

    __device__ static int2 cta_cs(int cta_id)
    {
        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};
    }

    __device__ static int2 cta_mn(int cta_id)
    {
        return cs2mk(cta_cs(cta_id));
    }

    int2    cta_mn_;
    ushort2 mask_mn_;

    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}

    __device__ int cta_m()
    {
        return cta_mn_.x;
    }

    __device__ int cta_n()
    {
        return cta_mn_.y;
    }

    __device__ uint16_t mask_m()
    {
        return mask_mn_.x;
    }

    __device__ uint16_t mask_n()
    {
        return mask_mn_.y;
    }
};

}  // namespace arch

struct GemmUniversalSm90_v3 {

    static constexpr bool kDebug = false;

    using Arch = Sm90;

    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS;
    using MMA_Atom = GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>;
    static constexpr typename cute::MMA_Traits::Shape_MNK MMA_Shape{};

    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);
    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);
    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);

    static constexpr int TILE_M = 128;
    static constexpr int TILE_N = 192;
    static constexpr int TILE_K = 128;

    static constexpr int WG_M = 2;
    static constexpr int WG_N = 1;

    static constexpr int WG_TILE_M = TILE_M / WG_M;
    static constexpr int WG_TILE_N = TILE_N / WG_N;

    static constexpr int kSchedWarpGroups = 1;

    static constexpr int WARPGORUPS = WG_M * WG_N;

    static constexpr int MMA_ITER_M = WG_TILE_M / MMA_ATOM_M;
    static constexpr int MMA_ITER_N = WG_TILE_N / MMA_ATOM_N;
    static constexpr int MMA_ITER_K = TILE_K / MMA_ATOM_K;

    static constexpr int kMulticastA = 1;
    static constexpr int kMulticastB = 2;

    static constexpr int kClusterSize = kMulticastA * kMulticastB;

    static constexpr int Stages = 4;

    static constexpr bool kSplitK     = false;
    static constexpr int  kChunkSizeK = TILE_K;

    static constexpr int WARPGROUP_SIZE = 128;

    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);

    using Ta = __nv_fp8_e4m3;
    using Tb = __nv_fp8_e4m3;
    using Tc = nv_bfloat16;

    using Tu = float;
    using Tv = float;

    using Cluster = arch::Cluster;

    using Scheduler = TileScheduler;

    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;
    using ConsumerBar = cutlass::arch::ClusterBarrier;

    static constexpr int MAX_K = 32768;

    static constexpr int TILE_M_U = cdiv(TILE_M, 1);
    static constexpr int CTA_K_U  = cdiv(TILE_K, 128);

    static constexpr int kTmaTxBytes =
        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * TILE_M_U * CTA_K_U;

    // ! Smem addr must be SBO aligned for TMA load/store
    struct SharedStorage {
        struct Source {
            __align__(1024) Array A;
            __align__(1024) Array B;
            __align__(1024) Tu U[Stages][TILE_M_U * CTA_K_U];
            // __align__(1024) Tv V[2][WARPGORUPS][cdiv(MAX_K, 128)];
            __align__(1024) Tv V[Stages][2 * cdiv(MAX_K, 128)];
        };
        Source source;
        __align__(1024) Array C;
        __align__(128) uint64_t producer_bar[Stages];
        __align__(128) uint64_t consumer_bar[Stages];
        __align__(128) CUtensorMap tma_desc_C[WARPGORUPS];
    };

    static constexpr int kSmemSize = sizeof(SharedStorage);

    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));

    using LayoutC = std::conditional_t= 32,
                                       SmemLayoutV2,
                                       SmemLayoutV2>;

    __device__ void operator()(const CUtensorMap& tm_a,
                               const CUtensorMap& tm_b,
                               const CUtensorMap& tm_c,
                               const CUtensorMap& tm_u,
                               const CUtensorMap& tm_v,
                               const MatrixParam& param_A,
                               const MatrixParam& param_B,
                               const MatrixParam& param_U,
                               const MatrixParam& param_V,
                               const MatrixParam& param_C,
                               uint2              box_V,
                               Scheduler          sched,
                               CUtensorMap*       tensormap_buf,
                               char*              smem_buf)
    {
        SharedStorage& storage = *reinterpret_cast(smem_buf);

        uint64_t* producer_bar = storage.producer_bar;
        uint64_t* consumer_bar = storage.consumer_bar;

        if (threadIdx.x == 0) {
            PRAGMA_UNROLL
            for (int s = 0; s < Stages; ++s) {
                ProducerBar::init(&producer_bar[s], 1);
                ConsumerBar::init(&consumer_bar[s], WARPGORUPS * kClusterSize * 4);
            }
            cutlass::arch::fence_view_async_shared();
            if constexpr (kClusterSize > 1) {
                cutlass::arch::fence_barrier_init();
            }
        }

        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();

        const int wg_idx = cutlass::canonical_warp_group_idx();

        if (wg_idx == WARPGORUPS) {
            cutlass::arch::warpgroup_reg_dealloc<40>();

            static_assert(TILE_M % kMulticastA == 0);
            static_assert(TILE_N % kMulticastB == 0);

            // if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {
            if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {

                Cluster cluster(cute::block_id_in_cluster().x);

                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);
                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);

                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;
                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;
                auto& smem_U = storage.source.U;
                auto& smem_V = storage.source.V;

                sched.grid_init();

                cutlass::PipelineState write_state{0, 1, 0};
                cutlass::PipelineState v_state{0, 1, 0};

                while (sched.next()) {
                    if (cute::elect_one_sync()) {
                        auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                        if (!cluster_tile_p) {
                            // OOB tile caused by swizzle pattern
                            continue;
                        }

                        const auto tile_offset              = sched.tile_offset();
                        const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                        const int offset_k = iter_k_beg * TILE_K;

                        const uint16_t mask_A = cluster.mask_m();
                        const uint16_t mask_B = cluster.mask_n();

                        const int offset_m = tile_offset.x * TILE_M;
                        const int offset_n = tile_offset.y * TILE_N;

                        int k_iter = iter_k_end - iter_k_beg;

                        GmemIteratorSm90 gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};
                        GmemIteratorSm90 gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};
                        GmemIteratorSm90 gmem_U{&tm_u, {offset_m + mc_offset_m, offset_k / 128}, {0, 1}};
                        GmemIteratorSm90<1>           gmem_V(&tm_v, {0, offset_n / 128}, {0, 0});

                        {
                            int pipe = write_state.index();
                            ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                            const int v_bytes = sizeof(Tv) * box_V.x * box_V.y;
                            ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes + v_bytes);
                            gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);
                            gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);
                            gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);
                            gmem_V.Step(&producer_bar[pipe], &smem_V[v_state.index()], 0);
                            ++write_state;
                            --k_iter;
                        }

                        for (; k_iter > 0; --k_iter) {
                            int pipe = write_state.index();
                            ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                            ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);
                            gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);
                            gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);
                            gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);
                            ++write_state;
                        }

                        ++v_state;
                    }
                }
            }
        }
        else {
            cutlass::arch::warpgroup_reg_alloc<232>();

            sched.grid_init(kSchedWarpGroups);

            auto& smem_A = storage.source.A;
            auto& smem_B = storage.source.B;
            auto& smem_U = storage.source.U;

            const int wg_idx_m = WG_M > 1 ? wg_idx % WG_M : 0;
            const int wg_idx_n = WG_N > 1 ? wg_idx / WG_M : 0;

            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);
            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);

            SmemDescIterV2> 4)> smem_iter_A{smem_desc_A};
            SmemDescIterV2> 4)> smem_iter_B{smem_desc_B};

            constexpr int kStepMA = (sizeof(Ta) * MMA_ATOM_M * TILE_K) >> 4;
            constexpr int kStepNB = (sizeof(Tb) * MMA_ATOM_N * TILE_K) >> 4;
            constexpr int kStepKA = (sizeof(Ta) * MMA_ATOM_K) >> 4;
            constexpr int kStepKB = (sizeof(Tb) * MMA_ATOM_K) >> 4;

            cutlass::arch::NamedBarrier wg_barrier(WARPGROUP_SIZE, wg_idx + 2);  // 2,3

            auto epi_barrier = [&](int phase) {  // 0, 1
                return EmptyBarrier{};
                // return cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE, wg_idx ^ phase);
            };

            if (wg_idx == 1) {
                epi_barrier(1).arrive_unaligned();
            }

            cutlass::PipelineState pipe_state{};
            cutlass::PipelineState v_state{};

            while (sched.next(kSchedWarpGroups)) {
                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();

                if (!cluster_tile_p) {
                    // OOB tile caused by swizzle pattern
                    continue;
                }

                MMA_Atom::CRegisters frag_C[MMA_ITER_N];
                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};

                const auto tile_offset              = sched.tile_offset();
                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();

                // const auto [M, N, K, L] = sched.gemm_shape();

                const int offset_m = tile_offset.x * TILE_M;
                const int offset_n = tile_offset.y * TILE_N;

                const int wg_offset_n = offset_n + wg_idx_n * WG_TILE_N;

                int k_iter = iter_k_end - iter_k_beg;

                const int warp_id = threadIdx.x / WARP_SIZE;
                const int lane_id = threadIdx.x % WARP_SIZE;

                auto consumer_arrive = [&] {
                    __syncwarp();
                    if constexpr (kClusterSize > 1) {
                        ConsumerBar::arrive(&consumer_bar[pipe_state.index()], lane_id, lane_id < kClusterSize);
                    }
                    else {
                        if (lane_id == 0) {
                            ConsumerBar::arrive(&consumer_bar[pipe_state.index()]);
                        }
                    }
                    __syncwarp();
                };

                if constexpr (kClusterSize > 1) {
                    if (!cta_tile_p) {  // other CTAs in the cluster are still alive
                        for (; k_iter > 0; --k_iter) {
                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                            consumer_arrive();
                            ++pipe_state;
                        }
                        epi_barrier(0).arrive_and_wait_unaligned();
                        epi_barrier(1).arrive_unaligned();
                        continue;
                    }
                }

                // auto Copy = [k = cdiv(K, 128)](Tv* dst, const Tv* src) {
                //     PRAGMA_NO_UNROLL
                //     for (int i = threadIdx.x % WARPGROUP_SIZE; i < k; i += WARPGROUP_SIZE) {
                //         dst[i] = __ldg(&src[i]);
                //     }
                // };
                // auto gmem_V = (const Tv*)param_V.ptr + (wg_offset_n / 128) * param_V.stride + (offset_k / 128);
                // Copy(storage.source.V[0][wg_idx], gmem_V);

                uint32_t pred_V{};
                int      iter_V{};

                constexpr int OUTER_N = std::gcd(MMA_ATOM_N, 128);
                if constexpr (OUTER_N != 128) {

                    static_assert(MMA_ATOM_N <= 128 + OUTER_N, "MMA inst is crossing more than 2 scale blocks");

                    constexpr uint32_t mask = (1UL << (WG_TILE_N / OUTER_N)) - 1;

                    int phase = 128 - wg_offset_n % 128;
                    pred_V    = (mask << (phase / OUTER_N)) & mask;

                    // if (pred_V && wg_offset_n / 128 + 1 < cdiv(N, 128)) {
                    //     Copy(storage.source.V[1][wg_idx], gmem_V + param_V.stride);
                    // }
                    // if constexpr (WG_N > 1) {
                    //     constexpr int tiles = MMA_ATOM_N / OUTER_N;
                    //     pred_V              = (pred_V >> (wg_idx_n * tiles)) & ((1 << tiles) - 1);
                    // }
                }

                __syncwarp();

                float scale_V[2];
                // auto  Load_V = [&] {
                //     scale_V[0] = storage.source.V[0][wg_idx][iter_V];
                //     if (pred_V) {
                //         scale_V[1] = storage.source.V[1][wg_idx][iter_V];
                //     }
                //     ++iter_V;
                // };
                auto Load_V = [&] {
                    // scale_V[0] = scale_V[1] = 1;
                    scale_V[0] = storage.source.V[v_state.index()][iter_V];
                    if (pred_V) {
                        scale_V[1] = storage.source.V[v_state.index()][box_V.x + iter_V];
                    }
                    ++iter_V;
                };

                float     scale_U[MMA_ITER_M][2];
                const int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;
                auto      Load_U   = [&] {
                    for (int m = 0; m < MMA_ITER_M; ++m) {
                        scale_U[m][0] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M];
                        scale_U[m][1] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M + 8];
                    }
                };

                auto scale_accum = [&](int m) {  // cta_n = mma_iter_n * wg_n * mma_atom_n
                    float scales[2][2];
                    scales[0][0] = scale_U[m][0] * scale_V[0];
                    scales[1][0] = scale_U[m][1] * scale_V[0];
                    scales[0][1] = scale_U[m][0] * scale_V[1];
                    scales[1][1] = scale_U[m][1] * scale_V[1];
                    cute::warpgroup_wait<0>();
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {
                        PRAGMA_UNROLL
                        for (int c0 = 0; c0 < MMA_ATOM_N; c0 += OUTER_N) {
                            bool pred = (pred_V & (1U << (c0 / OUTER_N)));
                            PRAGMA_UNROLL
                            for (int cc = 0; cc < OUTER_N; cc += 8) {
                                int c = c0 + cc;
                                // clang-format off
                                accum_C[m][n][c / 2 + 0] += (pred ? scales[0][1] : scales[0][0]) * frag_C[n][c / 2 + 0];
                                accum_C[m][n][c / 2 + 1] += (pred ? scales[0][1] : scales[0][0]) * frag_C[n][c / 2 + 1];
                                accum_C[m][n][c / 2 + 2] += (pred ? scales[1][1] : scales[1][0]) * frag_C[n][c / 2 + 2];
                                accum_C[m][n][c / 2 + 3] += (pred ? scales[1][1] : scales[1][0]) * frag_C[n][c / 2 + 3];
                                // clang-format on
                            }
                        }
                    }
                };

                auto gmma = [&](int m) {
                    PRAGMA_UNROLL
                    for (int k = 0; k < MMA_ITER_K; ++k) {
                        PRAGMA_UNROLL
                        for (int n = 0; n < MMA_ITER_N; ++n) {
                            wgmma(smem_iter_A, smem_iter_B, frag_C[n], k == 0);
                            smem_iter_B += kStepNB;
                        }
                        smem_iter_B -= MMA_ITER_N * kStepNB;
                        smem_iter_A += kStepKA;
                        smem_iter_B += kStepKB;
                    }
                    smem_iter_A -= MMA_ITER_K * kStepKA;
                    smem_iter_B -= MMA_ITER_K * kStepKB;
                    smem_iter_A += kStepMA;
                    cute::warpgroup_commit_batch();
                };

                static_assert(MMA_ITER_N == 1);

                wg_barrier.sync();

                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                Load_V();
                Load_U();
                smem_iter_A.Reset(pipe_state.index());
                smem_iter_B.Reset(pipe_state.index());
                cute::warpgroup_arrive();
                gmma(0);
                scale_accum(0);
                consumer_arrive();
                ++pipe_state;
                --k_iter;

                Load_V();
                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                Load_U();
                smem_iter_A.Reset(pipe_state.index());
                smem_iter_B.Reset(pipe_state.index());

                for (; k_iter > 1; --k_iter) {
                    cute::warpgroup_arrive();
                    gmma(0);
                    scale_accum(0);
                    consumer_arrive();
                    ++pipe_state;
                    Load_V();
                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());
                }

                cute::warpgroup_arrive();
                gmma(0);
                scale_accum(0);
                consumer_arrive();
                ++pipe_state;
                ++v_state;

                const int wg_lane = threadIdx.x % WARPGROUP_SIZE;

                if (wg_lane < LayoutC::C1) {
                    cute::tma_store_wait<0>();
                }

                epi_barrier(0).arrive_and_wait_unaligned();

                wg_barrier.sync();

                // void* Cdesc{};
                // if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {
                //     Cdesc = update_tma_desc(tm_c, tensormap_buf, &storage.tma_desc_C[wg_idx], wg_idx, param_C.ptr,
                //     M);
                // }

                Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];

                // epilogue
                PRAGMA_UNROLL
                for (int m = 0; m < MMA_ITER_M; ++m) {
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA_ITER_N; ++n) {

                        constexpr int N       = LayoutC::C0;
                        constexpr int SW_bits = log2(kSwizzleC / 16);

                        static_assert(!SW_bits || MMA_ATOM_N % LayoutC::C0 == 0);

                        const int m0 = m * MMA_ATOM_M;
                        const int n0 = n * MMA_ATOM_N;

                        PRAGMA_UNROLL
                        for (int i = 0; i < MMA_ATOM_N; i += 16) {
                            __align__(16) Array tvec = cast(*(Array*)&accum_C[m][n][i / 2]);

                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);
                            int nn = n0 + i / N * N;

                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);

                            int s = lane_id % 8;
                            int c = (lane_id & 16) / 2 + i % N;

                            addr += Swizzle::apply(s * N + c);

                            auto& uvec = (Array&)tvec;
                            cute::SM90_U32x4_STSM_N::copy(uvec[0],  //
                                                          uvec[1],
                                                          uvec[2],
                                                          uvec[3],
                                                          (cutlass::uint128_t&)smem_C[addr]);
                        }
                    }
                }

                cute::tma_store_fence();  // visibility: smem -> async proxy

                wg_barrier.sync();

                if (wg_lane < LayoutC::C1) {
                    const int tma_n = wg_lane * LayoutC::C0;
                    // cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);
                    cute::SM90_TMA_STORE::copy(&tm_c,
                                               &smem_C[wg_lane * WG_TILE_M * LayoutC::C0],
                                               offset_n + wg_idx_n * WG_TILE_N + tma_n,
                                               offset_m + wg_idx_m * WG_TILE_M);
                    cute::tma_store_arrive();
                }

                epi_barrier(1).arrive_unaligned();

            }  // scheduler loop

            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {
                cute::tma_store_wait<0>();
            }

            if (wg_idx == 0) {
                epi_barrier(0).arrive_and_wait_unaligned();
            }
        }

        if constexpr (kClusterSize > 1) {
            cute::cluster_arrive();
            cute::cluster_wait();
        }

    }  // operator()

    struct EmptyBarrier {
        __device__      EmptyBarrier(...) {}
        __device__ void arrive_and_wait_unaligned() {}
        __device__ void arrive_unaligned() {}
    };

    __device__ void* update_tma_desc(const CUtensorMap& param_desc,
                                     CUtensorMap*       gmem_desc,
                                     CUtensorMap*       smem_desc,
                                     int                index,
                                     void*              global_addr,
                                     int                dim)
    {
        uint32_t uint_ptr = cast_smem_ptr_to_uint(smem_desc);

        const int lane_id = threadIdx.x % WARP_SIZE;

        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {
            ((uint2*)smem_desc)[lane_id] = ((uint2*)¶m_desc)[lane_id];
        }

        __syncwarp();

        if (lane_id == 0) {
            // clang-format off
            asm volatile("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" ::"r"(uint_ptr), "l"(global_addr));
            asm volatile("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" ::"r"(uint_ptr), "r"(dim));
            // clang-format on
        }

        __syncwarp();

        constexpr int kNumPerCta = 4;

        auto gmem_ptr = (void*)&gmem_desc[blockIdx.x * kNumPerCta + index];

        // clang-format off
        asm volatile("tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" :: "l"(gmem_ptr), "r"(uint_ptr));
        // clang-format on

        return gmem_ptr;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gemm_universal_sm90_v5.h
================================================
#pragma once

#include 
#include 

#include 
#include 

#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
#include "cute/arch/mma_sm90_desc.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/smem.h"

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/cp_async.h"
#include "src/turbomind/kernels/gemm/iterator_sm90.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/scheduler.cuh"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

#include "src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h"
#include "src/turbomind/kernels/gemm/sm90_utils.h"

namespace turbomind::gemm {

template
struct GemmUniversalSm90_v5 {

    static constexpr bool kDebug = false;

    using Arch = Sm90;

    static constexpr int WARPGORUPS = 4;

    static constexpr int TILE_M = 128;
    static constexpr int TILE_N = 96;
    static constexpr int TILE_K = 128;

    static constexpr int WG_M = 2;
    static constexpr int WG_N = 1;

    static constexpr int WG_TILE_M = TILE_M / WG_M;
    static constexpr int WG_TILE_N = TILE_N / WG_N;

    using GMMA = ScaledGmmaFP8_TN;

    static constexpr int kMulticastA = multicast_a;
    static constexpr int kMulticastB = multicast_b;

    static constexpr int kClusterSize = kMulticastA * kMulticastB;

    static constexpr int Stages = 5;

    static constexpr bool kSplitK     = false;
    static constexpr int  kChunkSizeK = TILE_K;

    static constexpr int WARPGROUP_SIZE = 128;
    static constexpr int kMathGroupSize = 256;

    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);

    using Ta = __nv_fp8_e4m3;
    using Tb = __nv_fp8_e4m3;
    using Tc = nv_bfloat16;

    using Tu = float;
    using Tv = float;

    using Cluster = arch::Cluster;

    static constexpr auto is_grouped_gemm = is_grouped_gemm_;

    using Scheduler = TileScheduler;

    static constexpr int kMulticastU = is_grouped_gemm ? 1 : kMulticastA;

    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;
    using ConsumerBar = cutlass::arch::ClusterBarrier;

    static constexpr int MAX_K        = 32768;
    static constexpr int MAX_K_BLOCKS = cdiv(MAX_K, 128);

    static constexpr int kAlignmentU = 16 / sizeof(Tu);
    static constexpr int kBoxU       = TILE_M + (is_grouped_gemm ? kAlignmentU : 0);

    // Alignment requirement for SMEM addr. This forbids multicast factor 8.
    static_assert(kMulticastU == 1 || sizeof(Tu) * kBoxU / kMulticastU % 128 == 0);

    static constexpr int kTmaTxBytes =
        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * kBoxU;

    static constexpr int kTmaDescNum = 7;

    // ! Smem addr must be SBO aligned for TMA load/store
    struct SharedStorage {
        struct Source {
            __align__(1024) Array A;
            __align__(1024) Array B;
            __align__(1024) Tu U[Stages][round_up(kBoxU, 128)];
            __align__(1024) Tv V[Stages][2];
        };
        Source source;
        __align__(1024) Array C;
        __align__(128) uint64_t producer_bar[Stages];
        __align__(128) uint64_t consumer_bar[Stages];
        __align__(128) CUtensorMap tma_desc_buf[kTmaDescNum];  //
        int                         pipe_count[2];
        typename Scheduler::Storage sched;
    };

    static constexpr int kSmemSize = sizeof(SharedStorage);

    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));

    using LayoutC = std::conditional_t= 32,
                                       SmemLayoutV2,
                                       SmemLayoutV2>;

    static constexpr int OUTER_N       = GMMA::OUTER_N;
    static constexpr int MMA_SUBTILE_N = GMMA::OP_N / OUTER_N;

    __device__ void operator()(const CUtensorMap& tm_a,
                               const CUtensorMap& tm_b,
                               const CUtensorMap& tm_c,
                               const CUtensorMap& tm_u,
                               const CUtensorMap& tm_v,
                               const MatrixParam& param_A,
                               const MatrixParam& param_B,
                               const MatrixParam& param_U,
                               const MatrixParam& param_V,
                               const MatrixParam& param_C,
                               Scheduler          sched,
                               CUtensorMap*       tensormap_buf,
                               char*              smem_buf)
    {
        SharedStorage& storage = *reinterpret_cast(smem_buf);

        uint64_t* producer_bar = storage.producer_bar;
        uint64_t* consumer_bar = storage.consumer_bar;

        if (threadIdx.x == 0) {
            PRAGMA_UNROLL
            for (int s = 0; s < Stages; ++s) {
                ProducerBar::init(&producer_bar[s], 1 + 1);
                ConsumerBar::init(&consumer_bar[s], kClusterSize * (kMathGroupSize / WARP_SIZE));
            }
            sched.init_dyanmic(storage.sched, kClusterSize * (kMathGroupSize / WARP_SIZE + 1));
            cutlass::arch::fence_view_async_shared();
            if constexpr (kClusterSize > 1) {
                cutlass::arch::fence_barrier_init();
            }
            PRAGMA_UNROLL
            for (int i = 0; i < 2; ++i) {
                storage.pipe_count[i] = 0;
            }
        }

        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();

        const int wg_idx = cutlass::canonical_warp_group_idx();

        if (wg_idx == WARPGORUPS) {
            cutlass::arch::warpgroup_reg_dealloc<32>();

            static_assert(TILE_M % kMulticastA == 0);
            static_assert(TILE_N % kMulticastB == 0);

            cutlass::arch::NamedBarrier producers_bar(WARP_SIZE * 2, 6);

            const int  warp_id = cutlass::canonical_warp_idx_sync();
            const bool cta_0   = cute::block_id_in_cluster().x == 0;

            if (warp_id % 4 == 0) {

                Cluster cluster(cute::block_id_in_cluster().x);

                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);
                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);

                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;
                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;
                auto& smem_U = storage.source.U;
                auto& smem_V = storage.source.V;

                if constexpr (is_grouped_gemm) {
                    init_tma_descs<3>({&tm_a, &tm_b, &tm_u}, storage.tma_desc_buf);
                }

                cutlass::PipelineState write_state{0, 1, 0};

                auto sched_state = sched.init_consumer(storage.sched);

                int lane_predicate = cute::elect_one_sync();

                typename Scheduler::Tile* tile;

                while (sched_state.acquire(tile)) {

                    // if (cute::elect_one_sync()) {
                    //     printf("READ m %d n %d g %d v %s\n",
                    //            tile->offset_m,
                    //            tile->offset_n,
                    //            tile->group_idx,
                    //            tile->is_valid_cluster ? "true" : "false");
                    // }

                    if constexpr (Scheduler::is_dynamic) {
                        if (cta_0) {
                            producers_bar.arrive_unaligned();
                        }
                    }

                    if (tile->is_valid_cluster) {

                        const CUtensorMap* Adesc = &tm_a;
                        const CUtensorMap* Bdesc = &tm_b;
                        const CUtensorMap* Udesc = &tm_u;

                        const Tv* gmem_V0 = (const Tv*)param_V.ptr;
                        const Tv* gmem_V1;

                        if constexpr (is_grouped_gemm) {
                            const int g  = tile->group_idx;
                            const int m0 = tile->m0;
                            const int m1 = tile->m1;
                            const int m  = m1 - m0;

                            Array global_addrs;
                            global_addrs[0] = (Ta*)param_A.ptr + m0 * (int64_t)param_A.stride;
                            global_addrs[1] = ((void**)param_B.ptr)[g];

                            const int beg_u = m0 / kAlignmentU * kAlignmentU;
                            const int end_u = round_up(m1, kAlignmentU);
                            global_addrs[2] = (Tu*)param_U.ptr + beg_u;

                            Array dims;
                            dims[0] = m;
                            dims[1] = sched.gemm_shape().y;
                            dims[2] = end_u - beg_u;

                            auto descs = update_tma_descs(tensormap_buf, storage.tma_desc_buf, global_addrs, dims);
                            Adesc      = &descs[0];
                            Bdesc      = &descs[1];
                            Udesc      = &descs[2];

                            gmem_V0 = ((Tv**)gmem_V0)[g];

                            PRAGMA_UNROLL
                            for (int i = 0; i < 3; ++i) {
                                cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)&descs[i]);
                            }
                        }

                        if (lane_predicate) {

                            const int offset_k = 0;

                            const uint16_t mask_A = cluster.mask_m();
                            const uint16_t mask_B = cluster.mask_n();

                            const int offset_m = tile->offset_m;
                            const int offset_n = tile->offset_n;

                            int k_iter = sched.k_iters_;

                            GmemIteratorSm90 gmem_A{
                                Adesc, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};
                            GmemIteratorSm90 gmem_B{
                                Bdesc, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};

                            const int mc_offset_u = kMulticastU > 1 ? mc_offset_m : 0;
                            // column-major
                            GmemIteratorSm90 gmem_U{
                                Udesc, {offset_m + mc_offset_u, offset_k / 128}, {0, 1}};

                            gmem_V0 += (offset_n / 128) * param_V.stride + (offset_k / 128);
                            gmem_V1 = gmem_V0;
                            if (offset_n / 128 + 1 < cdiv(sched.gemm_shape().y, 128)) {
                                gmem_V1 += param_V.stride;
                            }

                            for (; k_iter > 0; --k_iter) {
                                int pipe = write_state.index();
                                ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());
                                ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);
                                gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);
                                gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);
                                gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_u, mask_A);
                                uint32_t uint_ptr_V = cast_smem_ptr_to_uint(smem_V[pipe]);
                                CP_ASYNC::apply(uint_ptr_V, gmem_V0, true);
                                CP_ASYNC::apply(uint_ptr_V + sizeof(Tv), gmem_V1, true);
                                ++gmem_V0;
                                ++gmem_V1;
                                cutlass::arch::cpasync_barrier_arrive_noinc(&producer_bar[pipe]);
                                ++write_state;
                            }
                        }
                    }

                    sched_state.release();

                }  // scheduler loop

                sched_state.release();

                // pair with the EXTRA tile
                sched_state.acquire(tile);
                sched_state.release();

                if constexpr (kClusterSize > 1) {
                    if (lane_predicate) {
                        for (int i = 0; i < Stages; ++i) {
                            ConsumerBar::wait(&consumer_bar[write_state.index()], write_state.phase());
                            ++write_state;
                        }
                    }
                    __syncwarp();
                }
            }
            else if (warp_id % 4 == 1 && cta_0) {
                auto sched_state = sched.init_producer(storage.sched);
                while (sched_state.next()) {
                    if constexpr (Scheduler::is_dynamic) {
                        producers_bar.arrive_and_wait_unaligned();
                    }
                }
                // send EXTRA null tile (to math WGs)
                sched_state.next();
                sched.tail(sched_state);
            }
        }
        else {
            cutlass::arch::warpgroup_reg_alloc<112>();

            const int math_group_idx = wg_idx / 2;

            if constexpr (is_grouped_gemm) {
                if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {
                    init_tma_descs<1>({&tm_c}, storage.tma_desc_buf + 3 + wg_idx);
                }
            }

            auto& smem_A = storage.source.A;
            auto& smem_B = storage.source.B;
            auto& smem_U = storage.source.U;
            auto& smem_V = storage.source.V;

            const int wg_idx_m = WG_M > 1 ? wg_idx % 2 % WG_M : 0;
            const int wg_idx_n = WG_N > 1 ? wg_idx % 2 / WG_M : 0;

            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);
            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);

            SmemDescIterV2> 4)> smem_iter_A{smem_desc_A};
            SmemDescIterV2> 4)> smem_iter_B{smem_desc_B};

            const int  thread_idx    = threadIdx.x % WARPGROUP_SIZE;
            const bool math_leader_p = threadIdx.x % kMathGroupSize == 0;

            auto math_barrier_sync = [&](int phase, int alive = 1) {
                constexpr int base       = (int)cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;
                const int     barrier_id = base + math_group_idx ^ phase;
                constexpr int threads    = WARPGORUPS * WARPGROUP_SIZE;
                int           res        = 0;
                asm volatile("{\n"
                             "  .reg.pred p;\n"
                             "  setp.ne.b32 p, %3, 0;\n"
                             "  barrier.cta.red.or.pred p, %1, %2, p;\n"
                             "  selp.s32 %0, 1, 0, p;\n"
                             "}\n"
                             : "=r"(res)
                             : "r"(barrier_id), "r"(threads), "r"(alive));
                return res;
            };

            cutlass::arch::NamedBarrier barrier(WARPGROUP_SIZE, 2 + wg_idx);  // 2,3,4,5

            cutlass::PipelineState pipe_state{};

            const int warp_id = cutlass::canonical_warp_idx_sync();
            const int lane_id = cutlass::canonical_lane_idx();

            auto consumer_arrive = [&] {
                auto bar = &consumer_bar[pipe_state.index()];
                __syncwarp();
                if constexpr (kClusterSize > 1) {
                    ConsumerBar::arrive(bar, lane_id, lane_id < kClusterSize);
                }
                else {
                    if (lane_id == 0) {
                        ConsumerBar::arrive(bar);
                    }
                }
            };

            auto sched_state = sched.init_consumer(storage.sched);

            if (math_group_idx == 1) {
                ++sched_state.pipe;
                math_barrier_sync(1);
            }

            typename Scheduler::Tile* tile;

            sched_state.acquire(tile);

            while (tile->alive) {

                if (tile->is_valid_cta) {

                    GMMA::AccumC accum_C{};
                    GMMA::FragC  frag_C;

                    const auto [_, N, K, L] = sched.gemm_shape();

                    const int offset_m = tile->offset_m;
                    const int offset_n = tile->offset_n;

                    int k_iter = sched.k_iters_;

                    auto pred_V = Fetch_V(param_V, K, N, tile, math_group_idx, wg_idx_n, storage);

                    float scale_V[2];
                    auto  Load_V = [&] {
                        scale_V[0] = smem_V[pipe_state.index()][0];
                        scale_V[1] = smem_V[pipe_state.index()][1];
                    };

                    int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;
                    if constexpr (is_grouped_gemm) {
                        offset_U += tile->m0 % kAlignmentU;
                    }
                    GMMA::FragU frag_U;
                    auto        Load_U = [&] {
                        GMMA::foreach_m(frag_U, [&](auto& U, int m) {
                            U[0] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M];
                            U[1] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M + 8];
                        });
                    };

                    auto gmma = [&] {  //
                        GMMA::apply(smem_iter_A, smem_iter_B, frag_C, accum_C, frag_U, scale_V, pred_V);
                    };

                    if constexpr (is_grouped_gemm) {
                        if (warp_id % 4 == 0) {
                            int  m0 = tile->m0, m1 = tile->m1;
                            auto addr = (Tc*)param_C.ptr + m0 * (int64_t)param_C.stride;
                            int  idx  = 3 + wg_idx;
                            update_tma_descs<1>(tensormap_buf + idx, storage.tma_desc_buf + idx, {addr}, {m1 - m0});
                        }
                    }

                    math_barrier_sync(0);

                    pipe_state = {};
                    pipe_state.advance(storage.pipe_count[math_group_idx ^ 1]);

                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_V();
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());
                    gmma();
                    consumer_arrive();
                    ++pipe_state;
                    --k_iter;

                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                    Load_V();
                    Load_U();
                    smem_iter_A.Reset(pipe_state.index());
                    smem_iter_B.Reset(pipe_state.index());

                    PRAGMA_NO_UNROLL
                    for (; k_iter > 1; --k_iter) {
                        gmma();
                        consumer_arrive();
                        ++pipe_state;
                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                        Load_V();
                        Load_U();
                        smem_iter_A.Reset(pipe_state.index());
                        smem_iter_B.Reset(pipe_state.index());
                    }

                    if (math_leader_p) {
                        storage.pipe_count[math_group_idx] = pipe_state.count() + 1;
                    }
                    math_barrier_sync(1);

                    gmma();
                    consumer_arrive();

                    Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];

                    GMMA::foreach_C(accum_C, [&](const auto& C, int m, int n) {
                        constexpr int N       = LayoutC::C0;
                        constexpr int SW_bits = log2(kSwizzleC / 16);

                        static_assert(!SW_bits || GMMA::OP_N % LayoutC::C0 == 0);
                        static_assert(GMMA::OP_N % 16 == 0);

                        const int m0 = m * GMMA::OP_M;
                        const int n0 = n * GMMA::OP_N;

                        PRAGMA_UNROLL
                        for (int i = 0; i < GMMA::OP_N; i += 16) {
                            __align__(16) Array tvec = cast(*(Array*)&C[i / 2]);
                            // fill(tvec, Tc(255));
                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);
                            int nn = n0 + i / N * N;

                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);

                            int s = lane_id % 8;
                            int c = (lane_id & 16) / 2 + i % N;

                            addr += Swizzle::apply(s * N + c);

                            auto& uvec = (Array&)tvec;
                            cute::SM90_U32x4_STSM_N::copy(
                                uvec[0], uvec[1], uvec[2], uvec[3], (cutlass::uint128_t&)smem_C[addr]);
                        }
                    });

                    cute::tma_store_fence();  // visibility: smem -> async proxy

                    barrier.sync();

                    if (thread_idx < LayoutC::C1) {
                        const void* Cdesc = &tm_c;
                        const int   tma_n = thread_idx * LayoutC::C0;
                        if constexpr (is_grouped_gemm) {
                            Cdesc = &tensormap_buf[blockIdx.x * kTmaDescNum + 3 + wg_idx];
                            cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);
                        }
                        cute::SM90_TMA_STORE::copy(Cdesc,
                                                   &smem_C[thread_idx * WG_TILE_M * LayoutC::C0],
                                                   offset_n + wg_idx_n * WG_TILE_N + tma_n,
                                                   offset_m + wg_idx_m * WG_TILE_M);
                        cute::tma_store_arrive();
                        cute::tma_store_wait<0>();
                    }

                }  // valid cta tile
                else {
                    math_barrier_sync(0);

                    pipe_state = {};
                    pipe_state.advance(storage.pipe_count[math_group_idx ^ 1]);

                    if (tile->is_valid_cluster) {
                        // other CTAs in the cluster are still alive
                        for (int k_iter = sched.k_iters_; k_iter > 0; --k_iter) {
                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());
                            consumer_arrive();
                            ++pipe_state;
                        }
                    }

                    if (math_leader_p) {
                        storage.pipe_count[math_group_idx] = pipe_state.count();
                    }

                    math_barrier_sync(1);
                }

                sched_state.release(2);
                sched_state.acquire(tile);
            }  // scheduler loop

            sched_state.release(2);  // release the last tile

            if (math_group_idx == 0) {
                math_barrier_sync(0, 0);
                if (math_leader_p) {
                    storage.pipe_count[0] = storage.pipe_count[1];
                }
                while (math_barrier_sync(1, 0)) {
                    math_barrier_sync(0, 0);
                    if (math_leader_p) {
                        storage.pipe_count[0] = storage.pipe_count[1];
                    }
                }
            }
            else {
                while (math_barrier_sync(0, 0)) {
                    if (math_leader_p) {
                        storage.pipe_count[1] = storage.pipe_count[0];
                    }
                    math_barrier_sync(1, 0);
                }
            }
        }

    }  // operator()

    template
    __device__ void init_tma_descs(Array param_desc, CUtensorMap* smem_desc)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;

        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {
            PRAGMA_UNROLL
            for (int i = 0; i < N; ++i) {
                ((uint2*)&smem_desc[i])[lane_id] = ((uint2*)param_desc[i])[lane_id];
            }
        }

        __syncwarp();
    }

    template
    __device__ CUtensorMap*
    update_tma_descs(CUtensorMap* gmem_desc, CUtensorMap* smem_desc, Array global_addrs, Array dims)
    {
        const int lane_id = threadIdx.x % WARP_SIZE;
        if (lane_id == 0) {
            PRAGMA_UNROLL
            for (int i = 0; i < N; ++i) {
                uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);
                // clang-format off
                asm volatile("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" ::"r"(uint_ptr), "l"(global_addrs[i]));
                if (i != 2) {
                    asm volatile("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" ::"r"(uint_ptr), "r"(dims[i]));
                } else { // special case for U
                    asm volatile("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" ::"r"(uint_ptr), "r"(dims[i]));
                }
                // clang-format on
            }
        }

        __syncwarp();

        auto gmem_ptr = &gmem_desc[blockIdx.x * kTmaDescNum];
        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);
            // clang-format off
            asm volatile("tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" :: "l"(gmem_ptr + i), "r"(uint_ptr));
            // clang-format on
        }

        return gmem_ptr;
    }

    __device__ auto Fetch_V(const MatrixParam&        param_V,
                            int                       K,
                            int                       N,
                            typename Scheduler::Tile* tile,
                            int                       math_group_idx,
                            int                       wg_idx_n,
                            SharedStorage&            storage)
    {
        const int offset_n = tile->offset_n;

        Array pred_V{};

        if constexpr (MMA_SUBTILE_N != 1) {
            int offset = offset_n % 128 + wg_idx_n * WG_TILE_N;
            static_assert(WG_N == 1);
            // Safely skip pred_V_0 when distributing WGs along M
            PRAGMA_UNROLL
            for (int i = 1; i < MMA_SUBTILE_N; ++i) {
                pred_V[i] = (i * OUTER_N + offset) >= 128;
            }
        }

        return pred_V;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gpu_metric.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/gpu_metric.h"
#include 

#include 

namespace turbomind::gemm {

using thrust::device_vector;

namespace {

template
__global__ void l2_bw(float* dsink, const float* array, int count)
{
    int    tid = threadIdx.x + (blockIdx.x >> LOG_TILE) * blockDim.x;
    float4 sink{};

    constexpr int NUM_THREADS = BLOCK_NUM * BLOCK_DIM;

    for (int i = 0; i < count; i += NUM_THREADS * 4) {
        const float* ptr    = array + i;
        const int    offset = tid * 4;
        float4       data   = __ldcg(reinterpret_cast(ptr + offset));
        sink.x += data.x;
        sink.y += data.y;
        sink.z += data.z;
        sink.w += data.w;
    }

    dsink[threadIdx.x] = sink.x + sink.y + sink.z + sink.w;
}

}  // namespace

float MeasureL2CacheThroughput()
{
    cudaDeviceProp prop{};
    int            device{};
    cudaGetDevice(&device);
    cudaGetDeviceProperties(&prop, device);

    size_t size = static_cast(prop.l2CacheSize) * 64;

    std::cout << size << std::endl;

    constexpr int BLOCK_X  = 128;  // blocks participating single sweep
    constexpr int BLOCK_Y  = 128;  // full sweep iters
    constexpr int LOG_TILE = 5;    // swizzling factor to bring up L2 hit rate, set to 0 will minimize hit rate

    constexpr int BLOCK_DIM = 256;

    constexpr int CHUNK_SIZE = BLOCK_X * BLOCK_DIM * 4;  // x4 for float4 load pattern

    device_vector data(ceil_div(size, sizeof(float)) / CHUNK_SIZE * CHUNK_SIZE);
    device_vector dsink(BLOCK_DIM);

    cudaStream_t stream;
    cudaStreamCreate(&stream);

    cudaMemsetAsync(data.data().get(), 0, sizeof(float) * data.size(), stream);

    cudaEvent_t ev_start, ev_end;

    cudaEventCreate(&ev_start);
    cudaEventCreate(&ev_end);

    cudaEventRecord(ev_start, stream);

    l2_bw<<> LOG_TILE), BLOCK_DIM, 0, stream>>>(
        dsink.data().get(), data.data().get(), data.size());

    cudaEventRecord(ev_end, stream);

    cudaEventSynchronize(ev_end);

    float ms{};
    cudaEventElapsedTime(&ms, ev_start, ev_end);

    size_t bytes = BLOCK_Y * sizeof(float) * data.size();

    const float bytes_per_second = bytes / ms * 1e3;
    std::cout << bytes_per_second / 1e9 << " GB/s" << std::endl;

    cudaEventDestroy(ev_start);
    cudaEventDestroy(ev_end);

    cudaStreamDestroy(stream);

    return bytes_per_second;
}

float MeasureMmaThroughput(int problem_size)
{
    device_vector a(problem_size * problem_size);
    device_vector b(a.size());
    device_vector c(a.size());

    cublasHandle_t cublas{};
    cublasCreate(&cublas);

    cudaStream_t stream;
    cudaStreamCreate(&stream);

    cublasSetStream(cublas, stream);

    cudaEvent_t ev_start, ev_end;

    cudaEventCreate(&ev_start);
    cudaEventCreate(&ev_end);

    cudaEventRecord(ev_start, stream);

    float alpha = 1.f;
    float beta  = 0.f;
    cublasGemmEx(cublas,
                 CUBLAS_OP_N,
                 CUBLAS_OP_N,
                 problem_size,
                 problem_size,
                 problem_size,
                 &alpha,
                 a.data().get(),
                 CUDA_R_16F,
                 problem_size,
                 b.data().get(),
                 CUDA_R_16F,
                 problem_size,
                 &beta,
                 c.data().get(),
                 CUDA_R_16F,
                 problem_size,
                 CUBLAS_COMPUTE_32F,
                 CUBLAS_GEMM_DEFAULT);

    cudaEventRecord(ev_end, stream);

    cudaEventSynchronize(ev_end);

    float ms{};
    cudaEventElapsedTime(&ms, ev_start, ev_end);

    cudaEventDestroy(ev_start);
    cudaEventDestroy(ev_end);

    cudaStreamDestroy(stream);

    cublasDestroy(cublas);

    const size_t ops = (size_t)problem_size * problem_size * problem_size;

    float fma_per_second = ops / ms * 1e3;

    std::cout << 2 * fma_per_second / 1e9 << " FLOPS/s" << std::endl;

    return fma_per_second;
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/gpu_metric.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

// bytes / second
float MeasureL2CacheThroughput();

// fused multiply-add / second
float MeasureMmaThroughput(int proble_size = 16384);

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/iterator.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/thread_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

struct VoidGmemIter {
    static constexpr int  ITER_S = 0;
    static constexpr auto kMode  = Striding::kFlat;
    using Fragments              = int;
    __device__      VoidGmemIter(...) {}
    __device__ void ClearSmem() {}
    __device__ void Prefetch(int, int, bool) {}
    __device__ void Prefetch(bool) {}
    __device__ void Fetch(Fragments&, bool) {}
    __device__ void Store(const Fragments&) {}
    __device__ void Advance() {}
    int*            smem_data_;
    bool            g_mask{false};
};

struct GetGmemIter {
    template
    static constexpr auto
        apply(basic_type, basic_type, basic_type, pair, constant)
    {
        using Dtype = typename Operand::Dtype;

        constexpr int kAccessSize =
            std::min(128 / bitsof, std::max(32 / bitsof, M * K / (WARPS * WARP_SIZE)));

        constexpr int2 kAligned = mk2cs(0, 1);
        constexpr int2 kCS      = mk2cs(M, K);

#if 0
        constexpr int kMaxThrS = std::min(WARP_SIZE, ceil_div(kCS.y, WARPS));
        constexpr int kMaxThrC = std::min(WARP_SIZE, ceil_div(kCS.x, kAccessSize));

        constexpr int kTgtThrC = ceil_div(256, sizeof(Array));

        constexpr int kWarpThrC = std::min(kMaxThrC, std::max(WARP_SIZE / kMaxThrS, kTgtThrC));
#endif
        using GmemIter = typename Iterator::template Type,
                                                          SmemLayout,
                                                          Operand::kPack,
                                                          Operand::kOrder,
                                                          kAligned.x,   // aligned C
                                                          kAligned.y>;  // aligned S
        return type_c;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/iterator_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/gemm/cp_async.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/predicate.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include 
#include 

namespace turbomind::gemm {

template
inline __device__ void _Ld(Array& dst, const T* src)
{
    static_assert(sizeof(Array) <= sizeof(uint4));

    if constexpr (sizeof(Array) == sizeof(uint4)) {
        (uint4&)dst = __ldcs((const uint4*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint2)) {
        (uint2&)dst = __ldcs((const uint2*)src);
    }
    else if constexpr (sizeof(Array) == sizeof(uint)) {
        (uint&)dst = __ldcs((const uint*)src);
    }
    else {
        static_assert(!std::is_same_v);
    }
}

template
struct GmemIteratorSm70 {

    using ThreadMap = Map;

    using AccessType = Array;
    using Pointer    = get_pointer_type;

    using Policy = Policy_;

    static constexpr int ITER_S = Map::kIterS;
    static constexpr int ITER_C = Map::kIterC;

    static constexpr Striding kMode      = mode;
    static constexpr bool     is_indexed = mode == Striding::kIndexed;

    const char* src_data_;

    int src_offset_;
    int dst_offset_;

    int offset_c_;
    int offset_s_;

    int src_step_c_;
    int src_step_s_;

    int src_step_k_;

    Predicate pred_;

    bool g_mask{true};

    SmemAccessor smem_data_;

    static constexpr int2 kMK0     = cs2mk(SmemLayout::C0, SmemLayout::S0);
    static constexpr int  kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);
    static constexpr int  kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);

    int phases_[kPeriodS][kPeriodC];

    const char* src_data_vec_[ITER_S];

    using Fragments = AccessType[Map::kIterS][Map::kIterC];

    __device__ static constexpr int2 pack(int2 mk)
    {
        return Packing_v2::apply(mk);
    }

    __device__ static constexpr int2 to_cs(int2 mk)
    {
        return mk2cs(mk.x, mk.y);
    }

    __device__ GmemIteratorSm70(): smem_data_{Pointer{nullptr}} {};

    __device__ GmemIteratorSm70(const MatrixData& mat, int2 offset, int2 extent): smem_data_{Pointer{(T*)nullptr}}
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        const Pointer data{(T*)mat.ptr.ptr};
        const int     ld = mat.ptr.stride;

        const int2 offsets = Map::get_offset(warp_id, lane_id);

        offset_c_ = offsets.x;
        offset_s_ = offsets.y;

        // auto src_ptr = reinterpret_cast((T*)data);

        if constexpr (pred_.is_active) {
            extent = to_cs(pack(extent));
            PRAGMA_UNROLL
            for (int s = 0; s < Map::kIterS; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < Map::kIterC; ++c) {
                    int ss = offset_s_ + s * Map::kDeltaS;
                    int cc = offset_c_ + c * Map::kDeltaC;
                    if (ss < extent.y && cc < extent.x) {
                        pred_.set(s, c);
                    }
                }
            }
        }

        PRAGMA_UNROLL
        for (int s = 0; s < kPeriodS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < kPeriodC; ++c) {
                phases_[s][c] = SmemLayout::apply(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);
            }
        }

        const int src_offset = is_indexed ? offsets.x : offsets.x + offsets.y * ld;

        src_offset_ = src_offset * bitsof / bitsof;

        src_step_c_ = bitsof * Map::kDeltaC / bitsof;
        src_step_s_ = bitsof * Map::kDeltaS * ld / bitsof;

        src_step_k_ = bitsof * cs2mk(Map::kDimC, Map::kDimS * ld).y / bitsof;

        // initialize for the first tile
        if constexpr (is_indexed) {
            const int2 cta_cs = to_cs(offset);
            for (int s = 0; s < ITER_S; ++s) {
                const int  ss    = cta_cs.y + offset_s_ + s * Map::kDeltaS;
                const int  idx   = (mat.idxs && pred_(s, 0)) ? __ldg(mat.idxs + ss) : ss;
                const auto tmp   = data + cs2idx({cta_cs.x, idx}, ld);
                src_data_vec_[s] = reinterpret_cast((T*)tmp) + src_offset_;
            }
        }
        else {
            auto src_data = data + cs2idx(to_cs(pack(offset)), ld);
            src_data_     = reinterpret_cast((T*)src_data) + src_offset_;
        }
    }

    __device__ constexpr int _src_step_k() const
    {
        return src_step_k_;
    }

    __device__ void ClearSmem(int pipe_iter = 0)
    {
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                const int pred_s = offset_s_ + s * Map::kDeltaS < Map::kDimS;
                const int pred_c = offset_c_ + c * Map::kDeltaC < Map::kDimC;
                auto      ptr    = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);
                if ((Map::kAlignedC && Map::kAlignedS) || (pred_s && pred_c)) {
                    turbomind::Store(ptr, Array{});
                }
            }
        }
    }

    __device__ void Advance()
    {
        if constexpr (!is_indexed) {
            if (!g_mask) {
                src_data_ -= _src_step_k();
            }
        }
    }

    __device__ void Copy(std::true_type, T* dst, const char* __restrict__ src, bool mask)
    {
        if (mask) {
            AccessType frag;
            if constexpr (Policy_::kEvictPolicy != EvictPolicy::kEvictNormal) {
                _Ld(frag, (const T*)src);
            }
            else {
                Ldg(frag, (const T*)src);
            }
            turbomind::Store(dst, frag);
        }
    }

    __device__ void Fetch(Fragments& frags, bool tile_mask)
    {
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {

            if constexpr (is_indexed) {
                src_data_ = src_data_vec_[s];
            }

            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                Copy2(frags[s][c], src_data_ + src_step_c_ * c, tile_mask && g_mask && pred_(s, c));
            }

            if constexpr (is_indexed) {
                src_data_vec_[s] += _src_step_k();
            }
            else {
                src_data_ += src_step_s_;
                if (s == Map::kIterS - 1) {
                    src_data_ -= src_step_s_ * Map::kIterS;
                    src_data_ += _src_step_k();
                }
            }
        }
    }

    __device__ void Store(Fragments& frags)
    {
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                // auto dst = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);

                const int i0  = SmemLayout::apply(  //
                    s / kPeriodS * kPeriodS * Map::kDeltaS,
                    c / kPeriodC * kPeriodC * Map::kDeltaC);
                const int i1  = phases_[s % kPeriodS][c % kPeriodC];
                auto      dst = &smem_data_.ptr_[i0 + i1];

                if (pred_(s, c)) {
                    turbomind::Store(dst, frags[s][c]);
                }
            }
        }
    }

    __device__ void Copy2(AccessType& frag, const char* __restrict__ src, bool mask)
    {
        if (mask) {
            if constexpr (Policy_::kEvictPolicy != EvictPolicy::kEvictNormal) {
                _Ld(frag, (const T*)src);
            }
            else {
                Ldg(frag, (const T*)src);
            }
        }
    }
};

template
struct IteratorSm70 {
    template
    using Type = GmemIteratorSm70;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/iterator_sm80.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/cp_async.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/predicate.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include 
#include 

namespace turbomind::gemm {

template
struct GmemIteratorSm80 {

    using ThreadMap = Map;

    using AccessType = Array;
    using Pointer    = get_pointer_type;

    using Policy = Policy_;

    static constexpr int ITER_S = Map::kIterS;
    static constexpr int ITER_C = Map::kIterC;

    static constexpr Striding kMode      = mode;
    static constexpr bool     is_indexed = mode == Striding::kIndexed;

    const char* src_data_;

    int src_offset_;
    int dst_offset_;

    int offset_c_;
    int offset_s_;

    int src_step_c_;
    int src_step_s_;

    int src_step_k_;

    Predicate pred_;

    bool g_mask{true};

    SmemAccessor smem_data_;

    static constexpr int2 kMK0     = cs2mk(SmemLayout::C0, SmemLayout::S0);
    static constexpr int  kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);
    static constexpr int  kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);

    int phases_[kPeriodS][kPeriodC];

    const char* src_data_vec_[ITER_S];

    uint64_t cache_policy_{};

    __device__ static constexpr int2 pack(int2 mk)
    {
        return Packing_v2::apply(mk);
    }

    __device__ static constexpr int2 to_cs(int2 mk)
    {
        return mk2cs(mk.x, mk.y);
    }

    __device__ GmemIteratorSm80(): smem_data_{Pointer{nullptr}} {};

    __device__ GmemIteratorSm80(const MatrixData& mat, int2 offset, int2 extent): smem_data_{Pointer{(T*)nullptr}}
    {
        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        const Pointer data{(T*)mat.ptr.ptr};
        const int     ld = mat.ptr.stride;

        const int2 offsets = Map::get_offset(warp_id, lane_id);

        offset_c_ = offsets.x;
        offset_s_ = offsets.y;

        if constexpr (pred_.is_active) {
            extent = to_cs(pack(extent));
            PRAGMA_UNROLL
            for (int s = 0; s < Map::kIterS; ++s) {
                PRAGMA_UNROLL
                for (int c = 0; c < Map::kIterC; ++c) {
                    int ss = offset_s_ + s * Map::kDeltaS;
                    int cc = offset_c_ + c * Map::kDeltaC;
                    if (ss < extent.y && cc < extent.x) {
                        pred_.set(s, c);
                    }
                }
            }
        }

        PRAGMA_UNROLL
        for (int s = 0; s < kPeriodS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < kPeriodC; ++c) {
                phases_[s][c] = SmemLayout::apply(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);
            }
        }

        const int src_offset = is_indexed ? offsets.x : offsets.x + offsets.y * ld;

        src_offset_ = src_offset * bitsof / bitsof;

        src_step_c_ = bitsof * Map::kDeltaC / bitsof;
        src_step_s_ = bitsof * Map::kDeltaS * ld / bitsof;

        src_step_k_ = bitsof * cs2mk(Map::kDimC, Map::kDimS * ld).y / bitsof;

        // Initialize for the first tile
        if constexpr (is_indexed) {
            const int2 cta_cs = to_cs(offset);
            for (int s = 0; s < ITER_S; ++s) {
                const int  ss    = cta_cs.y + offset_s_ + s * Map::kDeltaS;
                const int  idx   = (mat.idxs && pred_(s, 0)) ? __ldg(mat.idxs + ss) : ss;
                const auto tmp   = data + cs2idx({cta_cs.x, idx}, ld);
                src_data_vec_[s] = reinterpret_cast((T*)tmp) + src_offset_;
            }
        }
        else {
            auto src_data = data + cs2idx(to_cs(pack(offset)), ld);
            src_data_     = reinterpret_cast((T*)src_data) + src_offset_;
        }

#if TURBOMIND_ARCH_SM80
        if constexpr (Policy::kEvictPolicy != EvictPolicy::kEvictNormal) {
            asm volatile("createpolicy.fractional.L2::evict_first.b64 %0;\n" : "=l"(cache_policy_) :);
        }
#endif
    }

    __device__ constexpr int _src_step_k() const
    {
        return src_step_k_;
    }

    __device__ void ClearSmem(int pipe_iter = 0)
    {
        PRAGMA_UNROLL
        for (int s = 0; s < Map::kIterS; ++s) {
            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                const int pred_s = offset_s_ + s * Map::kDeltaS < Map::kDimS;
                const int pred_c = offset_c_ + c * Map::kDeltaC < Map::kDimC;
                auto      ptr    = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);
                if ((Map::kAlignedC && Map::kAlignedS) || (pred_s && pred_c)) {
                    Store(ptr, Array{});
                }
            }
        }
    }

    __device__ void Prefetch(int begin, int count, bool tile_mask)
    {
        PRAGMA_UNROLL
        for (int s = begin; s < begin + count && s < Map::kIterS; ++s) {

            if constexpr (is_indexed) {
                src_data_ = src_data_vec_[s];
            }

            PRAGMA_UNROLL
            for (int c = 0; c < Map::kIterC; ++c) {
                // auto dst = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);

                const int i0  = SmemLayout::apply(  //
                    s / kPeriodS * kPeriodS * Map::kDeltaS,
                    c / kPeriodC * kPeriodC * Map::kDeltaC);
                const int i1  = phases_[s % kPeriodS][c % kPeriodC];
                auto      dst = &smem_data_.ptr_[i0 + i1];

                CpAsync(std::true_type{}, dst, src_data_ + src_step_c_ * c, tile_mask && g_mask && pred_(s, c));
            }

            if constexpr (is_indexed) {
                src_data_vec_[s] += _src_step_k();
            }
            else {
                src_data_ += src_step_s_;
                if (s == Map::kIterS - 1) {
                    src_data_ -= src_step_s_ * Map::kIterS;
                    src_data_ += _src_step_k();
                }
            }
        }
    }

    __device__ void Prefetch(bool tile_mask)
    {
        Prefetch(0, Map::kIterS, tile_mask);
    }

    __device__ void Advance()
    {
        if constexpr (!is_indexed) {
            if (!g_mask) {
                src_data_ -= _src_step_k();
            }
        }
    }

    __device__ void CpAsync(std::true_type, T* dst, const char* __restrict__ src, bool mask)
    {
#if TURBOMIND_ARCH_SM80
        constexpr int size = sizeof(AccessType);
        static_assert(size <= 16);

        constexpr int prefetch_size = std::min(256, size * Map::kWarpThreadC);

        auto ptr = cast_smem_ptr_to_uint(dst);

        static constexpr auto cache_op = GetCacheOp::value;

        if constexpr (Policy::kEvictPolicy != EvictPolicy::kEvictNormal) {
            CP_ASYNC::apply(ptr, src, cache_policy_, mask);
        }
        else {
            CP_ASYNC::apply(ptr, src, mask);
        }
#else
        assert(TURBOMIND_ARCH_SM80);
#endif
    }
};

template
struct IteratorSm80 {
    template
    using Type = GmemIteratorSm80;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/iterator_sm90.h
================================================
#pragma once

#include 
#include 

namespace turbomind::gemm {

template
struct GmemIteratorSm90 {

    const CUtensorMap* desc_ptr_;
    int2               offset_;
    int2               step_;

    __device__ GmemIteratorSm90(const CUtensorMap* desc_ptr, int2 offset, int2 step)
    {
        desc_ptr_ = desc_ptr;
        offset_   = offset;
        step_     = step;
    }

    __device__ void Step(uint64_t* mbar_ptr, void* smem_ptr, uint16_t mask, uint64_t cache_hint = 0)
    {
        if constexpr (multicast > 1) {
            cute::SM90_TMA_LOAD_MULTICAST_2D::copy(
                desc_ptr_, mbar_ptr, mask, cache_hint, smem_ptr, offset_.x, offset_.y);
        }
        else {
            cute::SM90_TMA_LOAD_2D::copy(desc_ptr_, mbar_ptr, cache_hint, smem_ptr, offset_.x, offset_.y);
        }
        offset_.x += step_.x;
        offset_.y += step_.y;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm70_884_16.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch/config_sm70_s884.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm70_s884;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm70_884_16()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_F16;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm70_884_4.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch/config_sm70_s884.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm70_s884;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm70_884_4()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_U4_d;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_U4_g;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm70_884_8.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch/config_sm70_s884.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm70_s884;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm70_884_8()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_E4M3;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm75_16816_16.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm75_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm75_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm75_16816_16()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_F16;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm75_16816_4.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch/config_sm75_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm75_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm75_16816_4()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_U4_d;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_U4_g;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm75_16816_8.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm75_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm75_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm75_16816_8()
{
    if constexpr (1) {
        // clang-format off
        using Cg = Config_E4M3;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm80_16816_16.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm80_16816_16()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_F16_g;
        Add>();
        Add>(); // 10
        Add>();
        Add>(); // 6
        Add>();
        Add>();
        Add>(); // 2
        Add>();
        Add>(); // *
        Add>();
        Add>(); // 4
        Add>();
        Add>();
        Add>(); // 10
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_F16_g;
        Add>();
        Add>(); // 10
        Add>();
        Add>(); // 6
        Add>();
        Add>();
        Add>(); // 2
        Add>();
        Add>(); // *
        Add>();
        Add>(); // 4
        Add>();
        Add>();
        Add>(); // 10
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm80_16816_4.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm80_16816_4()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_U4_d;
        // Add>(); // 0/0
        Add>(); // 30/3
        Add>(); // --/20
        Add>();  // --/13
        Add>();  // 21/13
        Add>();  // 6/6

        Add>();  // --/3
        Add>();  // 13/13
        Add>();  // 14/10
        Add>();  // 2/2

        Add>(); // --/21
        Add>(); // 27/13
        Add>();  // 8/5
        Add>();  // 7/5
        Add>();  // 6/7
        Add>();

        Add>(); // 1/1
        Add>();  // 1/1
        Add>();  // 4/4
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_U4_g;
        Add>();  // 10 + 5 + 4 + 10 + 10, 37
        Add>();  // 1 + 6 + 4 + 4 + 2, 3
        Add>();  // 7 + 4 + 6 + 2 + 4, 26
        Add>();  // 18
        Add>();  // 2
        Add>();  // 1 + 2 + 2 + 2 + 2, 2
        Add>();  // 9
        Add>();  // 22
        Add>();  // 8
        Add>();  // 1 + 13 + 9 + 13 + 7, 7
        Add>();  // 12 + 2 + 6 + 2 + 8, 42
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using Cd = Config_MXF4;
        // Add>();

        using Cg = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        using C8 = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm80_16816_8.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm80_16816_8()
{
    if constexpr (1) {
        // clang-format off
        using Cd = Config_E4M3;
        // Add>();

        using Cg = Config_E4M3;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        using C8 = Config_E4M3;
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm90_16816_16.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm90_16816_16()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_F16_g;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_F16_g;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm90_16816_4.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm90_16816_4()
{
    if constexpr (1) {
        // clang-format off
        using C = Config_U4_d;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using C = Config_U4_g;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }

    if constexpr (1) {
        // clang-format off
        using Cd = Config_MXF4;
        // Add>();

        using Cg = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        using C8 = Config_MXF4;
        Add>();
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm90_16816_8.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h"
#include "src/turbomind/kernels/gemm/registry.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

using namespace sm80_s16816;
using namespace cache_policy;
using S = cache_policy::Stream;
using D = cache_policy::Default;

void Registry::sm90_16816_8()
{
    if constexpr (1) {
        // clang-format off
        using Cd = Config_E4M3;
        // Add>();

        using Cg = Config_E4M3;
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();
        Add>();

        using C8 = Config_E4M3;
        Add>();
        Add>();
        Add>();
        // clang-format on
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel/sm90_64n32_8.cu
================================================

#include 

#include "src/turbomind/kernels/gemm/registry.h"

// We need modifiable TMA, which is added in 12.3
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 3))

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h"
#include "src/turbomind/kernels/gemm/gemm_universal_sm90_v5.h"
#include "src/turbomind/kernels/gemm/kernel_impl_sm90.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

void Registry::sm90_64n32_8()
{
    Add(std::make_unique>>());
    Add(std::make_unique>>());
    Add(std::make_unique>>());

    Add(std::make_unique>>());
    Add(std::make_unique>>());
    Add(std::make_unique>>());

    Add(std::make_unique>>());
    Add(std::make_unique>>());
    Add(std::make_unique>>());

    Add(std::make_unique>>());
    Add(std::make_unique>>());
    Add(std::make_unique>>());
}

}  // namespace turbomind::gemm

#else

namespace turbomind::gemm {

void Registry::sm90_64n32_8() {}

}  // namespace turbomind::gemm

#endif


================================================
FILE: src/turbomind/kernels/gemm/kernel.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

bool accept(Striding a, Striding b)
{
    if (a == Striding::kBlocked) {
        switch (b) {
            case Striding::kBlocked:
            case Striding::kFlat:
                return true;
            default:
                return false;
        }
    }
    else if (a == Striding::kIndexed) {
        switch (b) {
            case Striding::kFlat:
            case Striding::kBlocked:
            case Striding::kIndexed:
                return true;
            default:
                return false;
        }
    }
    else {
        return a == b;
    }
}

bool Kernel::is_feasible(const GemmDesc& desc) const noexcept
{
    constexpr bool debug = 0;

    if constexpr (debug)
        printf("S\n");

    // printf("%d %d\n", desc.arch, desc_.arch);

    if (!is_arch_compatible(desc_.arch, desc.arch)) {
        return false;
    }

    if constexpr (debug)
        printf("S0\n");

    if (std::tie(desc.order_a, desc.order_b, desc.order_c) != std::tie(desc_.order_a, desc_.order_b, desc_.order_c)) {
        return false;
    }

    if (desc.group_axis >= 0 && desc.group_axis != desc_.group_axis) {
        return false;
    }

    if (!(accept(desc_.striding_a, desc.striding_a)     //
          && accept(desc_.striding_b, desc.striding_b)  //
          && accept(desc_.striding_c, desc.striding_c))) {
        return false;
    }

    if constexpr (debug)
        printf("A\n");

    if (std::tie(desc.type_a, desc.type_b, desc.type_c) != std::tie(desc_.type_a, desc_.type_b, desc_.type_c)) {
        return false;
    }

    if constexpr (debug) {
        printf("B\n");
        printf("%X %X %X %X\n", desc.pack_a, desc_.pack_a, desc.pack_u, desc_.pack_u);
    }

    if (std::tie(desc.pack_a, desc.pack_u) != std::tie(desc_.pack_a, desc_.pack_u)) {
        return false;
    }

    if constexpr (debug) {
        printf("C\n");
        printf("%X %X %X %X\n", desc.pack_b, desc_.pack_b, desc.pack_v, desc_.pack_v);
    }

    if (std::tie(desc.pack_b, desc.pack_v) != std::tie(desc_.pack_b, desc_.pack_v)) {
        return false;
    }

    if constexpr (debug)
        printf("D\n");

    if (desc.quant_a.type != desc_.quant_a.type || desc.quant_a.group_size != desc_.quant_a.group_size) {
        return false;
    }

    if constexpr (debug)
        printf("E\n");

    if (desc.quant_b.type != desc_.quant_b.type || desc.quant_b.group_size != desc_.quant_b.group_size) {
        return false;
    }

    if constexpr (debug)
        printf("F\n");

    if (desc.m % desc_.align.x || desc.n % desc_.align.y || desc.k % desc_.align.z) {
        return false;
    }

    if constexpr (debug)
        printf("success\n");

    return true;
}

//  mm:     m * n * k,     m * k,     n * k,     m * n
// Bmm: b * m * n * k, b * m * k, b * n * k, b * m * n
// Gmm: S $ M * n * k, S $ M * k, S $ n * k, S $ M * n

std::string Kernel::GetName() const
{
    std::stringstream ss;

    ss << "sm" << desc_.arch / 10;
    ss << "_" << to_string(desc_.type_a);  //
    if (desc_.quant_a) {
        ss << to_string(desc_.quant_a);
    }
    ss << "_" << to_string(desc_.type_b);  //
    if (desc_.quant_b) {
        ss << to_string(desc_.quant_b);
    }
    ss << "_" << to_string(desc_.type_c);
    ss << "_"                                        //
       << (desc_.order_a == kColMajor ? 'n' : 't')   //
       << (desc_.order_b == kColMajor ? 'n' : 't')   //
       << (desc_.order_c == kColMajor ? 'n' : 't');  //
    ss << "_"                                        //
       << to_string(desc_.striding_a)                //
       << to_string(desc_.striding_b)                //
       << to_string(desc_.striding_c);
    ss << "_" << desc_.cta_tile.x << "x" << desc_.cta_tile.y << "x" << desc_.cta_tile.z  //
       << "_" << desc_.stages                                                            //
       << "_" << desc_.cluster_shape.x << "x" << desc_.cluster_shape.y                   //
       << "_" << to_string(desc_.op_class)                                               //
       << "_" << desc_.mma_tile.x << "x" << desc_.mma_tile.y << "x" << desc_.mma_tile.z;
    if (desc_.group_axis >= 0) {
        ss << "_"
           << "mn"[desc_.group_axis] << "group";
    }
    ss << "_c" << desc_.c_tile.x << "x" << desc_.c_tile.y                        //
       << "_a" << desc_.align.x << "x" << desc_.align.y << "x" << desc_.align.z  //
       << "_" << desc_.policy_a << desc_.policy_b;

    return ss.str();
}

class TransposedKernel: public Kernel {
public:
    explicit TransposedKernel(Kernel& kernel): kernel_(&kernel)
    {
        desc_ = kernel.desc();
        info_ = kernel.info();

        desc_.transpose = !desc_.transpose;
    }

    int Launch(const Operation&    operation,
               float               alpha,
               const void*         A,
               const MatrixLayout& Adesc,
               const void*         U,
               const MatrixLayout& Udesc,
               const void*         B,
               const MatrixLayout& Bdesc,
               const void*         V,
               const MatrixLayout& Vdesc,
               float               beta,
               const void*         C,
               const MatrixLayout& Cdesc,
               void*               D,
               const MatrixLayout& Ddesc,
               int                 swizzle,
               int                 splits,
               Workspace&          workspace,
               cudaStream_t        stream) override
    {
        return kernel_->Launch(transpose(operation),
                               alpha,
                               B,
                               transpose(Bdesc),
                               V,
                               transpose(Vdesc),
                               A,
                               transpose(Adesc),
                               U,
                               transpose(Udesc),
                               beta,
                               C,
                               transpose(Cdesc),
                               D,
                               transpose(Ddesc),
                               swizzle,
                               splits,
                               workspace,
                               stream);
    }

    bool is_feasible(const GemmDesc& desc) const noexcept override
    {
        return kernel_->is_feasible(desc);
    }

    int GetMaxSwizzle(const int4& shape) const override
    {
        return kernel_->GetMaxSwizzle(shape);
    }

    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override
    {
        return kernel_->GetMaxSplits(shape, swizzle, bsize, psize);
    }

private:
    Kernel* kernel_;
};

std::unique_ptr transpose(Kernel& kernel)
{
    return std::make_unique(kernel);
}

template
inline static bool cmp(const int3& a, const int3& b, Op op)
{
    return op(std::tie(a.x, a.y, a.z), std::tie(b.x, b.y, b.z));
}

std::vector> Cluster(const std::vector& specs, const ClusteringParam& param)
{
    std::vector ptrs;  // pointer into `specs`
    for (auto& s : specs) {
        ptrs.push_back(&s);
    }

    auto less = [&](const LaunchSpec* u, const LaunchSpec* v) {
        const auto& a = u->kernel->desc();
        const auto& b = v->kernel->desc();
        if (!cmp(a.cta_tile, b.cta_tile, std::equal_to<>{})) {
            return cmp(a.cta_tile, b.cta_tile, std::less<>{});
        }
        if (!cmp(a.mma_tile, b.mma_tile, std::equal_to<>{})) {
            return cmp(a.mma_tile, b.mma_tile, std::less<>{});
        }
        if (param.cache_policy) {
            const auto pa = std::tie(a.policy_a, a.policy_b);
            const auto pb = std::tie(b.policy_a, b.policy_b);
            if (pa != pb) {
                return pa < pb;
            }
        }
        if (param.max_active_ctas) {
            const auto& a = u->kernel->info();
            const auto& b = v->kernel->info();
            if (a.max_active_ctas != b.max_active_ctas) {
                return a.max_active_ctas < b.max_active_ctas;
            }
        }
        return u->splits < v->splits;
    };

    std::stable_sort(ptrs.begin(), ptrs.end(), less);

    if (ptrs.empty()) {
        return {};
    }
    std::vector> clusters{{*ptrs[0]}};

    auto equal = [&](const LaunchSpec* u, const LaunchSpec* v) {  //
        return !less(u, v) && !less(v, u);
    };
    int p = 0;
    for (size_t i = 1; i < ptrs.size(); ++i) {
        if (equal(ptrs[p], ptrs[i])) {
            clusters.back().push_back(*ptrs[i]);
        }
        else {
            clusters.push_back({*ptrs[i]});
            p = i;
        }
    }

    return clusters;
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 

#include 

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

struct KernelMetric {
    int64_t mio_cost;
    int64_t mma_cost;
};

class Kernel {
public:
    Kernel(): desc_{}, info_{} {}

    virtual ~Kernel() = default;

    virtual int Launch(const Operation&    operation,
                       float               alpha,
                       const void*         A,
                       const MatrixLayout& Adesc,
                       const void*         U,
                       const MatrixLayout& Udesc,
                       const void*         B,
                       const MatrixLayout& Bdesc,
                       const void*         V,
                       const MatrixLayout& Vdesc,
                       float               beta,
                       const void*         C,
                       const MatrixLayout& Cdesc,
                       void*               D,
                       const MatrixLayout& Ddesc,
                       int                 swizzle,
                       int                 splits,
                       Workspace&          workspace,
                       cudaStream_t        stream) = 0;

    // true if this kernel can be used to compute the gemm
    virtual bool is_feasible(const GemmDesc& desc) const noexcept;

    virtual int GetMaxSwizzle(const int4& shape) const = 0;

    virtual int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const = 0;

    const KernelDesc& desc() const noexcept
    {
        return desc_;
    }

    const KernelInfo& info() const noexcept
    {
        return info_;
    }

    int3 cta_tile_size() const noexcept
    {
        return desc_.cta_tile;
    }

    int3 warp_tile_size() const noexcept
    {
        return desc_.mma_tile;
    }

    int chunk_size_k() const noexcept
    {
        return info_.chunk_size_k;
    }

    int stages() const noexcept
    {
        return desc_.stages;
    }

    bool split_k() const noexcept
    {
        return desc_.split_k;
    }

    int arch() const noexcept
    {
        return desc_.arch;
    }

    int smem_size() const noexcept
    {
        return info_.attr.sharedSizeBytes + info_.dynamic_smem_size;
    }

    std::string name() const
    {
        return info_.name;
    }

protected:
    std::string GetName() const;

    KernelDesc desc_;
    KernelInfo info_;
};

struct ClusteringParam {
    bool cache_policy;
    bool max_active_ctas;
};

std::vector> Cluster(const std::vector& specs, const ClusteringParam& param);

std::unique_ptr transpose(Kernel& kernel);

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel_impl.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"

#include "src/turbomind/kernels/gemm/context.h"
#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/gemm_universal.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

template
class KernelImpl: public Kernel {
public:
    // import frequently used constants
    static constexpr int CTA_M = Gemm::CTA_M;
    static constexpr int CTA_N = Gemm::CTA_N;
    static constexpr int CTA_K = Gemm::CTA_K;

    using Impl  = typename Gemm::Impl;
    using Sched = typename Gemm::Scheduler;

    using OpA = typename Gemm::OperandA;
    using OpB = typename Gemm::OperandB;
    using OpU = typename Gemm::OperandU;
    using OpV = typename Gemm::OperandV;

    KernelImpl()
    {
        desc_.order_a = OpA::kOrder;
        desc_.order_b = transpose(OpB::kOrder);
        desc_.order_c = Gemm::kOrderC;

        desc_.type_a = data_type_v;
        desc_.type_b = data_type_v;
        desc_.type_c = data_type_v;

        using IterA = typename OpA::GmemIter;
        using IterB = typename OpB::GmemIter;

        desc_.striding_a = IterA::kMode;
        desc_.striding_b = IterB::kMode;
        desc_.striding_c = Gemm::Epilogue::kMode;

        desc_.pack_a = OpA::kPack;
        desc_.pack_b = OpB::kPack;
        desc_.pack_u = OpU::kPack;
        desc_.pack_v = OpV::kPack;

        desc_.quant_a = QuantDesc{};
        desc_.quant_b = QuantDesc{};

        if constexpr (OpU::SmemLayout::kSize > 1) {
            desc_.quant_a = QuantDesc{QuantType::kDefault, OpU::kGroupSize};
        }

        if constexpr (OpV::SmemLayout::kSize > 1) {
            desc_.quant_b = QuantDesc{QuantType::kDefault, OpV::kGroupSize};
        }

        desc_.cta_tile = {Gemm::CTA_M, Gemm::CTA_N, Gemm::CTA_K};
        desc_.mma_tile = {Impl::MMA_Map::kGroupM, Impl::MMA_Map::kGroupN, Impl::MMA_Map::kGroupK};

        info_.chunk_size_k = Gemm::kChunkSizeK;

        desc_.align.x = OpA::kOrder == kColMajor ? IterA::ThreadMap::kAccessC : 1;
        desc_.align.y = OpB::kOrder == kColMajor ? IterB::ThreadMap::kAccessC : 1;
        desc_.align.z = Gemm::CTA_K;

        desc_.policy_a = (int)IterA::Policy::kEvictPolicy;
        desc_.policy_b = (int)IterB::Policy::kEvictPolicy;
        desc_.c_tile   = {Gemm::Epilogue::TM, Gemm::Epilogue::TN};
        desc_.op_class = Impl::kOpClass;

        desc_.cluster_shape = {1, 1};

        auto func = gemm_kernel;

        cudaFuncGetAttributes(&info_.attr, func);

        info_.dynamic_smem_size = sizeof(typename Gemm::SharedStorage);

        if (info_.dynamic_smem_size > (48 << 10)) {
            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);
        }

        cudaOccupancyMaxActiveBlocksPerMultiprocessor(
            &info_.max_active_ctas, func, Impl::WARPS * WARP_SIZE, info_.dynamic_smem_size);

        desc_.stages     = Impl::Stages;
        desc_.split_k    = Gemm::kSplitK;
        desc_.group_axis = Sched::group_axis;

        desc_.arch = Gemm::Arch::value;

        info_.name = GetName();
    }

    int Launch(const Operation&    operation,
               float               alpha,
               const void*         A,
               const MatrixLayout& _Adesc,
               const void*         U,
               const MatrixLayout& Udesc,
               const void*         B,
               const MatrixLayout& _Bdesc,
               const void*         V,
               const MatrixLayout& _Vdesc,
               float               beta,
               const void*         C,
               const MatrixLayout& Cdesc,
               void*               D,
               const MatrixLayout& Ddesc,
               int                 swizzle,
               int                 splits,
               Workspace&          workspace,
               cudaStream_t        stream) override
    {
        MatrixLayout Adesc = _Adesc;

        const int m = Ddesc.rows;
        const int n = Ddesc.cols;
        const int k = Adesc.cols;
        const int l = std::max(1, Ddesc.num);

        auto transpose = [](MatrixLayout x) {
            std::swap(x.rows, x.cols);
            x.order = gemm::transpose(x.order);
            return x;
        };

        MatrixLayout Bdesc = transpose(_Bdesc);
        MatrixLayout Vdesc = transpose(_Vdesc);

        auto max_splits = GetMaxSplits({m, n, k, l}, swizzle, workspace.barriers_size, workspace.partials_size);

        Sched sched{{m, n, k, l}, swizzle, std::min(splits, max_splits)};
        sched.offsets_ = Ddesc.offsets;

        using Ta = typename Gemm::Ta;
        using Tb = typename Gemm::Tb;
        using Tc = typename Gemm::Tc;

        if constexpr (0) {
            [[maybe_unused]] static const int _ = [] {
                std::cout << "A:\n";
                Print(typename Gemm::OperandA::GmemIter::ThreadMap{});
                std::cout << "\nB:\n";
                Print(typename Gemm::OperandB::GmemIter::ThreadMap{});
                if constexpr (!std::is_same_v) {
                    std::cout << "\nU:\n";
                    Print(typename Gemm::OperandU::GmemIter::ThreadMap{});
                }
                if constexpr (!std::is_same_v) {
                    std::cout << "\nV:\n";
                    Print(typename Gemm::OperandV::GmemIter::ThreadMap{});
                }
                printf("warp count: %d\n", Impl::WARPS);
                Print_(typename Gemm::Impl::MMA_Map{});

                printf("C:\n");
                Print(typename Gemm::Epilogue::Map{});

                std::cout << "Smem for mainloop: " << sizeof(Gemm::SharedStorage::mainloop) << "\n";
                std::cout << "Smem for epilogue: " << sizeof(Gemm::SharedStorage::epilogue) << "\n";

                return 0;
            }();
        }

        const bool silu_act = ((int)operation.epilogue & (int)Epilogue::kGatedSilu);

        MatrixLayout Pdesc = Ddesc;
        Pdesc.ld           = mk2cs(Pdesc.rows, Pdesc.cols).x;

        MatrixCombination_v3 combin_mat{to_param((void*)C, Cdesc), alpha, beta};

        EpilogueParam epilogue{to_param((void*)D, Ddesc),
                               to_param((void*)workspace.partials, Pdesc),
                               (int*)workspace.barriers,
                               combin_mat,
                               silu_act};

        // std::cout << Adesc.offsets << " " << Adesc.idxs << "\n";

        GemmParam param{
            to_param((void*)A, Adesc),
            to_param((void*)B, Bdesc),
            to_param((void*)U, Udesc),
            to_param((void*)V, Vdesc),
        };

        const auto grid  = sched.get_grid_shape();
        const auto block = Gemm::Impl::WARPS * WARP_SIZE;

        // std::cout << info_.name << " " << splits << " " << swizzle << " " << sched.tiles_[0] << " " <<
        // sched.tiles_[1]
        //           << std::endl;
        // std::cout << grid.x << " " << grid.y << " " << grid.z << "\n";

        gemm_kernel<<>>(param, epilogue, sched);

        return 0;
    }

    std::array GetWorkspaceSize(int tiles, int splits) const
    {
        static constexpr bool kSerial = true;

        size_t barriers_size = sizeof(int) * tiles;
        size_t partials_size = sizeof(float) * CTA_M * CTA_N * tiles;

        if constexpr (!kSerial) {
            barriers_size *= splits;
            partials_size *= splits;
        }

        return {barriers_size, partials_size};
    }

    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override
    {
        if (!Gemm::kSplitK) {
            return 1;
        }

        const auto& [m, n, k, l] = shape;

        Sched sched{{m, n, k, l}, swizzle};  // for getting padded tiles

        const auto& [a, b] = GetWorkspaceSize(sched.tiles_[0] * sched.tiles_[1], 1);

        if (bsize >= a && psize >= b) {
            // Serial split-k requires workspace for 1 split only
            // But it can't exceed num of k chunks
            return cdiv(k, Gemm::kChunkSizeK);
        }
        else {
            return 1;
        }
    }

    int GetMaxSwizzle(const int4& shape) const override
    {
        const auto& [m, n, k, l] = shape;

        auto swizzle = Sched{{m, n, k, l}}.get_max_swizzle();
        // std::cout << m << " " << n << " " << k << " " << l << " " << swizzle << "\n";
        return swizzle;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/kernel_impl_sm90.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "cute/util/debug.hpp"
#include "src/turbomind/core/check.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/gemm/context.h"
#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/epilogue.h"
#include "src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/matrix_ptr.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/thread_group_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

#include "src/turbomind/kernels/gemm/tma.h"

#include "src/turbomind/utils/cuda_utils.h"

#define TM_GEMM_CUTLASS_NAME 0

#if TM_GEMM_CUTLASS_NAME
#define gemm_kernel_name cutlass_gemm_kernel_sm90
#else
#define gemm_kernel_name gemm_kernel_sm90
#endif

namespace turbomind::gemm {

extern __shared__ char smem_buf[];

template
__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_name(const __grid_constant__ CUtensorMap tm_a,
                                                                        const __grid_constant__ CUtensorMap tm_b,
                                                                        const __grid_constant__ CUtensorMap tm_c,
                                                                        const __grid_constant__ CUtensorMap tm_u,
                                                                        const __grid_constant__ CUtensorMap tm_v,
                                                                        const MatrixParam                   param_A,
                                                                        const MatrixParam                   param_B,
                                                                        const MatrixParam                   param_U,
                                                                        const MatrixParam                   param_V,
                                                                        const MatrixParam                   param_C,
                                                                        typename Kernel::Scheduler          sched,
                                                                        void* tensormap_buf)
{

#if __CUDA_ARCH__
    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {
        Kernel kernel;
        kernel(tm_a,
               tm_b,
               tm_c,
               tm_u,
               tm_v,
               param_A,
               param_B,
               param_U,
               param_V,
               param_C,
               sched,
               (CUtensorMap*)tensormap_buf,
               smem_buf);
    }
#endif
}

template
class KernelImplSm90: public Kernel {
public:
    // import frequently used constants
    static constexpr int TILE_M = Gemm::TILE_M;
    static constexpr int TILE_N = Gemm::TILE_N;
    static constexpr int TILE_K = Gemm::TILE_K;

    static constexpr auto is_grouped_gemm = Gemm::is_grouped_gemm;

    KernelImplSm90()
    {
        desc_.order_a = kRowMajor;  // m, k
        desc_.order_b = kColMajor;  // k, n
        desc_.order_c = kRowMajor;

        desc_.type_a = data_type_v;
        desc_.type_b = data_type_v;
        desc_.type_c = data_type_v;

        desc_.striding_a = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // IterA::kMode;
        desc_.striding_b = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // IterB::kMode;
        desc_.striding_c = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // Gemm::Epilogue::kMode;

        desc_.pack_a = {};  // OpA::kPack;
        desc_.pack_b = {};  // OpB::kPack;
        desc_.pack_u = {};  // OpU::kPack;
        desc_.pack_v = {};  // OpV::kPack;

        desc_.quant_a = QuantDesc{QuantType::kK, 128};
        desc_.quant_b = QuantDesc{QuantType::kB, 128};

        desc_.cta_tile = {TILE_M, TILE_N, TILE_K};
        desc_.mma_tile = {1, 1, 1};

        info_.chunk_size_k = Gemm::TILE_K;

        desc_.align.x = 1;  // OpA::kOrder == kColMajor ? IterA::ThreadMap::kAccessC : 1;
        desc_.align.y = 1;  // OpB::kOrder == kColMajor ? IterB::ThreadMap::kAccessC : 1;
        desc_.align.z = 1;  // Gemm::TILE_K;

        desc_.policy_a = 0;                 // (int)IterA::Policy::kEvictPolicy;
        desc_.policy_b = 0;                 // (int)IterB::Policy::kEvictPolicy;
        desc_.c_tile   = {TILE_M, TILE_N};  // {Gemm::Epilogue::TM, Gemm::Epilogue::TN};
        desc_.op_class = OpClass::kGMMA_s64n16;

        desc_.cluster_shape = {Gemm::Cluster::M, Gemm::Cluster::N};

        info_.dynamic_smem_size = Gemm::kSmemSize;

        desc_.stages     = Gemm::Stages;
        desc_.split_k    = 1;  // Gemm::kSplitK;
        desc_.group_axis = is_grouped_gemm ? 0 : -1;

        desc_.arch = Gemm::Arch::value;

        auto func = gemm_kernel_name;

        cudaFuncGetAttributes(&info_.attr, func);

        if (info_.dynamic_smem_size > (48 << 10)) {
            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);
        }

        if (1) {
            cudaFuncSetAttribute(func, cudaFuncAttributeNonPortableClusterSizeAllowed, 16);
        }

        cudaOccupancyMaxActiveBlocksPerMultiprocessor(
            &info_.max_active_ctas, func, Gemm::CTA_SIZE, info_.dynamic_smem_size);

        sm_count_ = getSMCount();

        info_.name = GetName();
    }

    int Launch(const Operation&    operation,
               float               alpha,
               const void*         A,
               const MatrixLayout& _Adesc,
               const void*         U,
               const MatrixLayout& Udesc,
               const void*         B,
               const MatrixLayout& _Bdesc,
               const void*         V,
               const MatrixLayout& _Vdesc,
               float               beta,
               const void*         C,
               const MatrixLayout& Cdesc,
               void*               D,
               const MatrixLayout& Ddesc,
               int                 swizzle,
               int                 splits,
               Workspace&          workspace,
               cudaStream_t        stream) override
    {
        using Sched = typename Gemm::Scheduler;

        MatrixLayout Adesc = _Adesc;

        [[maybe_unused]] const int m = Ddesc.rows;
        [[maybe_unused]] const int n = Ddesc.cols;
        [[maybe_unused]] const int k = Adesc.cols;

        // std::cout << "M: " << m << ", N: " << n << ", K: " << k << "\n";

        auto transpose = [](MatrixLayout x) {
            std::swap(x.rows, x.cols);
            x.order = gemm::transpose(x.order);
            return x;
        };

        // (K, N) -> (N, K)
        MatrixLayout Bdesc = transpose(_Bdesc);
        MatrixLayout Vdesc = transpose(_Vdesc);

        auto sched = [&] {
            const int2 tiles = get_tiled_shape(m, n, TILE_M, TILE_N);
            const int4 shape{m, n, k, Adesc.num};

            swizzle = Sched::get_log_tile(tiles, 1 << swizzle);

            Sched sched{};
            sched.init(shape, swizzle, {TILE_M, TILE_N, TILE_K});

            sched.next_cluster_id_ = TM_CHECK_NOTNULL(workspace.flags);

            sched.offsets_ = Adesc.offsets;

            return sched;
        }();

        constexpr int kMulticastA = Gemm::kMulticastA;
        constexpr int kMulticastB = Gemm::kMulticastB;
        constexpr int kMulticastU = Gemm::kMulticastU;

        constexpr int kTileM = Gemm::TILE_M;
        constexpr int kTileN = Gemm::TILE_N;

        if (Gemm::Scheduler::is_dynamic) {
            check_cuda_error(cudaMemsetAsync(workspace.flags, 0, sizeof(int), stream));
        }

        // std::cout << "A: " << Adesc << "\n";
        auto tm_a = make_2d_tma_desc((void*)A, Adesc, {kTileM / kMulticastA, TILE_K}, CU_TENSOR_MAP_SWIZZLE_128B);

        // std::cout << "B: " << Bdesc << "\n";
        auto tm_b = make_2d_tma_desc(Gemm::is_grouped_gemm ? nullptr : (void*)B,
                                     Bdesc,
                                     {kTileN / kMulticastB, TILE_K},
                                     CU_TENSOR_MAP_SWIZZLE_128B);

        // std::cout << "C: " << Cdesc << "\n";
        using LayoutC = typename Gemm::LayoutC;
        auto tm_c     = make_2d_tma_desc((void*)C, Cdesc, {LayoutC::S0, LayoutC::C0}, get_tma_swizzle(Gemm::kSwizzleC));

        CUtensorMap tm_u{};
        if (U) {
            // std::cout << "U: " << Udesc << "\n";
            tm_u = make_2d_tma_desc((void*)U, Udesc, {Gemm::kBoxU / kMulticastU, 1}, CU_TENSOR_MAP_SWIZZLE_NONE);
        }

        CUtensorMap            tm_v{};
        [[maybe_unused]] uint2 box_v{};
        if (V) {
            // std::cout << "V: " << Vdesc << "\n";
            // box_v = {(uint32_t)round_up(cdiv(k, 128), 4), 2};
            // std::cout << "V: " << Vdesc << ", box: " << box_v.x << "," << box_v.y << "\n";
            // tm_v = make_2d_tma_desc((void*)V, Vdesc, {box_v.y, box_v.x}, CU_TENSOR_MAP_SWIZZLE_NONE);
        }

        const int sm_count = sm_count_;

        static constexpr int cluster_size = Gemm::kClusterSize;

        auto       grid  = sm_count / cluster_size * cluster_size;
        const auto block = Gemm::CTA_SIZE;

        cudaLaunchConfig_t config{};
        config.gridDim          = grid;
        config.blockDim         = block;
        config.dynamicSmemBytes = info_.dynamic_smem_size;
        config.stream           = stream;

        auto func = gemm_kernel_name;

        [[maybe_unused]] static bool _ = [&] {
            int max_cluster_size = 0;
            cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, func, &config);
            // std::cout << "max cluster size: " << max_cluster_size << "\n";
            return false;
        }();

        cudaLaunchAttribute attrs[1];

        attrs[0].id               = cudaLaunchAttributeClusterDimension;
        attrs[0].val.clusterDim.x = cluster_size;
        attrs[0].val.clusterDim.y = 1;
        attrs[0].val.clusterDim.z = 1;

        config.attrs    = attrs;
        config.numAttrs = std::size(attrs);

        int max_active_cluster{};
        cudaOccupancyMaxActiveClusters(&max_active_cluster, func, &config);
        config.gridDim = std::min(config.gridDim.x, max_active_cluster * cluster_size);

        // std::cout << "max active cluster: " << max_active_cluster << "\n";

        // std::cout << "swizzle: " << swizzle << ", split: " << splits << "\n";

        auto ec = cudaLaunchKernelEx(&config,
                                     func,
                                     tm_a,
                                     tm_b,
                                     tm_c,
                                     tm_u,
                                     tm_v,
                                     to_param((void*)A, Adesc),
                                     to_param((void*)B, Bdesc),
                                     to_param((void*)U, Udesc),
                                     to_param((void*)V, Vdesc),
                                     to_param((void*)D, Ddesc),
                                     sched,
                                     workspace.tensormaps);
        TM_CHECK_EQ(ec, cudaSuccess) << cudaGetErrorString(ec);

        return 0;
    }

    std::array GetWorkspaceSize(int tiles, int splits) const
    {
        static constexpr bool kSerial = true;

        size_t barriers_size = sizeof(int) * tiles;
        size_t partials_size = sizeof(float) * TILE_M * TILE_N * tiles;

        if constexpr (!kSerial) {
            barriers_size *= splits;
            partials_size *= splits;
        }

        return {barriers_size, partials_size};
    }

    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override
    {
        return 1;
    }

    int GetMaxSwizzle(const int4& shape) const override
    {
        using Map = typename Gemm::Scheduler;
        // TODO: fix tiled shape
        const auto tiles = get_tiled_shape(shape.x, shape.y, TILE_M, TILE_N);
        return Map::get_log_tile(tiles, 1 << 10);
    }

    bool is_feasible(const GemmDesc& desc) const noexcept override
    {
        if (desc.striding_a != desc_.striding_a) {
            return false;
        }
        if (desc.striding_b != desc_.striding_b) {
            return false;
        }
        if (desc.striding_c != desc_.striding_c) {
            return false;
        }
        return Kernel::is_feasible(desc);
    }

private:
    int sm_count_ = 0;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/mainloop_sm70.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/thread_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include 

namespace turbomind::gemm {

template
struct GroupIter {

    static_assert((Stages & (Stages - 1)) == 0);

    int iter_ = 0;

    __device__ void Advance()
    {
        iter_ = (iter_ + 1) % Stages;
    }

    __device__ constexpr explicit operator bool()
    {
        return iter_ == 0;
    }
};

template<>
struct GroupIter<1> {
    __device__ void               Advance() {}
    __device__ constexpr explicit operator bool()
    {
        return true;
    }
};

template
struct SmemIter {
    Pointer pointer;
    Pointer other_;

    __device__ SmemIter(Pointer base): pointer{base}, other_{base + Step} {}

    __device__ void Advance()
    {
        auto tmp = pointer;
        pointer  = other_;
        other_   = tmp;
    }
};

template
struct Binding {
    A&         a;
    B&         b;
    U&         u;
    V&         v;
    __device__ Binding(A& a, B& b, U& u, V& v): a{a}, b{b}, u{u}, v{v} {}  // CTAD
};

// Inspired by
// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/threadblock/mma_pipelined.h
template
struct MainloopSm70 {

    using MMA_Atom = typename MMA::Atom;
    using MMA_Map  = typename MMA::Map;

    using FragC = typename MMA_Atom::FragC[MMA::kMmaIterM][MMA::kMmaIterN];

    static constexpr int Stages = Stages_;

    static constexpr int CTA_M = MMA::M;
    static constexpr int CTA_N = MMA::N;
    static constexpr int CTA_K = MMA::K;

    static constexpr auto kOpClass = MMA_Atom::kOpClass;

    static constexpr int WARPS = MMA::kThreadCount / WARP_SIZE;

    using OperandA = MakeOperand;
    using OperandU = MakeOperand;

    using OperandB = MakeOperand;
    using OperandV = MakeOperand;

    using TransformA = TransformA_;
    using TransformB = TransformB_;

    using Ta = typename OperandA::Dtype;
    using Tb = typename OperandB::Dtype;
    using Tu = typename OperandU::Dtype;
    using Tv = typename OperandV::Dtype;

    using SmemLayoutA = typename OperandA::SmemLayout;
    using SmemLayoutB = typename OperandB::SmemLayout;
    using SmemLayoutU = typename OperandU::SmemLayout;
    using SmemLayoutV = typename OperandV::SmemLayout;

    using SmemCopyA = SmemCopy;
    using SmemCopyU = SmemCopy;
    using SmemCopyB = SmemCopy;
    using SmemCopyV = SmemCopy;

    using SmemAccessorA = SmemAccessor;
    using SmemAccessorB = SmemAccessor;
    using SmemAccessorU = SmemAccessor;
    using SmemAccessorV = SmemAccessor;

    using GmemIterA = typename OperandA::GmemIter;
    using GmemIterB = typename OperandB::GmemIter;
    using GmemIterU = typename OperandU::GmemIter;
    using GmemIterV = typename OperandV::GmemIter;

    struct SharedStorage {
        __align__(16) Array A;
        __align__(16) Array B;
        __align__(16) Array U;
        __align__(16) Array V;
    };

    template
    __device__ void _advance_smem(GmemIter& gmem_iter, SmemIter& smem_iter)
    {
        gmem_iter.smem_data_ = smem_iter.pointer;
        smem_iter.Advance();
    }

    // zip with
    template
    __device__ void AdvanceSmemStage(BindingG& g, BindingS& s)
    {
        _advance_smem(g.a, s.a);
        _advance_smem(g.b, s.b);
        _advance_smem(g.u, s.u);
        _advance_smem(g.v, s.v);
    }

    template
    __device__ void ClearSmem(Binding& g)
    {
        g.a.ClearSmem();
        g.b.ClearSmem();
        g.u.ClearSmem();
        g.v.ClearSmem();
    }

    template
    __device__ void Fetch(Binding& g, Fragments& f, bool mask)
    {
        g.a.Fetch(f.a, mask);
        g.b.Fetch(f.b, mask);
        g.u.Fetch(f.u, mask);
        g.v.Fetch(f.v, mask);
    }

    template
    __device__ void Store(Binding& g, Fragments& f)
    {
        g.a.Store(f.a);
        g.b.Store(f.b);
        g.u.Store(f.u);
        g.v.Store(f.v);
    }

    template
    __device__ void AdvanceGmemStage(Binding& g)
    {
        g.a.Advance();
        g.b.Advance();
        g.u.Advance();
        g.v.Advance();
    }

    __device__ void operator()(GmemIterA&     gmem_A,
                               GmemIterB&     gmem_B,
                               GmemIterU&     gmem_U,
                               GmemIterV&     gmem_V,
                               FragC&         frag_C,
                               int            tile_iter,
                               SharedStorage& storage)
    {
        static_assert(MMA::kAtomK == 1);

        static constexpr int UU = 1;  // ceil_div(GroupSizeU_, MMA_Map::TileK);
        static constexpr int VV = 1;  // ceil_div(GroupSizeV_, MMA_Map::TileK);

        // mma_iter_x = tile_iter_x * atom_x
        typename MMA_Atom::FragA frag_A[MMA::kTileIterK][MMA::kMmaIterM];
        typename MMA_Atom::FragB frag_B[MMA::kTileIterK][MMA::kMmaIterN];

        typename SmemCopyA::Frag data_A[MMA::kTileIterK];
        typename SmemCopyB::Frag data_B[MMA::kTileIterK];
        typename SmemCopyU::Frag data_U[ceil_div(MMA::kTileIterK, UU)];
        typename SmemCopyV::Frag data_V[ceil_div(MMA::kTileIterK, VV)];

        SmemIter, SmemLayoutA::kSize, Stages> smem_A{storage.A.data()};
        SmemIter, SmemLayoutB::kSize, Stages> smem_B{storage.B.data()};
        SmemIter, SmemLayoutU::kSize, Stages> smem_U{storage.U.data()};
        SmemIter, SmemLayoutV::kSize, Stages> smem_V{storage.V.data()};

        typename GmemIterA::Fragments rmem_A;
        typename GmemIterB::Fragments rmem_B;
        typename GmemIterU::Fragments rmem_U;
        typename GmemIterV::Fragments rmem_V;

        GroupIter gmem_group_iter_U{};
        GroupIter gmem_group_iter_V{};

        auto smem_group_iter_U = gmem_group_iter_U;
        auto smem_group_iter_V = gmem_group_iter_V;

        // a separate counter tends to generate better code
        int gmem_iter = tile_iter;
        int gmem_mask = true;

        Binding gmem_iters{gmem_A, gmem_B, gmem_U, gmem_V};
        Binding smem_iters{smem_A, smem_B, smem_U, smem_V};
        Binding rmem{rmem_A, rmem_B, rmem_U, rmem_V};

        // r0,w_

        PRAGMA_UNROLL
        for (int i = 0; i < Stages; ++i) {
            AdvanceSmemStage(gmem_iters, smem_iters);
            ClearSmem(gmem_iters);
        }

        // r0,w1

        __syncthreads();

        auto fetch_stage = [&](auto& rmem) {
            Fetch(gmem_iters, rmem, gmem_mask);
            AdvanceGmemStage(gmem_iters);
            gmem_group_iter_U.Advance();
            gmem_group_iter_V.Advance();
            gmem_U.g_mask = (bool)gmem_group_iter_U;
            gmem_V.g_mask = (bool)gmem_group_iter_V;
            if (--gmem_iter == 0) {
                gmem_mask = false;
            }
        };

        auto advance_and_wait_smem_stage = [&] {
            __syncthreads();
            AdvanceSmemStage(gmem_iters, smem_iters);
        };

        const int3 offset_mnk = MMA::get_offset(threadIdx.x);
        const int  offset_m   = offset_mnk.x;
        const int  offset_n   = offset_mnk.y;
        const int  offset_k   = offset_mnk.z;

        SmemCopyA smem_copy_A{{offset_m, offset_k}};
        SmemCopyU smem_copy_U{{offset_m, offset_k}};
        SmemCopyB smem_copy_B{{offset_n, offset_k}};
        SmemCopyV smem_copy_V{{offset_n, offset_k}};

        auto preload = [&](int k) {
            smem_copy_A(smem_A.pointer, data_A[k], k);
            smem_copy_U(smem_U.pointer, data_U[k / UU], k, k % UU == 0 && (bool)smem_group_iter_U);

            smem_copy_B(smem_B.pointer, data_B[k], k);
            smem_copy_V(smem_V.pointer, data_V[k / VV], k, k % VV == 0 && (bool)smem_group_iter_V);
        };

        AdvanceSmemStage(gmem_iters, smem_iters);
        // r1,w0

        fetch_stage(rmem);  // gmem -> rmem

        Store(gmem_iters, rmem);  // rmem -> smem

        advance_and_wait_smem_stage();
        // r0,w1

        preload(0);  // smem -> data_[A,B,U,V]

        TransformA::apply(frag_A, 0, data_A, data_U, UU);
        TransformB::apply(frag_B, 0, data_B, data_V, VV);

        PRAGMA_NO_UNROLL
        for (; tile_iter > 0; --tile_iter) {
            constexpr int ITER_K = MMA::kTileIterK;
            static_assert(ITER_K > 1);

            PRAGMA_UNROLL
            for (int k = 0; k < ITER_K; ++k) {
                // The last iter, store prefetched fragments to smem
                if (k == ITER_K - 1) {
                    Store(gmem_iters, rmem);
                    advance_and_wait_smem_stage();  // swap rw
                    smem_group_iter_U.Advance();
                    smem_group_iter_V.Advance();
                }

                // Preload for next iter, smem -> data_[A,B,U,V]
                preload((k + 1) % ITER_K);

                // The first iter, issue the prefetching of next stage
                if (k == 0) {
                    fetch_stage(rmem);
                }

                // PRAGMA_UNROLL
                // for (int n = 0; n < MMA::kMmaIterN; ++n) {
                //     PRAGMA_UNROLL
                //     for (int m = 0; m < MMA::kMmaIterM; ++m) {
                //         int mm = n % 2 ? MMA::kMmaIterM - m - 1 : m;
                //         MMA_Atom::fma(frag_C[mm][n], frag_A[k][mm], frag_B[k][n], frag_C[mm][n]);
                //     }
                // }

                PRAGMA_UNROLL
                for (int m = 0; m < MMA::kMmaIterM; ++m) {
                    PRAGMA_UNROLL
                    for (int n = 0; n < MMA::kMmaIterN; ++n) {
                        int nn = m % 2 ? MMA::kMmaIterN - n - 1 : n;
                        MMA_Atom::fma(frag_C[m][nn], frag_A[k][m], frag_B[k][nn], frag_C[m][nn]);
                    }
                }

                TransformA::apply(frag_A, (k + 1) % ITER_K, data_A, data_U, UU);
                TransformB::apply(frag_B, (k + 1) % ITER_K, data_B, data_V, VV);
            }
        }

        __syncthreads();
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/mainloop_sm80_v2.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/thread_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include 

namespace turbomind::gemm {

template
struct GroupIter {

    static_assert((Stages & (Stages - 1)) == 0);

    int iter_ = 0;

    __device__ void Advance()
    {
        iter_ = (iter_ + 1) % Stages;
    }

    __device__ constexpr explicit operator bool()
    {
        return iter_ == 0;
    }
};

template<>
struct GroupIter<1> {
    __device__ void               Advance() {}
    __device__ constexpr explicit operator bool()
    {
        return true;
    }
};

template
struct SmemIter {
    Pointer base_;
    Pointer pointer;
    int     pipe_iter_;

    __device__ SmemIter(Pointer base): base_{base}, pointer{base}, pipe_iter_{} {}

    __device__ void Advance()
    {
        pipe_iter_ += 1;
        pointer = pointer + Step;
        if (pipe_iter_ == Stages) {
            pipe_iter_ = 0;
            pointer    = base_;
        }
    }
};

template
struct Binding {
    A&         a;
    B&         b;
    U&         u;
    V&         v;
    __device__ Binding(A& a, B& b, U& u, V& v): a{a}, b{b}, u{u}, v{v} {}  // CTAD
};

// Inspired by
// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/threadblock/mma_multistage.h
// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/collective/sm80_mma_multistage.hpp
template
struct MainloopSm80_v2 {

    using MMA_Atom = typename MMA::Atom;
    using MMA_Map  = typename MMA::Map;

    using FragC = typename MMA_Atom::FragC[MMA::kMmaIterM][MMA::kMmaIterN];

    static constexpr int Stages = Stages_;

    static constexpr int CTA_M = MMA::M;
    static constexpr int CTA_N = MMA::N;
    static constexpr int CTA_K = MMA::K;

    static constexpr auto kOpClass = MMA_Atom::kOpClass;

    static constexpr int WARPS = MMA::kThreadCount / WARP_SIZE;

    using OperandA = MakeOperand;
    using OperandU = MakeOperand;

    using OperandB = MakeOperand;
    using OperandV = MakeOperand;

    using TransformA = TransformA_;
    using TransformB = TransformB_;

    using Ta = typename OperandA::Dtype;
    using Tb = typename OperandB::Dtype;
    using Tu = typename OperandU::Dtype;
    using Tv = typename OperandV::Dtype;

    using SmemLayoutA = typename OperandA::SmemLayout;
    using SmemLayoutB = typename OperandB::SmemLayout;
    using SmemLayoutU = typename OperandU::SmemLayout;
    using SmemLayoutV = typename OperandV::SmemLayout;

    using SmemCopyA = SmemCopy;
    using SmemCopyU = SmemCopy;
    using SmemCopyB = SmemCopy;
    using SmemCopyV = SmemCopy;

    using SmemAccessorA = SmemAccessor;
    using SmemAccessorB = SmemAccessor;
    using SmemAccessorU = SmemAccessor;
    using SmemAccessorV = SmemAccessor;

    using GmemIterA = typename OperandA::GmemIter;
    using GmemIterB = typename OperandB::GmemIter;
    using GmemIterU = typename OperandU::GmemIter;
    using GmemIterV = typename OperandV::GmemIter;

    static constexpr int kFusePrefetch = FusePrefetch_;

    static constexpr int kMaxPrefetchIter = 1;
    // std::min(ceil_div(std::max(GmemIterA::ITER_S, GmemIterB::ITER_S), 4), MMA::kTileIterK);

    static constexpr int kBatchA = ceil_div(GmemIterA::ITER_S, kMaxPrefetchIter);
    static constexpr int kBatchB = ceil_div(GmemIterB::ITER_S, kMaxPrefetchIter);
    static constexpr int kBatchU = ceil_div(GmemIterU::ITER_S, kMaxPrefetchIter);
    static constexpr int kBatchV = ceil_div(GmemIterV::ITER_S, kMaxPrefetchIter);

    struct SharedStorage {
        __align__(16) Array A;
        __align__(16) Array B;
        __align__(16) Array U;
        __align__(16) Array V;
    };

    __device__ void Wait()
    {
        __pipeline_wait_prior(Stages - 2);
        __syncthreads();
    }

    template
    __device__ void _advance_smem(GmemIter& gmem_iter, SmemIter& smem_iter)
    {
        gmem_iter.smem_data_ = smem_iter.pointer;
        smem_iter.Advance();
    }

    // zip with
    template
    __device__ void AdvanceSmemStage(BindingG& g, BindingS& s)
    {
        _advance_smem(g.a, s.a);
        _advance_smem(g.b, s.b);
        _advance_smem(g.u, s.u);
        _advance_smem(g.v, s.v);
    }

    template
    __device__ void ClearSmem(Binding& g)
    {
        g.a.ClearSmem();
        g.b.ClearSmem();
        g.u.ClearSmem();
        g.v.ClearSmem();
    }

    template
    __device__ void Prefetch(Binding& g, bool mask)
    {
        g.a.Prefetch(mask);
        g.b.Prefetch(mask);
        g.u.Prefetch(mask);
        g.v.Prefetch(mask);
    }

    template
    __device__ void Prefetch(Binding& g, int k, bool mask)
    {
        int batch_A = min((k + 1) * kBatchA, GmemIterA::ITER_S) - k * kBatchA;
        int batch_B = min((k + 1) * kBatchB, GmemIterB::ITER_S) - k * kBatchB;
        int batch_U = min((k + 1) * kBatchU, GmemIterU::ITER_S) - k * kBatchU;
        int batch_V = min((k + 1) * kBatchV, GmemIterV::ITER_S) - k * kBatchV;
        g.a.Prefetch(k * kBatchA, batch_A, mask);
        g.b.Prefetch(k * kBatchB, batch_B, mask);
        g.u.Prefetch(k * kBatchU, batch_U, mask);
        g.v.Prefetch(k * kBatchV, batch_V, mask);
    }

    template
    __device__ void AdvanceGmemStage(Binding& g)
    {
        g.a.Advance();
        g.b.Advance();
        g.u.Advance();
        g.v.Advance();
    }

    __device__ void operator()(GmemIterA&     gmem_A,
                               GmemIterB&     gmem_B,
                               GmemIterU&     gmem_U,
                               GmemIterV&     gmem_V,
                               FragC&         frag_C,
                               int            tile_iter,
                               SharedStorage& storage)
    {
        static_assert(MMA::kAtomK == 1);

        static constexpr int UU = ceil_div(GroupSizeU_, MMA_Map::TileK);
        static constexpr int VV = ceil_div(GroupSizeV_, MMA_Map::TileK);

        // mma_iter_x = tile_iter_x * atom_x
        typename MMA_Atom::FragA frag_A[MMA::kTileIterK][MMA::kMmaIterM];
        typename MMA_Atom::FragB frag_B[MMA::kTileIterK][MMA::kMmaIterN];

        typename SmemCopyA::Frag data_A[MMA::kTileIterK];
        typename SmemCopyB::Frag data_B[MMA::kTileIterK];
        typename SmemCopyU::Frag data_U[ceil_div(MMA::kTileIterK, UU)];
        typename SmemCopyV::Frag data_V[ceil_div(MMA::kTileIterK, VV)];

        SmemIter, SmemLayoutA::kSize, Stages> smem_A{storage.A.data()};
        SmemIter, SmemLayoutB::kSize, Stages> smem_B{storage.B.data()};
        SmemIter, SmemLayoutU::kSize, Stages> smem_U{storage.U.data()};
        SmemIter, SmemLayoutV::kSize, Stages> smem_V{storage.V.data()};

        GroupIter gmem_group_iter_U{};
        GroupIter gmem_group_iter_V{};

        auto smem_group_iter_U = gmem_group_iter_U;
        auto smem_group_iter_V = gmem_group_iter_V;

        // a separate counter tends to generate better code
        int gmem_iter = tile_iter;
        int gmem_mask = true;

        Binding gmem_iters{gmem_A, gmem_B, gmem_U, gmem_V};
        Binding smem_iters{smem_A, smem_B, smem_U, smem_V};

        PRAGMA_UNROLL
        for (int i = 0; i < Stages; ++i) {
            AdvanceSmemStage(gmem_iters, smem_iters);
            ClearSmem(gmem_iters);
        }

        // r: 0, w:s-1

        __syncthreads();

        auto prefetch_stage = [&] {
            Prefetch(gmem_iters, gmem_mask);
            __pipeline_commit();
            AdvanceGmemStage(gmem_iters);
            gmem_group_iter_U.Advance();
            gmem_group_iter_V.Advance();
            gmem_U.g_mask = (bool)gmem_group_iter_U;
            gmem_V.g_mask = (bool)gmem_group_iter_V;
            if (--gmem_iter == 0) {
                gmem_mask = false;
            }
        };

        [[maybe_unused]] auto prefetch_batch = [&](int k) {
            Prefetch(gmem_iters, k, gmem_mask);
            if (k == MMA::kTileIterK - 1) {
                __pipeline_commit();
                AdvanceGmemStage(gmem_iters);
                gmem_group_iter_U.Advance();
                gmem_group_iter_V.Advance();
                gmem_U.g_mask = (bool)gmem_group_iter_U;
                gmem_V.g_mask = (bool)gmem_group_iter_V;
                if (--gmem_iter == 0) {
                    gmem_mask = false;
                }
            }
        };

        auto advance_and_wait_smem_stage = [&] {
            Wait();
            AdvanceSmemStage(gmem_iters, smem_iters);
        };

        const int3 offset_mnk = MMA::get_offset(threadIdx.x);
        const int  offset_m   = offset_mnk.x;
        const int  offset_n   = offset_mnk.y;
        const int  offset_k   = offset_mnk.z;

        SmemCopyA smem_copy_A{{offset_m, offset_k}};
        SmemCopyU smem_copy_U{{offset_m, offset_k}};
        SmemCopyB smem_copy_B{{offset_n, offset_k}};
        SmemCopyV smem_copy_V{{offset_n, offset_k}};

        auto preload = [&](int k) {
            smem_copy_A(smem_A.pointer, data_A[k], k);
            smem_copy_U(smem_U.pointer, data_U[k / UU], k, k % UU == 0 && (bool)smem_group_iter_U);

            smem_copy_B(smem_B.pointer, data_B[k], k);
            smem_copy_V(smem_V.pointer, data_V[k / VV], k, k % VV == 0 && (bool)smem_group_iter_V);
        };

        PRAGMA_UNROLL
        for (int stage = 0; stage < Stages - 1; ++stage) {
            AdvanceSmemStage(gmem_iters, smem_iters);
            prefetch_stage();
        }
        // r:-1, w:-2

        advance_and_wait_smem_stage();
        // r: 0, w:-1

        preload(0);

        TransformA::apply(frag_A, 0, data_A, data_U, UU);
        TransformB::apply(frag_B, 0, data_B, data_V, VV);

        if constexpr (kFusePrefetch) {
            prefetch_batch(0);
        }

        PRAGMA_NO_UNROLL
        for (; tile_iter > 0; --tile_iter) {
            if constexpr (!kFusePrefetch) {
                prefetch_stage();
            }
            constexpr int ITER_K = MMA::kTileIterK;
            static_assert(ITER_K > 1);

            PRAGMA_UNROLL
            for (int k = 0; k < ITER_K; ++k) {
                // preload for next iter
                preload((k + 1) % ITER_K);

                MMA::mma_k_iter(frag_C, frag_A[k], frag_B[k], frag_C);

                if constexpr (kFusePrefetch) {
                    prefetch_batch((k + 1) % ITER_K);
                }

                if (k + 1 == ITER_K - 1) {
                    advance_and_wait_smem_stage();
                    smem_group_iter_U.Advance();
                    smem_group_iter_V.Advance();
                }

                TransformA::apply(frag_A, (k + 1) % ITER_K, data_A, data_U, UU);
                TransformB::apply(frag_B, (k + 1) % ITER_K, data_B, data_V, VV);
            }
        }

        __pipeline_commit();
        __pipeline_wait_prior(0);

        __syncthreads();
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/matrix_ptr.h
================================================
#pragma once

#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

struct __align__(16) StridedPtr
{
    void* ptr;
    int   stride;
};

struct MatrixParam {
    void* ptr;
    int   stride;
    int*  offsets;
    int*  idxs;
};

struct MatrixData {
    StridedPtr ptr;
    const int* idxs;
};

inline MatrixParam to_param(void* ptr, MatrixLayout layout)
{
    return {ptr, layout.ld, layout.offsets, layout.idxs};
}

#if 0
template
__inline__ __device__ MatrixData resolve(const MatrixParam& param, int gemm_id)
{
    if constexpr (mode == Striding::kFlat) {
        return {{param.ptr, param.stride}, nullptr};
    }
    else if constexpr (mode == Striding::kBlocked) {
        StridedPtr ptr{param.ptr, param.stride};
        if (param.stride == 0) {
            (uint4&)ptr = __ldg((const uint4*)param.ptr + gemm_id);
        }
        return {ptr, nullptr};
    }
    else if constexpr (mode == Striding::kIndexed) {
        const uintptr_t idx = param.idxs ? __ldg((uintptr_t*)param.idxs + gemm_id) : 0;
        StridedPtr      ptr{param.ptr, param.stride};
        if (param.stride == 0) {
            (uint4&)ptr = __ldg((const uint4*)param.ptr + gemm_id);
        }
        return {ptr, reinterpret_cast(idx)};
    }
    else {
        static_assert(mode != mode, "Not implemented.");
        return {};
    }
}
#endif

template
__inline__ __device__ MatrixData resolve(const MatrixParam& param, int g)
{
    StridedPtr ptr{param.ptr, param.stride};
    const int* idxs{};
    if constexpr (mode == Striding::kFlat) {
        // pass
    }
    else if constexpr (mode == Striding::kBlocked) {
        if (ptr.stride == 0) {
            (uint4&)ptr = __ldg((const uint4*)param.ptr + g);
        }  // Post-condition: ptr.stride != 0
        if (param.offsets) {
            ptr.ptr = (char*)ptr.ptr + __ldg(param.offsets + g) * (size_t)ptr.stride * bitsof / bitsof;
        }
    }
    else if constexpr (mode == Striding::kIndexed) {
        idxs = param.idxs;
        if (ptr.stride == 0) {
            (uint4&)ptr = __ldg((const uint4*)param.ptr + g);
            idxs        = idxs ? ((int**)idxs)[g] : nullptr;
        }  // Post-condition: ptr.stride != 0
        if (param.offsets) {
            const int offset = __ldg(param.offsets + g);
            if (idxs) {
                idxs += offset;
            }
            else {
                ptr.ptr = (char*)ptr.ptr + offset * (size_t)ptr.stride * bitsof / bitsof;
            }
        }
    }
    else {
        static_assert(mode != mode, "Not implemented.");
    }
    return {ptr, idxs};
}

// p <- dat_ptrs[g]
// i <- idx_ptrs[g]

// pitch offset idxs
//    1     0     0   -> {ptr, pitch}       , 0
//    1     0     1   -> {ptr, pitch}       , idxs
//    1     1     0   -> {ptr, pitch} + o[g], 0
//    1     1     1   -> {ptr, pitch}       , idxs + o[g]
//    0     0     0   ->       p            , 0
//    0     0     1   ->       p            , i
//    0     1     0   ->       p      + o[g], 0
//    0     1     1   ->       p            , i    + o[g]

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/moe_utils_v2.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"

namespace turbomind {

template
__global__ void MoeGateKernel_V2(float*       scales,  // [e,n]
                                 int8_t*      masks,   // [E,n], padded
                                 int*         accum,   // [E,tiles]
                                 const float* logits,  // [E,n]
                                 int          log_tile,
                                 int          tiles,
                                 int          tokens,
                                 int          tokens_padded,
                                 int          experts)
{
    constexpr int max_tiles = kMoeGateMaxTiles;

    // Brute-force per thread top-k using a flat thread mapping
    const int ti = threadIdx.x + blockIdx.x * blockDim.x;

    // Clear masks
    for (int e = 0; e < experts; ++e) {
        if (ti < tokens_padded) {
            masks[e * tokens_padded + ti] = -1;
        }
    }

    __shared__ int shared_accum[32][max_tiles];

    for (int i = threadIdx.x; i < experts * max_tiles; i += block_dim) {
        int e = i / max_tiles;
        int t = i % max_tiles;
        if (e < experts && t < tiles) {
            shared_accum[e][t] = 0;
        }
    }

    __syncthreads();

    if (ti < tokens) {

        static_assert(top_k <= 32);
        int mask = -1;

        float max_logit = 0.f;

        // Find top-k
        PRAGMA_UNROLL
        for (int k = 0; k < top_k; ++k) {
            int   max_bit = 0;
            float max_val = -std::numeric_limits::infinity();
            int   bit     = 1;
            for (int e = 0; e < experts; ++e) {
                const auto val = logits[ti * experts + e];
                // const auto val = logits[e * tokens + ti];
                if ((mask & bit) && val > max_val) {
                    max_bit = bit;
                    max_val = val;
                }
                bit *= 2;
            }
            mask -= max_bit;
            if (k == 0) {
                max_logit = max_val;
            }
        }

        mask = ~mask;

        Array top_val;
        PRAGMA_UNROLL
        for (int i = 0; i < top_k; ++i) {
            const int lowbit = (mask & -mask);
            const int e      = 31 - __clz(lowbit);

            // printf("e = %d, ti = %d, idx = %d\n", e, ti, i);

            masks[e * tokens_padded + ti] = i;
            atomicAdd(&shared_accum[e][ti >> log_tile], 1);
            top_val[i] = logits[ti * experts + e];
            // top_val[i] = logits[e * tokens + ti];

            mask -= lowbit;
        }

        float prob_sum = 0.f;
        PRAGMA_UNROLL
        for (int i = 0; i < top_k; ++i) {
            top_val[i] = expf(top_val[i] - max_logit);
            prob_sum += top_val[i];
        }

        PRAGMA_UNROLL
        for (int i = 0; i < top_k; ++i) {
            scales[i * tokens + ti] = fdividef(top_val[i], prob_sum);
        }
    }

    __syncthreads();

    for (int i = threadIdx.x; i < experts * max_tiles; i += block_dim) {
        int e = i / max_tiles;
        int t = i % max_tiles;
        if (e < experts && t < tiles) {
            atomicAdd(accum + e * tiles + t, shared_accum[e][t]);
        }
    }
}

template
__global__ void MoeScanKernel_v2(int*       f2n,      // [e*n]
                                 int*       f2E,      // [e*n]
                                 int*       en2f,     // [e,n]
                                 int*       offsets,  // [E+1]
                                 Mask*      masks,    // [E,n], padded
                                 const int* accum,    // [E,tiles]
                                 int        log_tile,
                                 int        tiles,
                                 int        tokens,
                                 int        tokens_padded,
                                 int        experts)
{
    using BlockReduce = cub::BlockReduce;
    using BlockScan   = cub::BlockScan;

    __shared__ union TempStorage {
        typename BlockReduce::TempStorage reduce;
        typename BlockScan::TempStorage   scan;
    } temp_storage;

    constexpr int vec_size = kMoeGateVecSize;

    using Vec = Array;

    const int tile_id = blockIdx.x;
    const int ei      = blockIdx.y;

    const int  global_tile_id = ei * tiles + tile_id;
    const bool is_valid       = global_tile_id <= experts * tiles;

#if 0
    int vacc[4]{};
    {
        int idx = threadIdx.x;
        PRAGMA_UNROLL
        for (int i = 0; i < 4; ++i) {
            if (idx < global_tile_id) {
                vacc[i] = accum[idx];
            }
            idx += block_dim;
        }
    }

    int offset = BlockReduce{temp_storage.reduce}.Sum(vacc);
#else

    int vacc = 0;
    for (int i = threadIdx.x; i < global_tile_id; i += block_dim) {
        if (is_valid && i < global_tile_id) {
            vacc += accum[i];
        }
    }

    int offset = BlockReduce{temp_storage.reduce}.Sum(vacc);

#endif

    __shared__ int shared_offset;

    if (threadIdx.x == 0) {
        shared_offset = offset;
        if (tile_id == 0) {
            offsets[ei] = offset;
        }
    }

    if (ei == experts) {
        return;
    }

    __syncthreads();

    offset = shared_offset;

    const int token_vecs = tokens_padded / vec_size;

    const int tile_size     = 1 << log_tile;
    const int tile_vec_size = tile_size / vec_size;

    const int tile_vec_beg    = tile_id * tile_vec_size;
    const int tile_vec_end    = std::min(tile_vec_beg + tile_vec_size, token_vecs);
    const int tile_vec_padded = tile_vec_beg + round_up(tile_vec_size, block_dim);

    // if (threadIdx.x == 0) {
    //     printf("%d %d %d\n", tile_vec_beg, tile_vec_end, tile_vec_padded);
    // }

    auto mask_ptr = (Vec*)masks + ei * token_vecs;

    for (int vi = tile_vec_beg + threadIdx.x; vi < tile_vec_padded; vi += block_dim) {

        const bool pred = vi < tile_vec_end;

        Vec data;
        fill(data, Mask{-1});
        if (pred) {
            Ldg(data, mask_ptr[vi].data());
        }

        int prefix[vec_size];
        PRAGMA_UNROLL
        for (int i = 0; i < vec_size; ++i) {
            prefix[i] = int(data[i] >= 0);
        }

        int block_sum = 0;

        BlockScan{temp_storage.scan}.ExclusiveSum(prefix, prefix, block_sum);
        __syncthreads();

        PRAGMA_UNROLL
        for (int i = 0; i < vec_size; ++i) {
            if (pred && data[i] >= 0) {
                const int flat_id = prefix[i] + offset;
                const int ti      = vi * vec_size + i;
                f2n[flat_id]      = ti;
                f2E[flat_id]      = ei;
                // No ti is generated for padded tokens so we are safe
                en2f[data[i] * tokens + ti] = flat_id;
            }
        }

        offset += block_sum;
    }
}

template
__global__ void MoeGateKernel_v8(float*       scales,  // [e,n]
                                 Mask*        masks,   // [E,n], padded
                                 int*         accum,   // [E,tiles]
                                 const float* logits,  // [n,E]
                                 int          log_tile,
                                 int          tiles,
                                 int          token_num,
                                 int          token_num_padded,
                                 int          expert_num,
                                 int          top_k,
                                 bool         softmax,
                                 bool         norm_topk,
                                 float        routed_scale)
{
    constexpr int max_tiles         = kMoeGateMaxTiles;
    constexpr int threads_per_token = max_expert_num / items_per_thread;  // 8
    constexpr int tokens_per_cta    = block_dim / threads_per_token;

    // We use bits in a uint32_t to represent selected experts
    static_assert(items_per_thread <= 32);
    // We use warp-level primitives for reduction
    static_assert(threads_per_token <= 32);

    static_assert((threads_per_token & (threads_per_token - 1)) == 0);

    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;

    const int ti = thread_idx / threads_per_token;
    const int ei = thread_idx % threads_per_token;

    const int bti = threadIdx.x / threads_per_token;

    const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token;

    // const int warp_offset  = thread_idx / WARP_SIZE * WARP_SIZE / threads_per_token;
    // const int block_offset = thread_idx / block_dim * block_dim / threads_per_token;

    float data[items_per_thread];
    int   idxs[items_per_thread];

#if 0
    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        data[i] = -std::numeric_limits::infinity();
        idxs[i] = threads_per_token * (i / access_size * access_size) + i % access_size + ei * access_size;
    }
    if (ti < token_num) {
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; i += access_size) {
            const int e = threads_per_token * i + ei * access_size;
            if (e < expert_num) {
                Ldg((Array&)data[i], &logits[ti * expert_num + e]);
            }
        }
    }

    __shared__ union {
        struct {
            // +1 padding greatly reduced (-80%) bank conflicts
            int   shared_accum[max_tiles][max_expert_num + 1];
            float shared_scales[max_top_k][tokens_per_cta];
            int   shared_exp_id[max_top_k][tokens_per_cta];
        };
    } smem;
#elif 1
    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        data[i] = -std::numeric_limits::infinity();
        // idxs[i] = threads_per_token * (i / access_size * access_size) + i % access_size + ei * access_size;
        idxs[i] = ei * items_per_thread + i;
    }
    if (ti < token_num) {
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; i += access_size) {
            // const int e = threads_per_token * i + ei * access_size;
            const int e = ei * items_per_thread + i;
            if (e < expert_num) {
                Ldg((Array&)data[i], &logits[ti * expert_num + e]);
            }
        }
    }

    __shared__ union {
        struct {
            // +1 padding greatly reduced (-80%) bank conflicts
            int   shared_accum[max_tiles][max_expert_num + 1];
            float shared_scales[max_top_k][tokens_per_cta];
            int   shared_exp_id[max_top_k][tokens_per_cta];
        };
    } smem;
#else

    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;

    constexpr int vecs_per_thread = items_per_thread / access_size;

    using Vec            = Array;
    constexpr int banks  = 128 / sizeof(Vec);
    constexpr int chunks = 4;  // block_dim / WARP_SIZE;

    __shared__ union {
        Vec shared_data[chunks][vecs_per_thread * WARP_SIZE / banks][banks + 1];
        struct {
            // +1 padding greatly reduced (-80%) bank conflicts
            int   shared_accum[max_tiles][max_expert_num + 1];
            float shared_scales[max_top_k][tokens_per_cta];
            int   shared_exp_id[max_top_k][tokens_per_cta];
        };
    } smem;

    __align__(16) Vec vecs[vecs_per_thread];

    {
        const int warp_end = min(warp_offset + WARP_SIZE / threads_per_token, token_num) * expert_num;
        int       p        = warp_offset * expert_num + access_size * lane_id;
        PRAGMA_UNROLL
        for (int i = 0; i < vecs_per_thread; ++i) {
            fill(vecs[i], -std::numeric_limits::infinity());
            // const int p = warp_offset * expert_num + access_size * (lane_id + i * WARP_SIZE);
            if (p < warp_end) {
                Ldg(vecs[i], &logits[p]);
            }
            p += access_size * WARP_SIZE;
        }
    }

    PRAGMA_UNROLL
    for (int c = 0; c < block_dim / WARP_SIZE; c += chunks) {
        PRAGMA_UNROLL
        for (int i = 0; i < vecs_per_thread; ++i) {
            int p = i * WARP_SIZE + lane_id;
            if (c <= warp_id && warp_id < c + chunks) {
                Store(smem.shared_data[warp_id - c][p / banks][p % banks].data(), vecs[i]);
            }
        }

        __syncwarp();

        PRAGMA_UNROLL
        for (int i = 0; i < vecs_per_thread; ++i) {
            int p = lane_id * vecs_per_thread + i;
            if (c <= warp_id && warp_id < c + chunks) {
                Load(vecs[i], smem.shared_data[warp_id - c][p / banks][p % banks].data());
            }
        }

        __syncthreads();
    }

    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        idxs[i] = ei * items_per_thread + i;
    }
    PRAGMA_UNROLL
    for (int i = 0; i < vecs_per_thread; ++i) {
        (Array&)data[i * access_size] = vecs[i];
    }

#endif

    // constexpr float kLog2e = 1.4426950408889634074;
    // if (k == 0) {
    //     PRAGMA_UNROLL
    //     for (int i = 0; i < items_per_thread; ++i) {
    //         data[i] *= kLog2e;
    //     }
    // }

    unsigned mask = (unsigned)-1;
    float    max_logit;

    int count{};

    const int warp_ti_offset = warp_ti * threads_per_token;

    auto run = [&](int k) {
        unsigned bit     = 1;
        unsigned max_bit = 0;
        float    max_val = -std::numeric_limits::infinity();
        // local maximum
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; ++i) {
            if ((mask & bit) && data[i] > max_val) {
                max_bit = bit;
                max_val = data[i];
            }
            // weird thing that nvcc tends to use funnel shift for `bit <<= 1`
            asm("shl.b32 %0, %1, 1;\n" : "=r"(bit) : "r"(bit));
        }

        int   g_max_ei  = ei;
        float g_max_val = max_val;
        if constexpr (threads_per_token > 1) {
            // global maximum
            PRAGMA_UNROLL
            for (int m = threads_per_token / 2; m >= 1; m /= 2) {
                g_max_val = fmaxf(g_max_val, __shfl_xor_sync((uint32_t)-1, g_max_val, m));
            }
            // tie breaking
            const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val);
            g_max_ei          = __ffs(active >> (unsigned)warp_ti_offset) - 1;
        }
        if (k == 0) {
            max_logit = g_max_val;
        }
        if (ei == g_max_ei) {
            mask -= max_bit;
            ++count;
        }
    };

    run(0);

    for (int k = 1; k < top_k; ++k) {
        run(k);
    }

    mask = ~mask;

    int used[items_per_thread];
    {
        unsigned bit = 1;
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; ++i) {
            used[i] = (mask & bit) > 0;
            asm("shl.b32 %0, %1, 1;\n" : "=r"(bit) : "r"(bit));
        }
    }

    float sum_prob{};

    if (softmax) {
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; ++i) {
            if (!norm_topk || used[i]) {
                data[i] = expf(data[i] - max_logit);
                sum_prob += data[i];
            }
        }
        PRAGMA_UNROLL
        for (int m = threads_per_token / 2; m >= 1; m /= 2) {
            sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);
        }
        sum_prob = fdividef(1.f, sum_prob);
    }
    else {
        sum_prob = 1.f;
    }

    using WarpScan = cub::WarpScan;
    __shared__ typename WarpScan::TempStorage temp_storage[tokens_per_cta];

    int idx{};
    WarpScan{temp_storage[bti]}.ExclusiveSum(count, idx);

    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        if (used[i]) {
            smem.shared_exp_id[idx][bti] = idxs[i];
            smem.shared_scales[idx][bti] = data[i] * sum_prob;
            ++idx;
        }
    }

    PRAGMA_UNROLL
    for (int i = 0; i < max_tiles * max_expert_num; i += block_dim) {
        int e = (i + threadIdx.x) % max_expert_num;
        int t = (i + threadIdx.x) / max_expert_num;
        if (t < max_tiles) {
            smem.shared_accum[t][e] = 0;
        }
    }

    __syncthreads();

    constexpr int k_per_thread = cdiv(max_top_k, threads_per_token);

    const int bti2 = threadIdx.x % tokens_per_cta;
    const int ei2  = threadIdx.x / tokens_per_cta;
    const int ti2  = blockIdx.x * tokens_per_cta + bti2;

    PRAGMA_UNROLL
    for (int i = 0; i < k_per_thread; ++i) {
        const int   idx       = ei2 * k_per_thread + i;
        const int   expert_id = smem.shared_exp_id[idx][bti2];
        const float scale     = smem.shared_scales[idx][bti2];

        if (ti2 < token_num && idx < top_k) {
            masks[expert_id * token_num_padded + ti2] = idx;
            scales[idx * token_num + ti2]             = scale * routed_scale;
            atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1);
        }
    }

    __syncthreads();

    for (int i = 0; i < max_expert_num * max_tiles; i += block_dim) {
        int t = (threadIdx.x + i) % max_tiles;
        int e = (threadIdx.x + i) / max_tiles;
        if (e < expert_num && t < tiles) {
            atomicAdd(accum + e * tiles + t, smem.shared_accum[t][e]);
        }
    }
}

template
inline constexpr std::integral_constant _Int{};

void invokeMoeGate_V2(int*         f2n,            // [e*n] -> n
                      int*         f2E,            // [e*n] -> E
                      int*         en2f,           // [e,n] -> n*e
                      int*         offsets,        // [E+1]
                      float*       scales,         // [e,n]
                      void*        masks,          // [E,n]
                      int*         accum,          // [E]
                      const float* logits,         // [e,n]
                      int          tokens,         //  n
                      int          tokens_padded,  //  round_up(n, 4)
                      int          experts,        //  E
                      int          experts_per_token,
                      bool         softmax,
                      bool         norm_topk,
                      float        routed_scale,
                      cudaStream_t st)
{
    constexpr int base_log_tile = 9;

    int log_tile = base_log_tile;
    while (((tokens_padded + (1 << log_tile) - 1) >> log_tile) > kMoeGateMaxTiles) {
        ++log_tile;
    }
    const int tiles = ceil_div(tokens_padded, 1 << log_tile);

    // std::cout << log_tile << " " << tiles << "\n";

    auto invoke = [&](auto max_expert_num, auto top_k, auto items_per_thread, auto vec_size) {
        constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value;
        constexpr int threads      = 256;
        const int     blocks       = ceil_div(tokens, threads / thrs_per_tok);

        cudaMemsetAsync(masks, -1, sizeof(int8_t) * experts * tokens_padded, st);

        MoeGateKernel_v8
            <<>>(  //
                scales,
                (int8_t*)masks,
                accum,
                logits,
                log_tile,
                tiles,
                tokens,
                tokens_padded,
                experts,
                experts_per_token,
                softmax,
                norm_topk,
                routed_scale);

        return true;
    };

    if (!softmax && norm_topk) {
        // norm top-k is part of softmax impl
        TM_CHECK(0) << softmax << " " << norm_topk;
    }

    auto dispatch = [&] {
        if (experts <= 8) {
            if (experts_per_token <= 2) {
                return invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);
            }
            else {
                return invoke(_Int<8>, _Int<8>, _Int<8>, _Int<4>);
            }
        }
        else if (experts <= 64) {
            if (experts_per_token <= 4) {
                return invoke(_Int<64>, _Int<4>, _Int<16>, _Int<4>);
            }
            else if (experts_per_token <= 8) {
                return invoke(_Int<64>, _Int<8>, _Int<16>, _Int<4>);
            }
        }
        else if (experts <= 128) {
            if (experts_per_token <= 8) {
                return invoke(_Int<128>, _Int<8>, _Int<16>, _Int<4>);
            }
        }
        else if (experts <= 160) {
            if (experts_per_token <= 8) {
                return invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>);
            }
        }
        else if (experts <= 512) {
            if (experts_per_token <= 8) {
                return invoke(_Int<512>, _Int<8>, _Int<16>, _Int<4>);
            }
        }
        return false;
    };

    auto success = dispatch();

    sync_check_cuda_error();

    TM_CHECK(success) << "unsupported moe config: expert_num=" << experts << ", top_k=" << experts_per_token
                      << ", softmax=" << softmax << ", norm_topk=" << norm_topk;

    {
        constexpr int threads = (1 << base_log_tile) / kMoeGateVecSize;
        const dim3    blocks(tiles, experts + 1);

        MoeScanKernel_v2<<>>(f2n,  //
                                                              f2E,
                                                              en2f,
                                                              offsets,
                                                              (int8_t*)masks,
                                                              accum,
                                                              log_tile,
                                                              tiles,
                                                              tokens,
                                                              tokens_padded,
                                                              experts);
    }
}

// noaux_tc: scores = scoring_func(logits), scores_for_choice = scores + correction_bias,
// top-k on scores_for_choice, weights from scores; renormalize; apply routed_scale.
// Threading: one token per block, threads cooperate over expert dimension.
__global__ void MoeGateNoAuxTCKernel(float*       scales,  // [top_k, tokens]
                                     int8_t*      masks,   // [experts, tokens_padded]
                                     int*         accum,   // [experts, tiles]
                                     const float* logits,  // [tokens, experts]
                                     const float* bias,    // [experts] or nullptr
                                     int          tokens,
                                     int          tokens_padded,
                                     int          experts,
                                     int          top_k,
                                     bool         norm_topk,
                                     float        routed_scale,
                                     int          log_tile,
                                     int          tiles,
                                     bool         use_sigmoid)
{
    const int ti = blockIdx.x;  // one token per block
    if (ti >= tokens) {
        return;
    }

    extern __shared__ char smem[];
    float*                 scores            = (float*)smem;
    float*                 scores_for_choice = scores + experts;

    const float* row = logits + ti * experts;

    if (use_sigmoid) {
        // Sigmoid scoring: scores[e] = 1 / (1 + exp(-logit[e]))
        for (int e = threadIdx.x; e < experts; e += blockDim.x) {
            float s              = 1.0f / (1.0f + expf(-row[e]));
            scores[e]            = s;
            scores_for_choice[e] = s + (bias ? bias[e] : 0.f);
        }
    }
    else {
        // Softmax scoring: scores[e] = exp(logit[e] - max) / sum(exp)
        float max_logit = -1e30f;
        for (int e = threadIdx.x; e < experts; e += blockDim.x) {
            float v = row[e];
            if (v > max_logit) {
                max_logit = v;
            }
        }
        max_logit = blockReduceMax(max_logit);
        __syncthreads();

        float sum_exp = 0.f;
        for (int e = threadIdx.x; e < experts; e += blockDim.x) {
            float s   = expf(row[e] - max_logit);
            scores[e] = s;
            sum_exp += s;
        }
        sum_exp = blockReduceSum(sum_exp);
        __syncthreads();

        for (int e = threadIdx.x; e < experts; e += blockDim.x) {
            float s              = scores[e] / (sum_exp + 1e-20f);
            scores[e]            = s;
            scores_for_choice[e] = s + (bias ? bias[e] : 0.f);
        }
    }
    __syncthreads();

    if (threadIdx.x == 0) {
        // Top-k on scores_for_choice (simple linear scan)
        int   topk_idx[32];
        float topk_val[32];
        for (int k = 0; k < top_k; k++) {
            int   best_e = -1;
            float best_v = -INFINITY;
            for (int e = 0; e < experts; e++) {
                if (k > 0) {
                    bool chosen = false;
                    for (int j = 0; j < k; j++) {
                        if (topk_idx[j] == e) {
                            chosen = true;
                            break;
                        }
                    }
                    if (chosen) {
                        continue;
                    }
                }
                float v = scores_for_choice[e];
                if (!isfinite(v)) {
                    v = -INFINITY;
                }
                if (v > best_v) {
                    best_v = v;
                    best_e = e;
                }
            }
            if (best_e < 0) {
                best_e      = 0;
                topk_val[k] = 0.f;
            }
            else {
                topk_val[k] = scores[best_e];
            }
            topk_idx[k] = best_e;
        }

        float wsum = 0.f;
        for (int k = 0; k < top_k; k++) {
            wsum += topk_val[k];
        }
        if (norm_topk && wsum > 1e-20f) {
            for (int k = 0; k < top_k; k++) {
                topk_val[k] /= wsum;
            }
        }
        for (int k = 0; k < top_k; k++) {
            scales[k * tokens + ti] = topk_val[k] * routed_scale;
        }

        for (int k = 0; k < top_k; k++) {
            masks[topk_idx[k] * tokens_padded + ti] = (int8_t)k;
        }

        const int tile_id = ti >> log_tile;
        for (int k = 0; k < top_k; k++) {
            const int e = topk_idx[k];
            atomicAdd(&accum[e * tiles + tile_id], 1);
        }
    }
}

void invokeMoeGate_NoAuxTC(int*         f2n,
                           int*         f2E,
                           int*         en2f,
                           int*         offsets,
                           float*       scales,
                           void*        masks,
                           int*         accum,
                           const float* logits,
                           const float* correction_bias,
                           int          tokens,
                           int          tokens_padded,
                           int          experts,
                           int          exp_per_tok,
                           bool         norm_topk_prob,
                           float        routed_scale,
                           bool         use_sigmoid,
                           cudaStream_t st)
{
    TM_CHECK(exp_per_tok > 0);
    TM_CHECK_LE(exp_per_tok, 32);
    TM_CHECK_LE(exp_per_tok, experts);

    constexpr int base_log_tile = 9;
    int           log_tile      = base_log_tile;
    while (((tokens_padded + (1 << log_tile) - 1) >> log_tile) > kMoeGateMaxTiles) {
        ++log_tile;
    }
    const int tiles = ceil_div(tokens_padded, 1 << log_tile);

    cudaMemsetAsync(accum, 0, sizeof(int) * experts * kMoeGateMaxTiles, st);
    cudaMemsetAsync(masks, -1, sizeof(int8_t) * experts * tokens_padded, st);

    // One token per block: threads cooperate over expert dimension
    int block_dim = 1;
    while (block_dim < experts && block_dim < 256) {
        block_dim *= 2;  // next power of 2
    }
    const int    blocks = tokens;
    const size_t smem   = sizeof(float) * experts * 2;

    MoeGateNoAuxTCKernel<<>>(scales,
                                                          (int8_t*)masks,
                                                          accum,
                                                          logits,
                                                          correction_bias,
                                                          tokens,
                                                          tokens_padded,
                                                          experts,
                                                          exp_per_tok,
                                                          norm_topk_prob,
                                                          routed_scale,
                                                          log_tile,
                                                          tiles,
                                                          use_sigmoid);

    constexpr int scan_threads = (1 << base_log_tile) / kMoeGateVecSize;
    const dim3    scan_blocks(tiles, experts + 1);
    MoeScanKernel_v2<<>>(
        f2n, f2E, en2f, offsets, (int8_t*)masks, accum, log_tile, tiles, tokens, tokens_padded, experts);
}

template
__global__ void MoeGatherKernel(T*         dst,  // [e*n, d]
                                const T*   src,  // [  n, d]
                                const int* f2n,  // [e*n] :: e*n -> n
                                int        dims)
{
    using Vec        = Array;
    const int64_t bi = blockIdx.x;

    auto src_ptr = (const Vec*)src + dims * f2n[bi];
    auto dst_ptr = (/* */ Vec*)dst + dims * bi;
    for (int i = threadIdx.x; i < dims; i += block_dim) {
        Vec v;
        Ldg(v, src_ptr[i].data());
        Store(dst_ptr[i].data(), v);
    }
}

void invokeMoeDispatch(Ref out_, const Tensor& src, const int* f2n, int expert_per_token, cudaStream_t st)
{
    auto& out    = out_.get();
    auto  invoke = [&](auto t) {
        using T                = decltype(t);
        auto [num, dim]        = src.shapes(0, 1);
        constexpr int threads  = 256;
        constexpr int vec_size = 16 / sizeof(T);
        // std::cout << num * expert_per_token << " " << dim << "\n";
        MoeGatherKernel<<>>(  //
            (T*)out.raw_data(),
            (const T*)src.raw_data(),
            f2n,
            dim / vec_size);
    };
    TM_CHECK_EQ(src.dtype(), out.dtype());
    const auto elem_size = byte_size(src.dtype());
    if (elem_size == sizeof(uint16_t)) {
        return invoke(uint16_t{});
    }
    else if (elem_size == sizeof(uint8_t)) {
        return invoke(uint8_t{});
    }
    TM_CHECK(0) << "unsupported data type: " << src.dtype();
}

template
__global__ void MoeDispatchScales(
    T* dst, int* dst_offsets, const T* src, const int* f2n, const int* offsets, int dim, int expert_num, int stride)
{
    int bi = blockIdx.x;

    __shared__ int shared_g;

    for (int g = threadIdx.x; g < expert_num; ++g) {
        if (offsets[g] <= bi && bi < offsets[g + 1]) {
            shared_g = g;
        }
    }

    __syncthreads();

    int g = shared_g;

    const int base = (offsets[g - 1] + alignment * (g - 1)) / alignment * alignment;
    const int ti   = base + bi - offsets[g];

    bi = f2n[bi];

    // ! strided access
    for (int di = threadIdx.x; di < dim; di += block_dim) {
        dst[di * stride + ti] = src[di * stride + bi];
    }
}

template
__global__ void
MoeDispatchScalesNonaligned(T* dst, const T* src, int dst_stride, int src_stride, const int* f2n, int dim)
{
    const int bi = blockIdx.x;
    const int ti = f2n[bi];

    if (threadIdx.x < dim) {
        dst[threadIdx.x * dst_stride + bi] = src[threadIdx.x * src_stride + ti];
    }
}

void invokeMoeDispatchScales(Ref out_, const Tensor& src, const int* f2n, int expert_per_token, cudaStream_t st)
{
    using T                 = float;
    constexpr int alignment = 16 / sizeof(T);

    auto [dim, num] = src.shapes(0, 1);

    const int size         = num * expert_per_token;
    const int aligned_size = round_up(size, alignment);

    auto& out = out_.get();

    if (!out) {
        out = Tensor_{{{dim, size}, {aligned_size, 1}}, kDEVICE};
    }
    else {
        TM_CHECK(std::make_tuple(dim, size) == out.shapes(0, 1));
        TM_CHECK(out.stride(1) == 1);
        TM_CHECK(out.stride(0) % alignment == 0);
    }

    TM_CHECK_LE(dim, 1024);
    const int threads = round_up(dim, WARP_SIZE);
    const int blocks  = size;

    // std::cout << src << " " << out << "\n";

    MoeDispatchScalesNonaligned<<>>((T*)out.raw_data(),  //
                                                            (const T*)src.raw_data(),
                                                            out.stride(0),
                                                            src.stride(0),
                                                            f2n,
                                                            dim);
}

template
__global__ void MoeReduceKernel(T*           dst,         // [  n, d]
                                const T*     src,         // [e*n, d]
                                const T*     bias,        // [  E, d]
                                const float* scales,      // [  e, n]
                                const int*   en2f,        // [  e, n] :: (e,n) -> e*n
                                const int*   f2E,         // [  e* n]
                                const float* dst_scales,  // [n]
                                int          dim,
                                int          tokens,
                                T            bscale,
                                float        dst_scale)
{
    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v)) {
        const int64_t ti = blockIdx.x;

        dst += dim * ti;

        if (dst_scales) {
            dst_scale = dst_scales[ti];
            dst_scale = fdividef(1.f, 1.f + expf(-dst_scale));
        }

        // Should be warp uniforms
        const T* src_[exp_k];
        const T* bias_[exp_k];

        float scale[exp_k];

        PRAGMA_UNROLL
        for (int e = 0; e < exp_k; ++e) {
            int fid = __ldg(&en2f[e * tokens + ti]);
            src_[e] = src + dim * fid;
            if constexpr (has_bias) {
                bias_[e] = bias + __ldg(&f2E[fid]) * dim;
            }
            scale[e] = scales ? __ldg(&scales[e * tokens + ti]) : 1.f;
        }

        using Vec = Array;

        for (int i = threadIdx.x * vec_size; i < dim; i += block_dim * vec_size) {
            Array accum{};
            if (dst_scale) {
                Vec v;
                Load(v, &dst[i]);
                using namespace ops;
                accum = cast(v) * dst_scale;
            }
            PRAGMA_UNROLL
            for (int e = 0; e < exp_k; ++e) {
                Vec v;
                Load(v, src_[e] + i);
                using namespace ops;
                if constexpr (has_bias) {
                    Vec b;
                    Load(b, bias_[e] + i);
                    PRAGMA_UNROLL
                    for (int i = 0; i < vec_size; ++i) {
                        v[i] = __hfma(b[i], bscale, v[i]);
                    }
                }
                const auto x = cast(v) * scale[e];
                accum        = accum + x;
            }
            Store(&dst[i], cast(accum));
        }
    }
}

template
void invokeMoeReduce(T*           dst,
                     const T*     src,
                     const T*     bias,
                     const float* scales,
                     const int*   en2f,
                     const int*   f2E,
                     const float* dst_scales,
                     int          tokens,
                     int          experts_per_token,
                     int          dim,
                     T            bscale,
                     float        dst_scale,
                     cudaStream_t st)
{
    // std::cout << __PRETTY_FUNCTION__ << std::endl;

    const auto invoke = [&](auto e) {
        constexpr int threads     = 256;
        constexpr int vec_size    = 16 / sizeof(T);
        constexpr int exp_per_tok = decltype(e)::value;
        MoeReduceKernel<<>>(  //
            dst,
            src,
            bias,
            scales,
            en2f,
            f2E,
            dst_scales,
            dim,
            tokens,
            bscale,
            dst_scale);
    };

    switch (experts_per_token) {
        case 1:
            return invoke(std::integral_constant{});
        case 2:
            return invoke(std::integral_constant{});
        case 4:
            return invoke(std::integral_constant{});
        case 6:
            return invoke(std::integral_constant{});
        case 8:
            return invoke(std::integral_constant{});
        default:
            fprintf(stderr, "Unsupported experts_per_token %d\n", experts_per_token);
            std::abort();
    }
}

void invokeMoeCombine(Ref   out_,
                      const Tensor& src,
                      const Tensor& bias,
                      const float*  scales,
                      const int*    en2f,
                      const int*    f2E,
                      const float*  dst_scales,
                      int           experts_per_token,
                      float         bscale,
                      float         dst_scale,
                      cudaStream_t  st)
{
    auto& out = out_.get();

    const int tokens = out.shape(0);
    TM_CHECK_EQ(src.shape(0), tokens * experts_per_token);

    auto invoke = [&](auto has_bias, auto t) {
        using T = decltype(t);
        return invokeMoeReduce(out.data(),
                                               src.data(),
                                               bias.data_or((T*)nullptr),
                                               scales,
                                               en2f,
                                               f2E,
                                               dst_scales,
                                               tokens,
                                               experts_per_token,
                                               src.shape(1),
                                               (T)bscale,
                                               dst_scale,
                                               st);
    };

    auto dispatch_dtype = [&](auto t) {
        if (bias) {
            TM_CHECK_NOTNULL(f2E);
            return invoke(std::true_type{}, t);
        }
        else {
            return invoke(std::false_type{}, t);
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(src.dtype(), dispatch_dtype);
}

std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g)
{
    std::vector idxs((size_t)token_num * exp_per_tok);
    std::vector r(expert_num);
    std::iota(r.begin(), r.end(), 0);
    auto it = idxs.begin();
    for (int i = 0; i < token_num; ++i) {
        it = std::sample(r.cbegin(), r.cend(), it, exp_per_tok, g);
    }
    return idxs;
}

std::vector SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g)
{
    assert(exp_per_tok <= expert_num);
    std::vector idxs((size_t)token_num * exp_per_tok);
    std::vector q;

    std::vector r(expert_num);
    std::iota(r.begin(), r.end(), 0);

    auto it = idxs.begin();
    for (int i = 0; i < token_num; ++i) {
        if ((int)q.size() < exp_per_tok) {
            const int k = q.size();
            // prepend the experts: [xxx] -> [yyy | xxx]
            q.insert(q.begin(), r.cbegin(), r.cend());
            // move duplicated experts to the front: [yyy | xxx] -> [xxx' | yyy' | xxx]
            int p = 0;
            std::for_each(q.cend() - k, q.cend(), [&](auto x) { std::swap(q[p++], q[x]); });
            // shuffle unique experts yyy'
            std::shuffle(q.begin() + p, q.end() - k, g);
        }
        it = std::copy(q.end() - exp_per_tok, q.end(), it);
        // remove used experts [xxx' | yyy' | xxx ] -> [xxx' | zzz]
        q.resize(q.size() - exp_per_tok);
        // alias [xxx] <- [xxx' | zzz]
    }
    assert(it == idxs.end());

    // shuffle to decorrelate adjacent tokens
    r.resize(token_num);
    std::iota(r.begin(), r.end(), 0);
    std::shuffle(r.begin(), r.end(), g);
    std::vector ret(idxs.size());
    it = ret.begin();
    for (const auto& i : r) {
        it = std::copy_n(idxs.begin() + i * exp_per_tok, exp_per_tok, it);
    }
    assert(it == ret.end());
    return ret;
}

template
__global__ void MoeSoftmaxMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k)
{
    constexpr int threads_per_token = max_expert_num / items_per_thread;

    static_assert((threads_per_token & (threads_per_token - 1)) == 0);
    static_assert(items_per_thread % access_size == 0);

    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;

    const int ti = thread_idx / threads_per_token;
    const int ei = thread_idx % threads_per_token;

    float data[items_per_thread];
    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        data[i] = -std::numeric_limits::infinity();
    }
    // max logit in the group
    float max_val = -std::numeric_limits::infinity();
    if (ti < token_num) {
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; i += access_size) {
            const int e = ei * items_per_thread + i;  // blocked partition
            if (e < expert_num) {
                Ldg((Array&)data[i], &logits[ti * expert_num + e]);
                PRAGMA_UNROLL
                for (int c = 0; c < access_size; ++c) {
                    max_val = fmaxf(max_val, data[i + c]);
                }
            }
        }
    }

    const int warp_ti        = threadIdx.x % WARP_SIZE / threads_per_token;
    const int warp_ti_offset = warp_ti * threads_per_token;

    bool  alive     = false;
    float max_logit = 0;

    for (int k = 0; k < top_k; ++k) {
        int   g_max_ei  = ei;
        float g_max_val = max_val;
        PRAGMA_UNROLL
        for (int m = threads_per_token / 2; m >= 1; m /= 2) {
            g_max_val = fmaxf(g_max_val, __shfl_xor_sync((uint32_t)-1, g_max_val, m));
        }
        // tie breaking
        const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val);
        g_max_ei          = __ffs(active >> (unsigned)warp_ti_offset) - 1;
        if (k == 0) {
            max_logit = g_max_val;
        }
        if (ei == g_max_ei) {
            alive   = true;
            max_val = -std::numeric_limits::infinity();
        }
    }

    float sum_prob{};

    PRAGMA_NO_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        data[i] = expf(data[i] - max_logit);
        sum_prob += data[i];
    }

    PRAGMA_UNROLL
    for (int m = threads_per_token / 2; m >= 1; m /= 2) {
        sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);
    }

    // mask dead logits
    sum_prob = alive ? fdividef(1.f, sum_prob) : 0;

    PRAGMA_UNROLL
    for (int i = 0; i < items_per_thread; ++i) {
        data[i] *= sum_prob;
    }

    if (ti < token_num) {
        PRAGMA_UNROLL
        for (int i = 0; i < items_per_thread; i += access_size) {
            const int e = ei * items_per_thread + i;
            if (e < expert_num) {
                Store(&logits[ti * expert_num + e], (Array&)data[i]);
            }
        }
    }
}

void invokeMoeSoftmaxMaskTopKGroups(
    float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st)
{
    auto invoke = [&](auto max_expert_num, auto items_per_thread, auto vec_size) {
        constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value;
        constexpr int threads      = 256;
        const int     blocks       = ceil_div(token_num, threads / thrs_per_tok);
        MoeSoftmaxMaskTopKGroups
            <<>>(logits, token_num, expert_num, top_k);
    };

    if (expert_num == 160 && group_size == 20) {
        return invoke(_Int<160>, _Int<20>, _Int<4>);
    }

    std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << expert_num
              << ", group_size=" << group_size << "\n";
    std::abort();
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/moe_utils_v2.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

constexpr int kMoeGateMaxTiles = 16;
constexpr int kMoeGateVecSize  = 4;

void invokeMoeGate_V2(int*         f2n,
                      int*         f2E,
                      int*         en2f,
                      int*         offsets,
                      float*       scales,
                      void*        masks,
                      int*         accum,
                      const float* logits,
                      int          tokens,
                      int          tokens_padded,
                      int          experts,
                      int          exp_per_tok,
                      bool         softmax,
                      bool         norm_topk,
                      float        routed_scale,
                      cudaStream_t st);

void invokeMoeDispatch(Ref   out_,  //
                       const Tensor& src,
                       const int*    f2n,
                       int           expert_per_token,
                       cudaStream_t  st);

void invokeMoeDispatchScales(Ref   out_,  //
                             const Tensor& src,
                             const int*    f2n,
                             int           expert_per_token,
                             cudaStream_t  st);

void invokeMoeCombine(Ref   out_,
                      const Tensor& src,
                      const Tensor& bias,
                      const float*  scales,
                      const int*    en2f,
                      const int*    f2E,
                      const float*  dst_scales,
                      int           experts_per_token,
                      float         bscale,
                      float         dst_scale,
                      cudaStream_t  st);

void invokeMoeSoftmaxMaskTopKGroups(
    float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st);

/// noaux_tc routing: scores = scoring_func(logits), scores_for_choice = scores + correction_bias,
/// top-k on scores_for_choice, weights from scores; renormalize if norm_topk_prob; always apply routed_scale.
/// correction_bias may be nullptr (then treated as 0).
/// use_sigmoid: if true, scores = sigmoid(logits); if false, scores = softmax(logits).
void invokeMoeGate_NoAuxTC(int*         f2n,
                           int*         f2E,
                           int*         en2f,
                           int*         offsets,
                           float*       scales,
                           void*        masks,
                           int*         accum,
                           const float* logits,
                           const float* correction_bias,
                           int          tokens,
                           int          tokens_padded,
                           int          experts,
                           int          exp_per_tok,
                           bool         norm_topk_prob,
                           float        routed_scale,
                           bool         use_sigmoid,
                           cudaStream_t st);

// Sample `e` from `E` experts uniformly for every token
std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);

std::vector SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/operand.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/iterator.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

struct VoidOperand {
    using Dtype = int;

    static constexpr Pack  kPack  = 0;
    static constexpr Order kOrder = Order::kColMajor;

    struct GetSmemLayout {
        static constexpr SmemLayoutV2<1, 1> apply(...)
        {
            return {};
        }
    };

    using SmemCopyAtom = VoidSmemCopyAtom;

    struct GetGmemIter {
        static constexpr auto apply(...)
        {
            return type_c;
        }
    };
};

/// TODO: fix AlignC, AlignS
/// TODO: fix GroupSize
template
struct MakeOperand {

    using Dtype = typename Operand::Dtype;

    static constexpr Pack  kPack      = Operand::kPack;
    static constexpr Order kOrder     = Operand::kOrder;
    static constexpr int   kGroupSize = GroupSize;

    static constexpr int2 kPackMK = Packing_v2::apply({M, ceil_div(K, kGroupSize)});

    static constexpr pair kShapeMK{};

    using SmemLayout   = decltype(Operand::GetSmemLayout::apply(kShapeMK));
    using SmemAccessor = SmemAccessorV2;

    using GmemIter = typename decltype(Operand::GetGmemIter::apply(
        type_c, type_c, type_c, kShapeMK, constant{}))::type;

    using SmemCopyAtom = typename Operand::SmemCopyAtom;
};

// CPO for getting specific operand templates
template
struct GetOperand: std::false_type {
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/predicate.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

namespace turbomind::gemm {

template
struct Predicate {

    static constexpr int kSizeC = AlignedC ? 1 : C;

    static_assert(S * kSizeC <= 32);

    static constexpr bool is_active = true;

    uint32_t pred_{};

    __device__ int operator()(int s, int c) const
    {
        return (pred_ & (1 << (s * kSizeC + c))) != 0;
    }

    __device__ void set(int s, int c)
    {
        pred_ |= (1 << (s * kSizeC + c));
    }

    __device__ void clear()
    {
        pred_ = 0;
    }
};

template
struct Predicate {

    static constexpr bool is_active = false;

    __device__ constexpr std::integral_constant operator()(int, int) const
    {
        return {};
    }

    __device__ void set(int, int) {}

    __device__ void clear()
    {
        // pred_ = 0;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/registry.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/arch.h"
#include "src/turbomind/kernels/gemm/registry.h"

namespace turbomind::gemm {

Registry::Registry(std::shared_ptr device_prop):
    device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10}
{
    sm90_16816_4();
    sm90_16816_8();
    sm90_16816_16();

    sm80_16816_4();
    sm80_16816_8();
    sm80_16816_16();

    sm75_16816_4();
    sm75_16816_8();
    sm75_16816_16();

    sm70_884_4();
    sm70_884_8();
    sm70_884_16();

    sm90_64n32_8();

    cublas_float();
}

bool Registry::Add(std::unique_ptr kernel)
{
    bool is_valid = true;

    if (!is_arch_compatible(kernel->arch(), arch_)) {
        is_valid = false;
    }

    // if (is_valid) {
    //     std::cout << "register: " << kernel->name()                                        //
    //               << ", shared: " << (kernel->smem_size() >> 10) << " KB"                  //
    //               << ", regs: " << kernel->info().attr.numRegs                             //
    //               << ", local: " << (float)kernel->info().attr.localSizeBytes << " bytes"  //
    //               << ", max_active_ctas: " << kernel->info().max_active_ctas << " \n";
    // }

    if ((int)device_prop_->sharedMemPerBlockOptin < kernel->smem_size()) {
        is_valid = false;
    }

    if (is_valid) {
        ptrs_.push_back(kernels_.emplace_back(transpose(*kernel)).get());
        ptrs_.push_back(kernels_.emplace_back(std::move(kernel)).get());
    }

    return true;
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/registry.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/kernel_impl.h"
#include 

namespace turbomind::gemm {

class Registry {
public:
    explicit Registry(std::shared_ptr device_prop);

    /// TODO: remove this
    template
    [[maybe_unused]] bool Add()
    {
        return Add(std::make_unique>());
    }

    [[nodiscard]] const std::vector& kernels() const
    {
        return ptrs_;
    }

private:
    bool Add(std::unique_ptr kernel);

    void sm90_16816_4();
    void sm90_16816_8();
    void sm90_16816_16();

    void sm80_16816_4();
    void sm80_16816_8();
    void sm80_16816_16();

    void sm75_16816_4();
    void sm75_16816_8();
    void sm75_16816_16();

    void sm70_884_4();
    void sm70_884_8();
    void sm70_884_16();

    void sm90_64n32_8();

    void cublas_float();

private:
    std::shared_ptr      device_prop_;
    int                                  arch_;
    std::vector> kernels_;
    std::vector                 ptrs_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h
================================================
#pragma once

#include 

#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/sm90_utils.h"

namespace turbomind::gemm {

template
struct ScaledGmmaFP8_TN {

    static constexpr auto select_gmma_operation()
    {
        static_assert(TILE_M % (BATCH_M * PIPE_M) == 0);
        static_assert(TILE_N % (BATCH_N * PIPE_N) == 0);

        constexpr int M = TILE_M / (BATCH_M * PIPE_M);
        constexpr int N = TILE_N / (BATCH_N * PIPE_N);

        static_assert(M % 64 == 0);

        using namespace cute::SM90::GMMA;

        if constexpr (N % 256 == 0) {
            return type_c>;
        }
        else if constexpr (N % 224 == 0) {
            return type_c>;
        }
        else if constexpr (N % 192 == 0) {
            return type_c>;
        }
        else if constexpr (N % 160 == 0) {
            return type_c>;
        }
        else if constexpr (N % 128 == 0) {
            return type_c>;
        }
        else if constexpr (N % 96 == 0) {
            return type_c>;
        }
        else if constexpr (N % 64 == 0) {
            return type_c>;
        }
        else {
            static_assert(N == 0, "unsupported configuration");
        }
    }

    using Operation = typename decltype(select_gmma_operation())::type;

    static constexpr typename cute::MMA_Traits::Shape_MNK OP_Shape{};

    static constexpr int OP_M = cute::get<0>(OP_Shape);
    static constexpr int OP_N = cute::get<1>(OP_Shape);
    static constexpr int OP_K = cute::get<2>(OP_Shape);

    static constexpr int ITER_M = TILE_M / OP_M / BATCH_M / PIPE_M;
    static constexpr int ITER_N = TILE_N / OP_N / BATCH_N / PIPE_N;

    using FragU = float[ITER_M][PIPE_M][BATCH_M][2];
    using FragV = float[2];

    using FragC = typename Operation::CRegisters[PIPE_M][PIPE_N][BATCH_M][BATCH_N];

    using AccumC = FragC[ITER_M][ITER_N];

    static constexpr int kStepMA = (OP_M * TILE_K) >> 4;
    static constexpr int kStepNB = (OP_N * TILE_K) >> 4;
    static constexpr int kStepKA = (OP_K) >> 4;
    static constexpr int kStepKB = (OP_K) >> 4;

    static constexpr int OUTER_N = std::gcd(TILE_N, 128);

    template
    __device__ static void scale_batch_to_accum(AccumC&      accum_C,
                                                const FragC& frag_C,
                                                const FragU& frag_U,
                                                const FragV& frag_V,
                                                const PredV& pred_V,
                                                int          offset_V)
    {
        PRAGMA_UNROLL
        for (int m = 0; m < BATCH_M; ++m) {
            float scales[2][2];
            // TODO: check the compiler's ability to avoid re-computing this
            scales[0][0] = frag_U[m][0] * frag_V[0];
            scales[1][0] = frag_U[m][1] * frag_V[0];
            scales[0][1] = frag_U[m][0] * frag_V[1];
            scales[1][1] = frag_U[m][1] * frag_V[1];
            PRAGMA_UNROLL
            for (int n = 0; n < BATCH_N; ++n) {
                PRAGMA_UNROLL
                for (int c0 = 0; c0 < OP_N; c0 += OUTER_N) {
                    int  i = (offset_V + c0) / OUTER_N;
                    bool p = pred_V[i];
                    PRAGMA_UNROLL
                    for (int c1 = 0; c1 < OUTER_N; c1 += 8) {
                        int c = c0 + c1;
                        accum_C[m][n][c / 2 + 0] += (p ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 0];
                        accum_C[m][n][c / 2 + 1] += (p ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 1];
                        accum_C[m][n][c / 2 + 2] += (p ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 2];
                        accum_C[m][n][c / 2 + 3] += (p ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 3];
                    }
                }
            }
        }
    }

    __device__ static void warpgroup_wait(int n)
    {
        if (n == 0) {
            cute::warpgroup_wait<0>();
        }
        else if (n == 1) {
            cute::warpgroup_wait<1>();
        }
        else if (n == 2) {
            cute::warpgroup_wait<2>();
        }
        else if (n == 3) {
            cute::warpgroup_wait<3>();
        }
        else if (n == 4) {
            cute::warpgroup_wait<4>();
        }
        else if (n == 5) {
            cute::warpgroup_wait<5>();
        }
        else if (n == 6) {
            cute::warpgroup_wait<6>();
        }
        else if (n == 7) {
            cute::warpgroup_wait<7>();
        }
    }

    template
    __device__ static void gmma_batch(SmemIterA& iter_A, SmemIterB& iter_B, FragC& frag_C)
    {
        constexpr int BATCH_K = TILE_K / OP_K;
        PRAGMA_UNROLL
        for (int k = 0; k < BATCH_K; ++k) {
            PRAGMA_UNROLL
            for (int m = 0; m < BATCH_M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < BATCH_N; ++n) {
                    wgmma(iter_A, iter_B, frag_C[m][n], k == 0);
                    iter_B += kStepNB;
                }
                iter_B -= kStepNB * BATCH_N;
                iter_A += kStepMA;
            }
            iter_A -= kStepMA * BATCH_M;
            iter_A += kStepKA;
            iter_B += kStepKB;
        }
        iter_A -= kStepKA * BATCH_K;
        iter_B -= kStepKB * BATCH_K;
        cute::warpgroup_commit_batch();
    }

    template
    __device__ static void gmma_pipe(AccumC&      accum_C,
                                     SmemIterA&   iter_A,
                                     SmemIterB&   iter_B,
                                     FragC&       frag_C,
                                     const FragU& frag_U,
                                     const FragV& frag_V,
                                     const PredV& pred_V,
                                     int          offset_V)
    {
        cute::warpgroup_arrive();
        PRAGMA_UNROLL
        for (int m = 0; m < PIPE_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < PIPE_N; ++n) {
                gmma_batch(iter_A, iter_B, frag_C[m][n]);
                iter_B += kStepNB * BATCH_N;
            }
            iter_B -= kStepNB * BATCH_N * PIPE_N;
            iter_A += kStepMA * BATCH_M;
        }
        iter_A -= kStepMA * BATCH_M * PIPE_M;

        int i = 0;
        PRAGMA_UNROLL
        for (int m = 0; m < PIPE_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < PIPE_N; ++n, ++i) {
                warpgroup_wait(PIPE_M * PIPE_N - i - 1);
                int offset = offset_V + n * BATCH_N * OP_N;
                scale_batch_to_accum(accum_C[m][n], frag_C[m][n], frag_U[m], frag_V, pred_V, offset);
            }
        }
    }

    template
    __device__ static void apply(SmemIterA&   iter_A,
                                 SmemIterB&   iter_B,
                                 FragC&       frag_C,
                                 AccumC&      accum_C,
                                 const FragU& frag_U,
                                 const FragV& frag_V,
                                 const PredV& pred_V)
    {
        PRAGMA_UNROLL
        for (int m = 0; m < ITER_M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < ITER_N; ++n) {
                int offset_V = n * PIPE_N * BATCH_N * OP_N;
                gmma_pipe(accum_C[m][n], iter_A, iter_B, frag_C, frag_U[m], frag_V, pred_V, offset_V);
                iter_B += kStepNB * BATCH_N * PIPE_N;
            }
            iter_B -= kStepNB * BATCH_N * PIPE_N * ITER_N;
            iter_A += kStepMA * BATCH_M * PIPE_M;
        }
        iter_A -= kStepMA * BATCH_M * PIPE_M * ITER_M;
    }

    template
    __device__ static void foreach_C(Frag& frag, Func&& func)
    {
        PRAGMA_UNROLL
        for (int i_m = 0; i_m < ITER_M; ++i_m) {
            PRAGMA_UNROLL
            for (int i_n = 0; i_n < ITER_N; ++i_n) {
                PRAGMA_UNROLL
                for (int p_m = 0; p_m < PIPE_M; ++p_m) {
                    PRAGMA_UNROLL
                    for (int p_n = 0; p_n < PIPE_N; ++p_n) {
                        PRAGMA_UNROLL
                        for (int b_m = 0; b_m < BATCH_M; ++b_m) {
                            PRAGMA_UNROLL
                            for (int b_n = 0; b_n < BATCH_N; ++b_n) {
                                int m = ((i_m * PIPE_M) + p_m * BATCH_M) + b_m;
                                int n = ((i_n * PIPE_N) + p_n * BATCH_N) + b_n;
                                func(frag[i_m][i_n][p_m][p_n][b_m][b_n], m, n);
                            }  // BATCH_N
                        }      // BATCH_M
                    }          // PIPE_N
                }              // PIPE_M
            }                  // ITER_N
        }                      // ITER_M
    }

    template
    __device__ static void foreach_m(Frag& frag, Func&& func)
    {
        PRAGMA_UNROLL
        for (int i_m = 0; i_m < ITER_M; ++i_m) {
            PRAGMA_UNROLL
            for (int p_m = 0; p_m < PIPE_M; ++p_m) {
                PRAGMA_UNROLL
                for (int b_m = 0; b_m < BATCH_M; ++b_m) {
                    int m = ((i_m * PIPE_M) + p_m * BATCH_M) + b_m;
                    func(frag[i_m][p_m][b_m], m);
                }
            }
        }
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/scheduler.cuh
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "cutlass/fast_math.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/types.h"

#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"

namespace turbomind::gemm {

TM_DEVICE void mbarrier_arrive_cluster(uint64_t* mbar, int cta_id, int pred)
{
    uint32_t smem_addr = cast_smem_ptr_to_uint(mbar);
    if (pred) {
        asm volatile("{\n"
                     ".reg .b32 remAddr32;\n"
                     "mapa.shared::cluster.u32  remAddr32, %0, %1;\n"
                     "mbarrier.arrive.release.cluster.shared::cluster.b64  _, [remAddr32];\n"
                     "}"
                     :
                     : "r"(smem_addr), "r"(cta_id));
    }
}

TM_DEVICE void mbarrier_wait_cluster(uint64_t* mbar, uint32_t phase)
{
    uint32_t smem_addr = cast_smem_ptr_to_uint(mbar);
    uint32_t ticks     = 0x989680;
    asm volatile("{\n"
                 ".reg .pred       P1; \n"
                 "LAB_WAIT: \n"
                 "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1, %2; \n"
                 "@P1 bra DONE; \n"
                 "bra     LAB_WAIT; \n"
                 "DONE: \n"
                 "}"
                 :
                 : "r"(smem_addr), "r"(phase), "r"(ticks));
}

TM_DEVICE void* map_to_cta(void* ptr, int cta_id)
{
    void* ret;
    asm volatile("mapa.u64 %0, %1, %2;\n" : "=l"(ret) : "l"(ptr), "r"(cta_id));
    return ret;
}

TM_DEVICE void st_shared_cluster(uint32_t ptr, int value)
{
    asm volatile("st.shared::cluster.s32 [%0], %1;\n" ::"r"(ptr), "r"(value));
}

template
constexpr int member_offset(M T::*member)
{
    return reinterpret_cast(&(reinterpret_cast(0)->*member));
}

template
struct TileScheduler {

    static constexpr bool is_dynamic = 1;  // is_grouped_gemm;
    static constexpr int  Stages     = Stages_;

    static constexpr int2 tile_{tile_m, tile_n};
    static constexpr int2 cluster_tile_{tile_m * Cluster::M, tile_n* Cluster::N};

    int4 gemm_shape_;
    int2 tiled_shape_;

    int log_tile_;
    int k_iters_;

    int2 tile_offset_;
    int2 iter_k_range_;

    int clusters_;

    //////// v2 /////
    int2 swizzle_unit_;
    int2 cluster_tiles_;
    int2 padded_cluster_tiles_;
    int2 swizzled_cluster_tiles_;

    cutlass::FastDivmod swizzle_tile_x_;
    /////////////

    const int* offsets_;

    int* next_cluster_id_;

    using PipelineState = cutlass::PipelineState;

    struct Tile0 {
        int is_valid_cta;
        int is_valid_cluster;
        int offset_m;
        int offset_n;
        int alive;
    };

    struct Tile1 {
        int is_valid_cta;
        int is_valid_cluster;
        int offset_m;
        int offset_n;
        int alive;
        int group_idx;
        int m0;
        int m1;
    };

    using Tile = std::conditional_t;

    struct Storage {
        Tile tile[Stages];
        __align__(8) uint64_t producer_bar[Stages];
        __align__(8) uint64_t consumer_bar[Stages];
    };

    struct ConsumerState {
        PipelineState  pipe;
        Storage&       store;
        TileScheduler& sched;

        TM_DEVICE bool acquire(Tile*& tile)
        {
            return sched.acquire(*this, tile);
        }

        TM_DEVICE void release(int step = 1)
        {
            return sched.release(*this, step);
        }
    };

    struct ProducerState {
        PipelineState  pipe;
        int            group_id_offset;
        int            cluster_idx;
        Storage&       store;
        TileScheduler& sched;

        TM_DEVICE bool next()
        {
            return sched.next(*this);
        }
    };

public:
    TM_DEVICE void init_dyanmic(Storage& store, int consumer_num)
    {
        for (int i = 0; i < Stages; ++i) {
            cutlass::arch::ClusterBarrier::init(&store.producer_bar[i], 1);
            cutlass::arch::ClusterBarrier::init(&store.consumer_bar[i], consumer_num);
        }
        // cutlass::arch::ClusterBarrier::init(&store.sync_bar, 1);
    }

    TM_HOST_DEVICE void init(int4 gemm_shape, int log_tile, int3 tile_shape)
    {
        gemm_shape_ = gemm_shape;

        // printf("gemm shape: %d %d %d\n", gemm_shape.x, gemm_shape.y, gemm_shape.z);

        log_tile_ = log_tile;
        k_iters_  = cdiv(gemm_shape_.z, tile_shape.z);

        tiled_shape_.x = cdiv(gemm_shape.x, tile_.x);
        tiled_shape_.y = cdiv(gemm_shape.y, tile_.y);

        cluster_tiles_.x = cdiv(gemm_shape.x, cluster_tile_.x);  // useless
        cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y);

        // printf("cluster tiles: %d %d\n", cluster_tiles_.x, cluster_tiles_.y);

        if constexpr (is_grouped_gemm) {
            {
                int2 unit     = get_swizzled_shape({1, 1}, log_tile);
                swizzle_unit_ = order == kColMajor ? int2{unit.y, unit.x} : int2{unit.x, unit.y};
            }

            // col {8, 1}, row {1, 8}
            // printf("swizzle unit: %d %d\n", swizzle_unit_.x, swizzle_unit_.y);

            swizzle_tile_x_ = cluster_tile_.x * swizzle_unit_.x;

            int num = gemm_shape_.w;

            // num of tiles won't change after swizzle
            padded_cluster_tiles_.x = (num + gemm_shape.x / (cluster_tile_.x * swizzle_unit_.x)) * swizzle_unit_.x;
            padded_cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y * swizzle_unit_.y) * swizzle_unit_.y;

            // printf("padded   cluster tiles: %d %d\n", padded_cluster_tiles_.x, padded_cluster_tiles_.y);

            swizzled_cluster_tiles_ = get_swizzled_shape(padded_cluster_tiles_, log_tile);

            // printf("swizzled cluster tiles: %d %d\n", swizzled_cluster_tiles_.x, swizzled_cluster_tiles_.y);

            clusters_ = padded_cluster_tiles_.x * padded_cluster_tiles_.y;

            // printf("clusters = %d\n", clusters_);
            // M is runtime value
        }
        else {
            tiled_shape_.x = cdiv(gemm_shape.x, tile_.x);
            tiled_shape_.y = cdiv(gemm_shape.y, tile_.y);

            cluster_tiles_.x = cdiv(gemm_shape.x, cluster_tile_.x);
            cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y);

            swizzled_cluster_tiles_ = get_swizzled_shape(cluster_tiles_, log_tile);

            swizzle_tile_x_ = swizzled_cluster_tiles_.x;

            clusters_ = swizzled_cluster_tiles_.x * swizzled_cluster_tiles_.y;
        }
    }

    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)
    {
        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);
    }

    TM_HOST_DEVICE static int2 get_swizzled_shape(int2 tiled_shape, int log_tile)
    {
        const int tile = 1 << log_tile;

        if constexpr (order == kColMajor) {
            return {tiled_shape.x * tile, (tiled_shape.y + tile - 1) >> log_tile};
        }
        else {
            return {tiled_shape.y * tile, (tiled_shape.x + tile - 1) >> log_tile};
        }
    }

    TM_DEVICE ProducerState init_producer(Storage& store)
    {
        int cluster_id = 0;
        if constexpr (!is_dynamic) {
            cluster_id = (int)cute::cluster_id_in_grid().x;
        }
        return {
            PipelineState{0, 1, 0},
            0,
            cluster_id,
            store,
            *this,
        };
    }

    TM_DEVICE ConsumerState init_consumer(Storage& store)
    {
        return {
            PipelineState{},
            store,
            *this,
        };
    }

    TM_DEVICE void
    unswizzle(Tile& tile, int cluster_idx, int cta_id, int2 cta_tiles, int2 cluster_tiles, int2 swizzle_tiles) const
    {
        int cluster_idx_x, cluster_idx_y;

        if constexpr (is_grouped_gemm) {
            cluster_idx_x = cluster_idx % swizzle_tiles.x;
            cluster_idx_y = cluster_idx / swizzle_tiles.x;
        }
        else {
            swizzle_tile_x_(cluster_idx_y, cluster_idx_x, cluster_idx);
        }

        auto [cluster_cta_m, cluster_cta_n] = Cluster::cta_mn(cta_id);

        const int offset_x = cluster_cta_m * (striped_m ? cluster_tiles.x : 1);
        const int offset_y = cluster_cta_n * (striped_n ? cluster_tiles.y : 1);

        int2 cluster_tile_offset;

        if constexpr (order == kColMajor) {
            cluster_tile_offset = {(cluster_idx_x >> log_tile_),
                                   (cluster_idx_y << log_tile_) + (cluster_idx_x & ((1 << log_tile_) - 1))};
        }
        else {
            cluster_tile_offset = {(cluster_idx_y << log_tile_) + (cluster_idx_x & ((1 << log_tile_) - 1)),
                                   (cluster_idx_x >> log_tile_)};
        }

        // `tile` may be on DSMEM
        int tile_idx_x        = offset_x + cluster_tile_offset.x * (striped_m ? 1 : Cluster::M);
        int tile_idx_y        = offset_y + cluster_tile_offset.y * (striped_n ? 1 : Cluster::N);
        tile.offset_m         = tile_idx_x * tile_.x;
        tile.offset_n         = tile_idx_y * tile_.y;
        int valid_cluster_p   = cluster_tile_offset.x < cluster_tiles.x && cluster_tile_offset.y < cluster_tiles.y;
        tile.is_valid_cta     = valid_cluster_p && tile_idx_x < cta_tiles.x && tile_idx_y < cta_tiles.y;
        tile.is_valid_cluster = valid_cluster_p;
    }

    TM_DEVICE int get_start_index(int g) const
    {
        // return (__ldg(&offsets_[g]) / (cluster_tile_.x * swizzle_unit_.x) + g) * swizzle_unit_.x
        //        * padded_cluster_tiles_.y;
        return (swizzle_tile_x_.div(__ldg(&offsets_[g])) + g) * swizzle_unit_.x * padded_cluster_tiles_.y;
    }

    TM_DEVICE bool update_sync(int   cluster_idx,
                               int&  group_id_offset,
                               int&  group_idx,
                               int&  group_beg,
                               int&  group_m0,
                               int&  group_m1,
                               int2& tiled_shape,
                               int2& cluster_tiles,
                               int2& swizzled_tiles) const
    {
        const int lane_id = threadIdx.x % WARP_SIZE;

        uint32_t mask;
        while (true) {
            int e    = group_id_offset + lane_id;
            int pred = e > gemm_shape_.w || cluster_idx < get_start_index(e);
            mask     = __ballot_sync((uint32_t)-1, pred);
            if (mask) {
                break;
            }
            group_id_offset += WARP_SIZE;
        }

        // 32 - clz(~mask) - 1
        group_idx = group_id_offset + 31 - __clz(~mask);

        group_m0 = __ldg(&offsets_[group_idx]);
        group_m1 = __ldg(&offsets_[group_idx + 1]);
        int m    = group_m1 - group_m0;

        group_beg = get_start_index(group_idx);

        tiled_shape.x   = cdiv(m, tile_.x);
        cluster_tiles.x = cdiv(m, cluster_tile_.x);

        swizzled_tiles = get_swizzled_shape(cluster_tiles, log_tile_);

        return true;
    }

    TM_DEVICE bool next(ProducerState& state)
    {
        const int lane_id = cutlass::canonical_lane_idx();

        auto& store = state.store;
        auto& pipe  = state.pipe;

        int cluster_idx{};

        if constexpr (is_dynamic) {
            if (lane_id == 0) {
                cutlass::arch::ClusterBarrier::wait(&store.consumer_bar[pipe.index()], pipe.phase());
                cluster_idx = atomicAdd(next_cluster_id_, 1);
            }
            cluster_idx = __shfl_sync((uint32_t)-1, cluster_idx, 0);
        }
        else {
            cutlass::arch::ClusterBarrier::wait(&store.consumer_bar[pipe.index()], pipe.phase());
            cluster_idx = state.cluster_idx;
            state.cluster_idx += (int)cute::cluster_grid_dims().x;
        }

        Tile* tile{};

        if constexpr (Cluster::size == 1) {
            tile = &store.tile[pipe.index()];
        }
        else {
            if (lane_id < Cluster::size) {
                tile = (Tile*)map_to_cta(&store.tile[pipe.index()], lane_id);
            }
        }

        const int alive = cluster_idx < clusters_;

        if (alive) {
            int  group_id      = 0;
            int  group_beg     = 0;
            int  group_m0      = 0;
            int  group_m1      = 0;
            auto cta_tiles     = tiled_shape_;
            auto cluster_tiles = cluster_tiles_;
            auto swizzle_tiles = swizzled_cluster_tiles_;
            if constexpr (is_grouped_gemm) {
                update_sync(cluster_idx,  //
                            state.group_id_offset,
                            group_id,
                            group_beg,
                            group_m0,
                            group_m1,
                            cta_tiles,
                            cluster_tiles,
                            swizzle_tiles);
            }
            if (lane_id < Cluster::size) {
                unswizzle(*tile,  //
                          cluster_idx - group_beg,
                          lane_id,
                          cta_tiles,
                          cluster_tiles,
                          swizzle_tiles);
                if constexpr (is_grouped_gemm) {
                    tile->group_idx = group_id;
                    tile->m0        = group_m0;
                    tile->m1        = group_m1;
                }
            }
        }

        if (lane_id < Cluster::size) {
            tile->alive = alive;
        }

        if constexpr (Cluster::size == 1) {
            if (lane_id == 0) {
                cutlass::arch::ClusterBarrier::arrive(&store.producer_bar[pipe.index()]);
            }
        }
        else {
            mbarrier_arrive_cluster(&store.producer_bar[pipe.index()], lane_id, lane_id < Cluster::size);
        }

        ++pipe;

        return alive;
    }

    TM_DEVICE void tail(ProducerState& state)
    {
        if constexpr (Cluster::size > 1) {
            for (int i = 0; i < Stages; ++i) {
                cutlass::arch::ClusterBarrier::wait(&state.store.consumer_bar[state.pipe.index()], state.pipe.phase());
                ++state.pipe;
            }
        }
    }

    TM_DEVICE bool acquire(ConsumerState& state, Tile*& tile)
    {
        auto& store = state.store;
        auto& pipe  = state.pipe;

        if constexpr (Cluster::size == 1) {
            cutlass::arch::ClusterBarrier::wait(&store.producer_bar[pipe.index()], pipe.phase());
        }
        else {
            mbarrier_wait_cluster(&store.producer_bar[pipe.index()], pipe.phase());
        }

        tile = &store.tile[pipe.index()];

        return tile->alive;
    }

    TM_DEVICE void release(ConsumerState& state, int step)
    {
        auto& store = state.store;
        auto& pipe  = state.pipe;

        __syncwarp();

        if constexpr (Cluster::size == 1) {
            if (cutlass::elect_one_sync()) {
                cutlass::arch::ClusterBarrier::arrive(&store.consumer_bar[pipe.index()]);
            }
        }
        else {
            cutlass::arch::ClusterBarrier::arrive(&store.consumer_bar[pipe.index()], 0, cutlass::elect_one_sync());
        }

        pipe.advance(step);
    }

    TM_DEVICE int4 gemm_shape() const
    {
        return gemm_shape_;
    }

    TM_DEVICE int2 tiled_shape() const
    {
        return tiled_shape_;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/scheduler_sm70.cuh
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/kernels/gemm/cta_map.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

template
struct SchedulerSm70 {

    static constexpr int group_axis = group_axis_;

    static constexpr Array tile_shape{tile_m, tile_n, tile_k};

    static_assert(chunk_k % tile_k == 0);
    static constexpr int chunk_iters = chunk_k / tile_k;

    Array gemm_shape_;
    Array tiles_;

    int log_tile_;

    int split_chunks_;
    int chunk_offset_;

    const int* offsets_;

    struct Tile {
        Array tile_id;
        Array shape;
        Array k_iters;

        int group_id;
        int linear_tile_id;
    };

    struct SharedStorage {
        int group_id;
        int dynamic_dim;
        int base_tile_id;
    };

    __host__ dim3 get_grid_shape()
    {
        auto shape = get_swizzled_shape(tiles_, log_tile_);
        return dim3(shape[0], shape[1], shape[2]);
    }

    __host__ SchedulerSm70(Array gemm_shape, int log_tile = 0, int splits = 1):
        gemm_shape_{gemm_shape}, log_tile_{log_tile}
    {
        tiles_[0] = cdiv(gemm_shape[0], tile_m);
        tiles_[1] = cdiv(gemm_shape[1], tile_n);
        tiles_[2] = splits;

        log_tile_ = log_tile;

        Array log_unit{};
        log_unit[1 - (int)order] = log_tile;

        tiles_[0] = round_up(tiles_[0], 1 << log_unit[0]);
        tiles_[1] = round_up(tiles_[1], 1 << log_unit[1]);

        // printf("gemm shape: %d %d %d %d\n", gemm_shape_[0], gemm_shape_[1], gemm_shape_[2], gemm_shape_[3]);
        // printf("tile shape: %d %d %d\n", tile_shape[0], tile_shape[1], tile_shape[2]);

        if constexpr (group_axis >= 0) {
            constexpr int i = group_axis;
            // overwrite dynamic axis <- estimated upper bound
            tiles_[i] = ((gemm_shape_[i] / tile_shape[i] >> log_unit[i]) + gemm_shape_[3]) << log_unit[i];
        }

        int chunks    = cdiv(gemm_shape[2], chunk_k);
        split_chunks_ = chunks / splits;
        chunk_offset_ = splits - chunks % splits;
    }

    __device__ int2 get_group_offset(int g)
    {
        constexpr int i = group_axis;

        Array log_unit{};
        log_unit[1 - (int)order] = log_tile_;

        int offset      = __ldg(offsets_ + g);
        int tile_offset = ((offset / tile_shape[i] >> log_unit[i]) + g) << log_unit[i];

        return {offset, tile_offset};
    }

    __device__ int find_group(Array& tile_id, SharedStorage& storage)
    {
        constexpr int axis = group_axis;

        int success = 0;

        const int block_dim = blockDim.x;

        for (int g = threadIdx.x; g < gemm_shape_[3]; g += block_dim) {
            auto [beg, beg_tile] = get_group_offset(g);
            auto [end, end_tile] = get_group_offset(g + 1);

            if (beg_tile <= tile_id[axis] && tile_id[axis] < end_tile) {
                storage.group_id     = g;
                storage.dynamic_dim  = end - beg;
                storage.base_tile_id = beg_tile;
                success              = 1;
            }

            if (tile_id[axis] < end_tile) {
                break;
            }
        }

        return __syncthreads_or(success);
    }

    template
    __device__ int init(Tile& tile, SharedStorage& storage, Reinit)
    {
        Array cta_id{(int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z};
        Array tile_id = unswizzle(cta_id);
        Array shape{gemm_shape_[0], gemm_shape_[1], gemm_shape_[2]};

        tile.group_id       = 0;
        tile.linear_tile_id = tile_id[1 - (int)order] * tiles_[(int)order] + tile_id[(int)order];

        constexpr int axis = group_axis;

        if constexpr (axis >= 0) {
            if (offsets_) {
                if constexpr (!Reinit::value) {
                    if (!find_group(tile_id, storage)) {
                        return false;
                    }
                }
                tile_id[axis] -= storage.base_tile_id;
                shape[axis]   = storage.dynamic_dim;
                tile.group_id = storage.group_id;
                // Crucial for the values above to be recognized as warp uniform, `__syncwarp()`
                // does not prevent modifying CTA scope SMEM from other warps
                __syncthreads();
            }
        }

        if constexpr (split_k) {
            int split_id    = tile_id[2];
            int chunk_id    = split_id * split_chunks_ + max(split_id - chunk_offset_, 0);
            tile.k_iters[0] = chunk_id * chunk_iters;
            tile.k_iters[1] = (split_chunks_ + int(split_id >= chunk_offset_)) * chunk_iters;
        }
        else {
            tile.k_iters[0] = 0;
            tile.k_iters[1] = split_chunks_ * chunk_iters;
        }

        tile.tile_id = tile_id;
        tile.shape   = shape;

        return true;
    }

    __device__ Array unswizzle(Array cta_id)
    {
        int tile_c = cta_id[0] >> log_tile_;
        int tile_s = cta_id[1] << log_tile_ | (cta_id[0] & ((1 << log_tile_) - 1));

        Array tile_id;

        tile_id[(int)order]     = tile_c;
        tile_id[1 - (int)order] = tile_s;

        tile_id[2] = cta_id[2];

        return tile_id;
    }

    __host__ __device__ static Array get_swizzled_shape(Array tiles, int log_tile)
    {
        constexpr int i = (int)order;  // expansion axis
        return {tiles[i] << log_tile, (tiles[1 - i] + (1 << log_tile) - 1) >> log_tile, tiles[2]};
    }

    __host__ int get_max_swizzle()
    {
        constexpr int axis = 1 - (int)order;

        int n = tiles_[axis];

        if (group_axis == axis) {
            n = cdiv(n, gemm_shape_[3]);
        }

        return get_log_tile(n);
    }

    __host__ __device__ static int get_log_tile(int size)
    {
        if (size >= 24)
            return 5;
        if (size >= 12)
            return 4;
        if (size >= 6)
            return 3;
        if (size >= 3)
            return 2;
        if (size >= 2)
            return 1;
        return 0;
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/simt.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind::gemm::simt {

// constexpr int OP_M = 2;
// constexpr int OP_N = 16;
// constexpr int OP_K = 4;

// constexpr int OP_M = 4;
// constexpr int OP_N = 8;
// constexpr int OP_K = 8;

constexpr int OP_M = 1;
constexpr int OP_N = 32;
constexpr int OP_K = 8;

}  // namespace turbomind::gemm::simt


================================================
FILE: src/turbomind/kernels/gemm/sm90_utils.h
================================================


#pragma once

#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_traits.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/smem.h"

#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

namespace GMMA = cute::SM90::GMMA;

inline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)
{
    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);

    cute::GmmaDescriptor desc{};
    desc.bitfield.start_address_       = uint_ptr >> 4;
    desc.bitfield.layout_type_         = layout_type;
    desc.bitfield.leading_byte_offset_ = 0;
    desc.bitfield.stride_byte_offset_  = 1024 >> 4;
    desc.bitfield.base_offset_         = 0;

    return desc;
}

template
struct SmemDescIterV2 {
    union {
        uint32_t u32_[2];
        uint64_t u64_;
    };

    uint32_t base_;

    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}

    __device__ void Advance(int stage)
    {
        u32_[0] += Step;
        if (stage == Stages - 1) {
            u32_[0] = base_;
        }
    }

    __device__ void Reset(int stage)
    {
        u32_[0] = base_ + stage * Step;
    }

    __device__ SmemDescIterV2& operator+=(int offset)
    {
        u32_[0] += offset;
        return *this;
    }

    __device__ SmemDescIterV2& operator-=(int offset)
    {
        u32_[0] -= offset;
        return *this;
    }

    __device__ operator uint64_t()
    {
        return u64_;
    }
};

template
inline __device__ void
wgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence)
{
    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);
}

template
inline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)
{
    return wgmma_impl(desc_a, desc_b, frag_C, clear, std::make_index_sequence{});
}

inline __device__ void warpgroup_fence_operand(float& reg)
{
    asm volatile("" : "+f"(reg)::"memory");
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])
{
    PRAGMA_UNROLL
    for (int m = 0; m < M; ++m) {
        PRAGMA_UNROLL
        for (int n = 0; n < N; ++n) {
            PRAGMA_UNROLL
            for (int k = 0; k < K; ++k) {
                warpgroup_fence_operand(x[m][n][k]);
            }
        }
    }
}

template
inline __device__ void warpgroup_fence_operand(float (&x)[N][K])
{
    PRAGMA_UNROLL
    for (int n = 0; n < N; ++n) {
        PRAGMA_UNROLL
        for (int k = 0; k < K; ++k) {
            warpgroup_fence_operand(x[n][k]);
        }
    }
}

template
__device__ void for_(std::index_sequence, Func func)
{
    return (func(constant{}), ...);
}

namespace arch {

template
struct Cluster {
    static constexpr int M = M_;
    static constexpr int N = N_;

    static constexpr int C = mk2cs(M, N).x;
    static constexpr int S = mk2cs(M, N).y;

    static constexpr int size = M * N;

    static constexpr uint16_t kMaskC = (1 << C) - 1;
    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;

    __device__ static ushort2 mask_cs(int cta_id)
    {
        const auto [c, s] = cta_cs(cta_id);
        return make_ushort2(kMaskS << c, kMaskC << s * C);
    }

    __device__ static ushort2 mask_mn(int cta_id)
    {
        auto [c, s] = mask_cs(cta_id);
        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};
    }

    __device__ static int2 cta_cs(int cta_id)
    {
        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};
    }

    __device__ static int2 cta_mn(int cta_id)
    {
        return cs2mk(cta_cs(cta_id));
    }

    int2    cta_mn_;
    ushort2 mask_mn_;

    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}

    __device__ int cta_m()
    {
        return cta_mn_.x;
    }

    __device__ int cta_n()
    {
        return cta_mn_.y;
    }

    __device__ uint16_t mask_m()
    {
        return mask_mn_.x;
    }

    __device__ uint16_t mask_n()
    {
        return mask_mn_.y;
    }
};

}  // namespace arch

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/smem_copy.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

struct VoidSmemCopyAtom {

    static constexpr int M = 1;
    static constexpr int K = 1;

    static constexpr int kFragNum = 1;

    using Frag = Array;

    template
    __device__ static void copy(S, D, bool)
    {
    }

    __device__ static int2 get_offset(int)
    {
        return {};
    }

    __device__ static int2 unique(int thread_idx, int pack_idx)
    {
        return {};
    }
};

template
struct SmemAccessorV2 {
};

template
struct SmemAccessorV2: SmemAccessor {
    using SmemAccessor::SmemAccessor;
};

template
struct SmemAccessorV2 {
    SmemAccessor base_;

    __device__ SmemAccessorV2(get_pointer_type ptr): base_{ptr} {}
    __device__ T& operator()(int m, int k)
    {
        return base_(k, m);
    }
};

template
struct SmemCopyAtom_Pack_v2 {
    static constexpr int M = M_;
    static constexpr int K = K_;

    static constexpr int kFragNum = FragNum_;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)
    {
        const int lane_id = thread_idx % WARP_SIZE;

        const int c = lane_id / RepeatC * Frag::size();

        return order == kRowMajor ? int2{0, c} : int2{c, 0};
    }

    template
    __device__ static void copy(S src_ptr, D dst_ptr, bool mask)
    {
        auto dst_raw_ptr = (T*)dst_ptr;  // SubBytePtr -> T*
        if (mask) {
            Lds(*(Frag*)dst_raw_ptr, src_ptr);
        }
    }
};

template
struct SmemCopyAtom_Pack_v3 {
    static constexpr int M = CopyAtom::M * FragNum_;
    static constexpr int K = CopyAtom::K;

    static constexpr int kFragNum = FragNum_;

    using Frag = Array;

    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)
    {
        const int c = CopyAtom::unique(thread_idx, 0).x * Frag::size();

        return order == kRowMajor ? int2{0, c} : int2{c, 0};
    }

    template
    __device__ static void copy(S src_ptr, D dst_ptr, bool mask)
    {
        if (mask) {
            auto dst_raw_ptr = (T*)dst_ptr;  // SubBytePtr -> T*
            Lds(*(Frag*)dst_raw_ptr, src_ptr);
        }
    }
};

template
struct SmemCopy {
    using Atom = typename Operand::SmemCopyAtom;

    static constexpr int kFragNum = Atom::kFragNum;

    static constexpr int ITER_M = iM / Atom::kFragNum;

    static_assert(ITER_M > 0);

    using Frag = typename Atom::Frag[ITER_M];

    using Pack = Packing_v2;

    static constexpr int2 delta = Pack::apply(int2{dM * kFragNum, dK});

    using Layout = typename Operand::SmemLayout;

    static constexpr int2 kMK0 = cs2mk(Layout::C0, Layout::S0);

    static constexpr int kPeriodM = ceil_div(kMK0.x, delta.x);
    static constexpr int kPeriodK = ceil_div(kMK0.y, delta.y);

    const int2 offset_;

    int phases_[kPeriodK][kPeriodM];

    __device__ SmemCopy(int2 offset): offset_{offset}
    {
        const int2 thr = Atom::get_offset(threadIdx.x);
        PRAGMA_UNROLL
        for (int k = 0; k < kPeriodK; ++k) {
            PRAGMA_UNROLL
            for (int m = 0; m < kPeriodM; ++m) {
                const int2 pack = Pack::apply({offset.x + m * dM * kFragNum, offset.y + k * dK});
                const int2 cs   = mk2cs({pack.x + thr.x, pack.y + thr.y});
                phases_[k][m]   = Layout::apply(cs.y, cs.x);
            }
        }
    }

    template
    __device__ void operator()(Pointer src_ptr, Frag& dst, int k, bool mask = true)
    {
        using Accessor = typename Operand::SmemAccessor;
        if constexpr (Operand::kGroupSize == 1) {
            PRAGMA_UNROLL
            for (int m = 0; m < ITER_M; ++m) {
                const int  mm = m / kPeriodM * kPeriodM * dM * kFragNum;
                const int  kk = k / kPeriodK * kPeriodK * dK;
                const int2 cs = mk2cs(Pack::apply(int2{mm, kk}));
                const int  i0 = Layout::apply(cs.y, cs.x);
                const int  i1 = phases_[k % kPeriodK][m % kPeriodM];
                Atom::copy(&src_ptr[i0 + i1], dst[m].data(), mask);
            }
        }
        else {  // generic case
            Accessor   smem{src_ptr};
            const int2 thr = Atom::get_offset(threadIdx.x);
            PRAGMA_UNROLL
            for (int m = 0; m < ITER_M; ++m) {
                const int  mm = offset_.x + m * dM * kFragNum;
                const int  kk = offset_.y + k * dK;  // Note: this forbids sub-tile group sizes
                const int2 mk = Pack::apply(int2{mm, kk / Operand::kGroupSize});
                Atom::copy(&smem(mk.x + thr.x, mk.y + thr.y), dst[m].data(), mask);
            }
        }
        // else if constexpr (Operand::kPack != 0 && Operand::kGroupSize != 1) {  // group size = 1, pack != 0
        //     const int  mask_k = Operand::kGroupSize == 1;
        //     const int2 pack   = Pack::apply(int2{offset_.x, offset_.y});
        //     const int2 thr    = Atom::get_offset(threadIdx.x);
        //     const int2 cs     = mk2cs({pack.x + thr.x, (pack.y + thr.y) * mask_k});
        //     auto       smem   = src_ptr + Layout::apply(cs.y, cs.x);
        //     PRAGMA_UNROLL
        //     for (int m = 0; m < ITER_M; ++m) {
        //         const int  mm  = m * dM * kFragNum;
        //         const int  kk  = k * dK;
        //         const int2 cs  = mk2cs(Pack::apply(int2{mm, kk * mask_k}));
        //         const int  idx = Layout::apply(cs.y, cs.x);
        //         Atom::copy(&smem[idx], dst[m].data(), mask);
        //     }
        // }
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/test/gemm_bench.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "nvbench/main.cuh"
#include "src/turbomind/kernels/gemm/operand.h"
#include "src/turbomind/kernels/gemm/test/models.h"
#include "src/turbomind/kernels/gemm/test/testbed.h"
#include 
#include 
#include 
#include 

void gemm_bench(nvbench::state& state)
{
    const auto idx = state.get_int64("idx");

    const auto bs = state.get_int64("bs");
    const auto tp = state.get_int64("tp");

    const auto expert_num  = state.get_int64("e_num");
    const auto exp_per_tok = state.get_int64("e_tok");

    auto [output_dims, input_dims] = config[idx];

    constexpr int group_size = 128;

    if (idx % 4 == 0 || idx % 4 == 2) {
        if (output_dims % tp)
            return;
        output_dims /= tp;
    }
    else {
        if (input_dims % tp)
            return;
        input_dims /= tp;
    }

    if (input_dims % group_size)
        return;

    using turbomind::gemm::get_test;

    {
        int m = bs;
        int n = output_dims;
        int k = input_dims;
        if (get_test().kBatchDim == 1) {
            std::swap(m, n);
        }
        std::cerr << "m" << m << "n" << n << "k" << k << "\n";

        get_test().Initialize(m, n, k, group_size, expert_num, exp_per_tok, state.get_cuda_stream());
    }

    state.add_element_count(get_test().get_element_count());

    // state.collect_dram_throughput();
    // state.collect_l2_hit_rates();

    if constexpr (1) {
        state.add_global_memory_reads(get_test().get_global_memory_reads());
        get_test().Run();
        state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {  //
            get_test().Run();
        });
    }
    else {
        state.add_global_memory_reads(get_test().get_ref_global_memory_reads());
        state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {  //
            get_test().RunCublas();
        });
    }

    get_test().ctx_.reset();
}

NVBENCH_BENCH(gemm_bench)
    .add_int64_axis("idx", nvbench::range(0, (int)config.size() - 1))
    .add_int64_power_of_two_axis("bs", nvbench::range(0, 14))
    .add_int64_axis("tp", {1, 2, 4})
    .add_int64_axis("e_num", {0})
    .add_int64_axis("e_tok", {1});

int main(int argc, char* argv[])
{
    NVBENCH_MAIN_BODY(argc, argv);
    return 0;
}


================================================
FILE: src/turbomind/kernels/gemm/test/models.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 

static const std::vector> config{
    {11008 * 2, 4096}, {4096, 11008}, {12288, 4096}, {4096, 4096},  // llama2-7b
    {14336 * 2, 4096}, {4096, 14336}, {6144, 4096},  {4096, 4096},  // llama3-8b / internlm2.5-7b
    {16384 * 2, 6144}, {6144, 16384}, {8192, 6144},  {6144, 6144},  // internlm2-20b
    {13696 * 2, 4096}, {4096, 13696}, {4608, 4096},  {4096, 4096},  // glm4-9b
    {18944 * 2, 3584}, {3584, 18944}, {4608, 3584},  {3584, 3584},  // qwen2-7b
    {20480 * 2, 7168}, {7168, 20480}, {9216, 7168},  {7168, 7168},  // yi-34b
    {28672 * 2, 8192}, {8192, 28672}, {10240, 8192}, {8192, 8192},  // llama2-70b / llama3-70b
    {29696 * 2, 8192}, {8192, 29696}, {10240, 8192}, {8192, 8192},  // qwen2-72b-instruct-awq
    {14336 * 2, 4096}, {4096, 14336}, {6144, 4096},  {4096, 4096},  // mixtral-8x7b, E8e2
    {16384 * 2, 6144}, {6144, 16384}, {0, 0},        {0, 0},        // mixtral-8x22b, E8e2
    {1536 * 2, 5120},  {5120, 1536},  {0, 0},        {0, 0},        // deepseek-v2, E160e6
    {1536 * 2, 2048},  {2048, 1536},  {0, 0},        {0, 0},        // deepseek-v2-lite, E64e6
    {2560 * 2, 3840},  {3840, 2560},  {0, 0},        {0, 0},        // qwen2-a14b, E64e8
    {6400 * 2, 4096},  {4096, 6400},  {0, 0},        {0, 0},        // phi-3.5-MoE, E16e2
};

// static const std::map> moe_config{{32, {8, 2}}, {33, {8, 2}}};

// {29568 * 2, 8192}, {8192, 29568}, {10240, 8192}, {8192, 8192},  // qwen2-72b


================================================
FILE: src/turbomind/kernels/gemm/test/quantization.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/test/quantization_impl.h"

namespace turbomind::gemm {

template void Quantize(const thrust::universal_vector& x,
                                int                                   m,
                                int                                   k,
                                Order                                 order,
                                int                                   group_size,
                                thrust::universal_vector&       x_p,  // pseudo-quantized
                                thrust::universal_vector&   x_q,  // quantized ushort
                                thrust::universal_vector&       x_u,  // scales & zeros (always m-major)
                                cudaStream_t                          stream);

template void Quantize(const thrust::universal_vector& x,
                                int                                   m,
                                int                                   k,
                                Order                                 order,
                                int                                   group_size,
                                thrust::universal_vector&       x_p,  // pseudo-quantized
                                thrust::universal_vector&   x_q,  // quantized ushort
                                thrust::universal_vector&       x_u,  // scales & zeros (always m-major)
                                cudaStream_t                          stream);

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/test/quantization.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/types.h"
#include 
#include 

#pragma once

namespace turbomind::gemm {

template
void Quantize(const thrust::universal_vector&  x,
              int                                 m,
              int                                 k,
              Order                               order,
              int                                 group_size,
              thrust::universal_vector&        x_p,  // pseudo-quantized
              thrust::universal_vector& x_q,  // quantized ushort
              thrust::universal_vector&        x_u,  // scales & zeros (always m-major)
              cudaStream_t                        stream);

}


================================================
FILE: src/turbomind/kernels/gemm/test/quantization_impl.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/attention/quantization.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/kernels/gemm/types.h"

#include 
#include 

namespace turbomind::gemm {

// quantize using `scale` and `zeros`,
template
__global__ void find_stats(Array* minmax, const T* src, int N, int K, int G)
{
    int n_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int k_idx = blockIdx.y;

    if (n_idx >= N || k_idx * G >= K) {
        return;
    }

    float minval = std::numeric_limits::infinity();
    float maxval = -minval;

    const int L = min(K, G);

    for (int k = 0; k < L; k += 8) {
        Array vec;
        Load(vec, &src[n_idx * K + k_idx * G + k]);
        PRAGMA_UNROLL
        for (int i = 0; i < vec.size(); ++i) {
            minval = __hmin(minval, vec[i]);
            maxval = __hmax(maxval, vec[i]);
        }
    }

    // store in n-major
    Store(minmax[k_idx * N + n_idx].data(), Array{minval, maxval});
}

template
__global__ void find_params(T* param, const Array* minmax, int count)
{
    int global_idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (global_idx >= count) {
        return;
    }
    auto        stats     = minmax[global_idx];
    const float inv_q_max = fdividef(1.f, (1 << bitsof)-1);

    static_assert(asym);

    float scale = (T)(((float)stats[1] - (float)stats[0]) * inv_q_max);

    // force trivial scale / zero for debugging
    if constexpr (0) {
        stats[0] = 0;
        scale    = 1.f;
    }

    Store(param + global_idx * 2, Array{scale, stats[0]});
}

template
__global__ void quantize(uint16_t* dst, T* pseudo, const T* src, const T* stats, int N, int K, int G)
{
    static_assert(bitsof <= 16);
    static_assert(bitsof == 16);  // fp16 & bf16

    int n_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int k_idx = blockIdx.y;

    if (n_idx >= N || k_idx * G >= K) {
        return;
    }

    Array param;
    Load(param, stats + (k_idx * N + n_idx) * 2);

    float inv_scale = fdividef(1.f, param[0]);

    const int L = min(K, G);

    for (int k = 0; k < L; k += 8) {
        Array        vi;
        Array vo;
        Load(vi, &src[n_idx * K + k_idx * G + k]);

        PRAGMA_UNROLL
        for (int i = 0; i < 8; ++i) {
            float u = (static_cast(vi[i] - param[1])) * inv_scale;
            vo[i]   = quant(u, bitsof);
        }
        Store(&dst[n_idx * K + k_idx * G + k], vo);

        if (pseudo) {
            Array vf;
            PRAGMA_UNROLL
            for (int i = 0; i < 8; ++i) {
                vf[i] = __hfma(static_cast(vo[i]), param[0], param[1]);
            }
            Store(&pseudo[n_idx * K + k_idx * G + k], vf);
        }
    }
}

template
__global__ void transpose(const T* src, T* dst, int s, int c)
{
    const int cid = threadIdx.x + blockIdx.x * blockDim.x;
    const int sid = threadIdx.y + blockIdx.y * blockDim.y;
    if (sid < s && cid < c) {
        dst[cid * s + sid] = src[sid * c + cid];
    }
}

template
void invokeTranspose(const T* src, T* dst, int s, int c, cudaStream_t stream)
{
    const dim3 block{32, 16};
    const dim3 grid(ceil_div(c, block.x), ceil_div(s, block.y));

    transpose<<>>(src, dst, s, c);
}

template
void Quantize(const thrust::universal_vector&  x,
              int                                 m,
              int                                 k,
              Order                               order,
              int                                 group_size,
              thrust::universal_vector&        x_p,  // pseudo-quantized
              thrust::universal_vector& x_q,  // quantized ushort
              thrust::universal_vector&        x_u,  // scales & zeros (always m-major)
              cudaStream_t                        stream)

{
    auto policy = thrust::device.on(stream);

    thrust::universal_vector           _x(x.size());
    thrust::universal_vector           _x_p(x.size());
    thrust::universal_vector    _x_q(x.size());
    thrust::universal_vector> stats(ceil_div(k, group_size) * m);

    x_p.resize(x.size());
    x_q.resize(x.size());
    /// FIXME: correct the size
    x_u.resize(stats.size() * 2);

    if (order == Order::kRowMajor) {
        thrust::copy(policy, x.begin(), x.end(), _x.begin());
    }
    else {
        invokeTranspose(x.data().get(), _x.data().get(), k, m, stream);
    }

    const int  block = std::min(256, m);
    const dim3 grid(ceil_div(m, block), ceil_div(k, group_size));

    find_stats<<>>(stats.data().get(),  //
                                           _x.data().get(),
                                           m,
                                           k,
                                           group_size);

    find_params<<(stats.size(), 256), 256, 0, stream>>>(  //
        x_u.data().get(),
        stats.data().get(),
        stats.size());

    quantize<<>>(_x_q.data().get(),  //
                                            _x_p.data().get(),
                                            _x.data().get(),
                                            x_u.data().get(),
                                            m,
                                            k,
                                            group_size);

    if (order == Order::kRowMajor) {
        thrust::copy(policy, _x_p.begin(), _x_p.end(), x_p.begin());
        thrust::copy(policy, _x_q.begin(), _x_q.end(), x_q.begin());
    }
    else {
        invokeTranspose(_x_p.data().get(), x_p.data().get(), m, k, stream);
        invokeTranspose(_x_q.data().get(), x_q.data().get(), m, k, stream);
    }

    cudaStreamSynchronize(stream);

    // Compare(_x_p.data().get(), _x.data().get(), k, k, m);

    const int kg = ceil_div(k, group_size);
    for (int i = 0; i < m * kg; ++i) {
        // int mi = i % m;
        // int ki = i / m;

        // x_u[i * 2]     = i;
        // x_u[i * 2 + 1] = i;

        // x_u[i * 2]     = i * 2;
        // x_u[i * 2 + 1] = i * 2 + 1;
    }
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/test/reference.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/test/reference.h"
#include 

namespace turbomind::gemm {

#define CHECK(cond)                                                                                                    \
    do {                                                                                                               \
        if (!(cond)) {                                                                                                 \
            fprintf(stderr, "*** Check failed: (%s) @ %s:%d\n", #cond, __FILE__, __LINE__);                            \
            std::abort();                                                                                              \
        }                                                                                                              \
    } while (0)

namespace {

MatrixLayout transpose(MatrixLayout x)
{
    std::swap(x.rows, x.cols);
    x.order = x.order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;
    return x;
}

cudaDataType to_cuda_dtype(DataType dtype)
{
    switch (dtype) {
        case DataType::kFloat16:
            return CUDA_R_16F;
        case DataType::kBfloat16:
            return CUDA_R_16BF;
        default:
            CHECK("unsupported data type" && 0);
    }
    return {};
}

}  // namespace

Reference::Reference()
{
    cublasCreate(&handle_);

    cublasSetWorkspace(handle_, nullptr, 0);
    cublasSetMathMode(handle_, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}

Reference::~Reference()
{
    if (handle_) {
        cublasDestroy(handle_);
        handle_ = {};
    }
}

void Reference::set_stream(cudaStream_t stream)
{
    cublasSetStream(handle_, stream);
}

void Reference::gemm(const void* A, MatrixLayout Adesc, const void* B, MatrixLayout Bdesc, void* C, MatrixLayout Cdesc)
{

    // Transpose the problem for C to be column major
    if (Cdesc.order == Order::kRowMajor) {
        std::swap(A, B);
        std::swap(Adesc, Bdesc);
        Adesc = transpose(Adesc);
        Bdesc = transpose(Bdesc);
        Cdesc = transpose(Cdesc);
        // (n, k) (k, m)
    }

    TM_CHECK_EQ(Adesc.cols, Bdesc.rows);

    // (m, k) (k, n)
    int m = Cdesc.rows;
    int n = Cdesc.cols;
    int k = Adesc.cols;

    TM_CHECK_EQ(Adesc.rows, m);
    TM_CHECK_EQ(Bdesc.cols, n);
    TM_CHECK_EQ(Bdesc.rows, k);

    float alpha = 1.f;
    float beta  = 0.f;

    auto to_cublas_op = [](Order o) { return o == Order::kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T; };

    auto status = cublasGemmEx(handle_,
                               to_cublas_op(Adesc.order),
                               to_cublas_op(Bdesc.order),
                               m,
                               n,
                               k,
                               &alpha,
                               A,
                               to_cuda_dtype(Adesc.type),
                               Adesc.ld,
                               B,
                               to_cuda_dtype(Bdesc.type),
                               Bdesc.ld,
                               &beta,
                               C,
                               to_cuda_dtype(Cdesc.type),
                               Cdesc.ld,
                               CUBLAS_COMPUTE_32F,
                               CUBLAS_GEMM_DEFAULT_TENSOR_OP);

    TM_CHECK_EQ(status, CUBLAS_STATUS_SUCCESS);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/test/reference.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/types.h"

#include 

namespace turbomind::gemm {

class Reference {
public:
    Reference();
    ~Reference();

    void set_stream(cudaStream_t stream);

    void gemm(const void* A, MatrixLayout Adesc, const void* B, MatrixLayout Bdesc, void* C, MatrixLayout Cdesc);

private:
    cublasHandle_t handle_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/test/test_gemm_v2.cc
================================================


#include "src/turbomind/core/context.h"
#include "src/turbomind/core/data_type.h"

#include "testbed_v3.h"

using namespace turbomind;

struct TestParameter: Testbed_v3::Parameter {
    TestParameter(DataType dtype, DataType wtype, DataType itype, int group_size = 128): Testbed_v3::Parameter{}
    {
        data_type   = dtype;
        weight_type = wtype;
        input_type  = itype;

        this->group_size = group_size;
    }
};

int main()
{
    auto stream = core::Stream::create();

    core::ContextGuard ctx{stream, core::Allocator{kCPU}, core::Allocator{stream, false}};
    // clang-format off
    // TestParameter p{kHalf, kUint4      , kHalf, 128};
    // TestParameter p{kHalf, kFloat4_e2m1, kHalf,  32};
    // TestParameter p{kHalf, kFloat8_e4m3, kHalf, 128};
    // TestParameter p{kHalf, kHalf       , kHalf};

    // TestParameter p{kBfloat16, kBfloat16   , kBfloat16};
    // TestParameter p{kBfloat16, kFloat8_e4m3, kFloat8_e4m3, 128};
    TestParameter p{kBfloat16, kFloat8_e4m3, kBfloat16   , 128};
    // TestParameter p{kBfloat16, kFloat4_e2m1, kBfloat16   ,  32};
    // clang-format on

    // p.input_dim      = 512;
    // p.output_dim     = 1024;
    // p.max_batch_size = 256;

    // p.input_dim      = 1024;
    // p.output_dim     = 1024;
    // p.max_batch_size = 1024;

    // p.input_dim      = 12288;
    // p.output_dim     = 16384;
    // p.max_batch_size = 8192;

    // p.expert_num        = 1;
    // p.experts_per_token = 1;

    // p.input_dim      = 2880;
    // p.output_dim     = 2880;
    // p.max_batch_size = 64;

    // p.input_dim         = 7168;
    // p.output_dim        = 4096;
    // p.max_batch_size    = 16384;
    // p.expert_num        = 256;
    // p.experts_per_token = 8;

    // Qwen3-MoE
    p.expert_num        = 128;
    p.experts_per_token = 8;
    // 30B
    // p.input_dim  = 2048;
    // p.output_dim = 768 * 2;
    // 235B
    // p.input_dim  = 4096;
    // p.output_dim = 1536 * 2;
    // 480B
    p.input_dim  = 6144;
    p.output_dim = 2560 * 2;

    p.max_batch_size = 256;

    // p.input_dim         = 16384;
    // p.output_dim        = 16384;
    // p.max_batch_size    = 16384;

    // p.input_dim         = 2880;
    // p.output_dim        = 5760;
    // p.max_batch_size    = 16384;
    // p.expert_num        = 32;
    // p.experts_per_token = 4;

    // p.input_dim      = 128;
    // p.output_dim     = 32;
    // p.max_batch_size = 1;

    Testbed_v3 test{p};

    test.GetReference();
    test.Run();
    test.Compare();

    cudaDeviceSynchronize();

    return 0;
}


================================================
FILE: src/turbomind/kernels/gemm/test/test_moe_utils.cu
================================================
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/kernels/gemm/tuner/cache_utils.h"
#include "src/turbomind/kernels/gemm/types.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace turbomind;

template
void print_vecs(const T* data, int m, int k, std::string msg, int width = 4)
{
    if (!msg.empty()) {
        std::cout << msg << ":\n";
    }
    for (int mm = 0; mm < m; ++mm) {
        for (int kk = 0; kk < k; ++kk) {
            std::cout << std::setw(width) << data[mm * k + kk];
        }
        std::cout << "\n";
    }
}

template
void diff_vecs(const T* data, const T* refs, int m, int k, std::string msg)
{
    if (!msg.empty()) {
        std::cout << msg << ": [" << m << ", " << k << "]\n";
    }
    for (int mm = 0; mm < m; ++mm) {
        std::cout << "m=" << mm << ": ";
        for (int kk = 0; kk < k; ++kk) {
            const auto& x = data[mm * k + kk];
            const auto& y = refs[mm * k + kk];
            if (x != y) {
                std::cout << kk << "(" << x << ", " << y << ") ";
            }
        }
        std::cout << "\n";
    }
}

RNG& gRNG()
{
    static RNG inst{};
    return inst;
}

using thrust::universal_vector;

void moe_gate_ref(int                            tokens,
                  int                            expert_num,
                  int                            experts_per_token,
                  const universal_vector& logits,
                  universal_vector&         offsets,
                  universal_vector&         eids,
                  universal_vector&         f2n,
                  universal_vector&         en2f,
                  universal_vector&       scales)
{
    std::vector eid_range(expert_num);
    std::iota(eid_range.begin(), eid_range.end(), 0);

    for (int t = 0; t < tokens; ++t) {
        const float* logit   = logits.data().get() + expert_num * t;
        const float  max_val = *std::max_element(logit, logit + expert_num);
        if constexpr (0) {
            std::vector probs(logit, logit + expert_num);
            float              sum = 0;
            for (auto& p : probs) {
                p = std::exp(p - max_val);
                sum += p;
            }
            for (auto& p : probs) {
                p /= sum;
            }
            std::vector idxs = eid_range;
            // Had to use stable sort since there is no `std::stable_nth_element`
            std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //
                return probs[i] > probs[j];
            });
            // Recover natural order in top-k
            std::sort(idxs.begin(), idxs.begin() + experts_per_token);
            idxs.resize(experts_per_token);
            sum = 0;
            for (int e = 0; e < experts_per_token; ++e) {
                eids[e * tokens + t] = idxs[e];
                sum += probs[idxs[e]];
            }
            for (int e = 0; e < experts_per_token; ++e) {
                scales[e * tokens + t] = probs[idxs[e]] / sum;
            }
        }
        else {
            std::vector idxs = eid_range;
            // Had to use stable sort since there is no `std::stable_nth_element`
            std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //
                return logit[i] > logit[j];
            });
            // Recover natural order in top-k
            std::sort(idxs.begin(), idxs.begin() + experts_per_token);
            idxs.resize(experts_per_token);
            std::vector probs(experts_per_token);
            float              sum = 0;
            for (int e = 0; e < experts_per_token; ++e) {
                eids[e * tokens + t] = idxs[e];
                probs[e]             = std::exp(logit[idxs[e]] - max_val);
                sum += probs[e];
            }
            for (int e = 0; e < experts_per_token; ++e) {
                scales[e * tokens + t] = probs[e] / sum;
            }
        }
    }

    // f2en
    std::vector f2en(eids.size());
    std::iota(f2en.begin(), f2en.end(), 0);

    std::stable_sort(f2en.begin(), f2en.end(), [&](int i, int j) {  //
        if (eids[i] != eids[j]) {
            return eids[i] < eids[j];
        }
        return i % tokens < j % tokens;
    });

    std::fill_n(offsets.begin(), offsets.size(), 0);
    std::vector accum(expert_num);

    for (size_t i = 0; i < f2en.size(); ++i) {
        f2n[i]        = f2en[i] % tokens;
        en2f[f2en[i]] = i;
        ++accum[eids[i]];
    }

    for (size_t i = 1; i < offsets.size(); ++i) {
        offsets[i] = offsets[i - 1] + accum[i - 1];
    }
}

void mask2eids(universal_vector& masks, universal_vector& eids, int tokens, int expert_num)
{
    const int tokens_padded = masks.size() / expert_num;
    // std::cout << eids.size() << std::endl;
    for (int e = 0; e < expert_num; ++e) {
        for (int t = 0; t < tokens_padded; ++t) {
            if (auto v = masks[e * tokens_padded + t]; v >= 0) {
                // if (v >= 2 || t >= 8193) {
                //     std::cerr << "FUCK " << v << " " << t << std::endl;
                // }
                eids[v * tokens + t] = e;
            }
        }
    }
}

struct Tiling {
    int  output_dims;
    int  input_dims;
    int3 cta_tile;
};

bool test_moe_gate(int                     tokens,  //
                   int                     expert_num,
                   int                     experts_per_token,
                   gemm::Tape&             tape,
                   const Tiling&           tiling,
                   universal_vector logits = {})
{
    if (logits.empty()) {
        logits.resize(tokens * expert_num);
        gRNG().GenerateUniform(logits.data().get(), logits.size());
    }
    assert(logits.size() == tokens * expert_num);

    const int tokens_padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;
    // const int max_coords    = get_max_coords(tokens, expert_num, experts_per_token, tiling);

    universal_vector    offsets(expert_num + 1);
    universal_vector    accum(expert_num * kMoeGateMaxTiles);
    universal_vector masks(expert_num * tokens_padded);
    universal_vector    eids(experts_per_token * tokens);
    universal_vector    f2n(experts_per_token * tokens);
    universal_vector    f2E(experts_per_token * tokens);
    universal_vector    en2f(experts_per_token * tokens);
    universal_vector  scales(experts_per_token * tokens);
    // universal_vector  coords(max_coords);
    // thrust::fill(coords.begin(), coords.end(), int2{-1, 0});

    auto offsets_ref = offsets;
    auto eids_ref    = eids;
    auto f2n_ref     = f2n;
    auto en2f_ref    = en2f;
    auto scales_ref  = scales;

    moe_gate_ref(tokens, expert_num, experts_per_token, logits, offsets_ref, eids_ref, f2n_ref, en2f_ref, scales_ref);

    cudaMemPrefetchAsync(f2n.data().get(), sizeof(int) * f2n.size(), 0);
    cudaMemPrefetchAsync(f2E.data().get(), sizeof(int) * f2E.size(), 0);
    cudaMemPrefetchAsync(en2f.data().get(), sizeof(int) * en2f.size(), 0);
    cudaMemPrefetchAsync(offsets.data().get(), sizeof(int) * offsets.size(), 0);
    cudaMemPrefetchAsync(scales.data().get(), sizeof(float) * scales.size(), 0);
    cudaMemPrefetchAsync(logits.data().get(), sizeof(float) * logits.size(), 0);

    bool softmax = true;

    if (1) {
        invokeMoeSoftmaxMaskTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 8, nullptr);
        softmax = false;
    }

    for (int i = 0; i < 1; ++i) {
        gemm::CacheFlushing::flush();
        cudaMemset(accum.data().get(), 0, sizeof(int) * accum.size());
        cudaMemset(masks.data().get(), -1, sizeof(int8_t) * masks.size());
        invokeMoeGate_V2(f2n.data().get(),
                         f2E.data().get(),
                         en2f.data().get(),
                         offsets.data().get(),
                         scales.data().get(),
                         masks.data().get(),
                         accum.data().get(),
                         logits.data().get(),
                         tokens,
                         tokens_padded,
                         expert_num,
                         experts_per_token,
                         softmax,
                         false,
                         1.f,
                         nullptr);
    }

    // invokeMoeTiling(coords.data().get(), offsets.data().get(), expert_num, coords.size(), &tiling, 1, 0);

    // gemm::scheduleGemmMoe(tape,
    //                       offsets.data().get(),
    //                       tokens,
    //                       experts_per_token,
    //                       expert_num,
    //                       tiling.output_dims,
    //                       tiling.input_dims,
    //                       tiling.cta_tile,
    //                       tiling.cta_tile.z,
    //                       1,
    //                       0,
    //                       0);

    if (auto err = cudaDeviceSynchronize(); err != cudaSuccess) {
        std::cerr << cudaGetErrorString(err) << std::endl;
        std::abort();
    }

    // print_vecs(masks.data().get(), expert_num, tokens_padded, "masks");
    mask2eids(masks, eids, tokens, expert_num);

    bool success = true;

    // success = offsets == offsets_ref && eids == eids_ref && f2n == f2n_ref && en2f == en2f_ref;

    if (offsets != offsets_ref) {
        std::cerr << "offset\n";
        success = false;
    }
    if (eids != eids_ref) {
        std::cerr << "eids\n";
        success = false;
    }
    if (f2n != f2n_ref) {
        std::cerr << "f2n\n";
        success = false;
    }
    if (en2f != en2f_ref) {
        std::cerr << "en2f\n";
        success = false;
    }

    // print_vecs(logits.data().get(), tokens, expert_num, "logits", 12);

    if (!success && 1) {

        diff_vecs(eids.data().get(), eids_ref.data().get(), experts_per_token, tokens, "eids");

        print_vecs(offsets_ref.data().get(), 1, expert_num + 1, "offsets_ref");
        print_vecs(offsets.data().get(), 1, expert_num + 1, "offsets");

        print_vecs(eids_ref.data().get(), experts_per_token, tokens, "eids_ref");
        print_vecs(eids.data().get(), experts_per_token, tokens, "eids");

        print_vecs(f2n_ref.data().get(), 1, experts_per_token * tokens, "f2n_ref");
        print_vecs(f2n.data().get(), 1, experts_per_token * tokens, "f2n");

        print_vecs(en2f_ref.data().get(), experts_per_token, tokens, "en2f_ref");
        print_vecs(en2f.data().get(), experts_per_token, tokens, "en2f");

        print_vecs(scales_ref.data().get(), experts_per_token, tokens, "scales_ref", 12);
        print_vecs(scales.data().get(), experts_per_token, tokens, "scales", 12);

        for (int i = 0; i < tokens; ++i) {
            float sum = 0;
            for (int j = 0; j < experts_per_token; ++j) {
                sum += scales[j * tokens + i];
            }
            std::cout << sum << " ";
        }
        std::cout << "\n";

        // print_vecs(accum.data().get(), expert_num, 1, "accum");

        // print_vecs(coords.data().get(), 1, max_coords, "coords");

        // thrust::host_vector tile_offsets(tape.max_ctas);
        // std::cout << tape.max_ctas << std::endl;
        // cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(),
        // cudaMemcpyDefault); cudaDeviceSynchronize();

        // std::cout << "coords:\n";
        // int last = -1;
        // for (int i = 0; i < tape.max_ctas; ++i) {
        //     auto& c = tile_offsets[i];
        //     if (last >= 0 && c.w != last) {
        //         std::cout << "\n";
        //     }
        //     if (c.w == -1) {
        //         std::cout << i << "\n";
        //         break;
        //     }
        //     last = c.w;
        //     std::stringstream ss;
        //     ss << c.x << "," << c.y;
        //     std::cout << std::setw(6) << ss.str();
        // }
        // std::cout << "\n";
    }

    return success;
}

int main()
{
    gemm::Tape       tape{};
    constexpr Tiling tiling{14336, 128, {128, 128, 32}};

    // test_moe_gate(32768 * 4, 60, 4, tape, tiling);
    // test_moe_gate(32768, 64, 8, tape, tiling);
    // test_moe_gate(8, 60, 4, tape, tiling);

    test_moe_gate(16, 160, 6, tape, tiling);

    return 0;

    for (int i = 1; i < 16384; ++i) {
        // std::cerr << i << std::endl;
        auto success = test_moe_gate(i, 8, 2, tape, tiling);
        if (!success) {
            std::cerr << i << std::endl;
            // std::abort();
        }
        // break;
    }
}


================================================
FILE: src/turbomind/kernels/gemm/test/test_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/core/core.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include 
#include 
#include 
#include 
#include 
#include 

#define _CG_ABI_EXPERIMENTAL
#include 
#include 
#include 

#include 
#include 
#include 
#include 

namespace turbomind {

cublasHandle_t cublas_handle{};
cudaStream_t   cublas_stream{};

template
void Compare(const T* src, const T* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol)
{
    float asums{};
    float rsums{};
    int   outliers{};
    for (int nn = 0; nn < bsz; ++nn) {
        float abs_diff_sum{};
        float rel_diff_sum{};
        for (int mm = 0; mm < dims; ++mm) {
            auto x = float(src[nn * stride + mm]);
            auto y = float(ref[nn * stride + mm]);
            // if (show) {
            //     std::cout << x << "\t" << y << std::endl;
            // }
            auto abs_diff = std::abs(x - y);
            auto rel_diff = abs_diff / (std::max(std::abs(y), std::abs(x)) + 1e-8f);
            if (!(abs_diff <= atol + rtol * std::abs(y))) {
                ++outliers;
                if (show) {
                    std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
                }
            }
            abs_diff_sum += abs_diff;
            rel_diff_sum += rel_diff;
        }
        asums += abs_diff_sum / dims;
        rsums += rel_diff_sum / dims;
    }
    const float abs_diff = asums / bsz;
    const float rel_diff = rsums / bsz;
    const float outlier  = outliers / (float)bsz;
    std::cout << "abs_diff = " << abs_diff << " rel_diff = " << rel_diff << " outliers = " << outlier << std::endl;
}

template void
Compare(const half* src, const half* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol);
template void
Compare(const float* src, const float* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol);
#if ENABLE_BF16
template void Compare(const nv_bfloat16* src,
                      const nv_bfloat16* ref,
                      size_t             stride,
                      int                dims,
                      int                bsz,
                      bool               show,
                      float              rtol,
                      float              atol);
#endif

void Compare(
    const void* x, const void* r, DataType dtype, size_t stride, int dim, int bsz, bool show, float rtol, float atol)
{
    auto invoke = [&](auto t) {
        using T = decltype(t);
        Compare((const T*)x, (const T*)r, stride, dim, bsz, show, rtol, atol);
    };
    TM_DISPATCH_DTYPES(dtype, invoke, half_t, bfloat16_t);
}

template
std::vector
FastCompare(const T* src, const T* ref, int dims, int bsz, cudaStream_t stream, float rtol, float atol)
{
    auto       zip_iter = thrust::make_zip_iterator(src, ref);
    const auto count    = (size_t)dims * bsz;
    // nvcc-11.8: __host__ __device__ lambda can't be generic
    using Tuple = thrust::tuple;
    auto res    = thrust::transform_reduce(
        thrust::cuda::par.on(stream),
        zip_iter,
        zip_iter + count,
        [=] __host__ __device__(thrust::tuple tup) -> Tuple {
            float   s        = thrust::get<0>(tup);
            float   r        = thrust::get<1>(tup);
            float   abs_diff = fabsf(s - r);
            float   abs_s    = fabsf(s);
            float   abs_r    = fabsf(r);
            float   rel_diff = abs_diff / (fmaxf(abs_r, abs_s) + 1e-8f);
            int64_t outlier  = !(abs_diff <= (atol + rtol * abs_r));
            return thrust::make_tuple(abs_s, abs_r, abs_diff, abs_diff, rel_diff, rel_diff, outlier);
        },
        thrust::make_tuple(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0LL),
        [] __host__ __device__(const Tuple& a, const Tuple& b) {  // `__host__`: compiler needs the return type
            return thrust::make_tuple(thrust::get<0>(a) + thrust::get<0>(b),
                                      thrust::get<1>(a) + thrust::get<1>(b),
                                      thrust::get<2>(a) + thrust::get<2>(b),
                                      fmaxf(thrust::get<3>(a), thrust::get<3>(b)),
                                      thrust::get<4>(a) + thrust::get<4>(b),
                                      fmaxf(thrust::get<5>(a), thrust::get<5>(b)),
                                      thrust::get<6>(a) + thrust::get<6>(b));
        });
    return {thrust::get<0>(res) / dims / bsz,   // avg abs src
            thrust::get<1>(res) / dims / bsz,   // avg abs ref
            thrust::get<2>(res) / dims / bsz,   // avg abs diff
            thrust::get<3>(res),                // max abs diff
            thrust::get<4>(res) / dims / bsz,   // avg rel diff
            thrust::get<5>(res),                // max rel diff
            (float)thrust::get<6>(res) / bsz};  // outlier count
}

template std::vector FastCompare(const half*  src,  //
                                        const half*  ref,
                                        int          dims,
                                        int          bsz,
                                        cudaStream_t stream,
                                        float        rtol,
                                        float        atol);

template std::vector FastCompare(const nv_bfloat16* src,  //
                                        const nv_bfloat16* ref,
                                        int                dims,
                                        int                bsz,
                                        cudaStream_t       stream,
                                        float              rtol,
                                        float              atol);

template std::vector FastCompare(const float* src,  //
                                        const float* ref,
                                        int          dims,
                                        int          bsz,
                                        cudaStream_t stream,
                                        float        rtol,
                                        float        atol);

std::vector FastCompare(const Tensor& x, const Tensor& r, cudaStream_t stream, float rtol, float atol)
{
    TM_CHECK_EQ(x.ndim(), 2);
    TM_CHECK(x.is_contiguous());
    TM_CHECK(x.layout() == r.layout());
    TM_CHECK(x.dtype() == r.dtype());

    auto invoke = [&](auto t) {
        using T         = decltype(t);
        auto [dim, bsz] = x.shapes(1, 0);
        return FastCompare(x.data(), r.data(), dim, bsz, stream, rtol, atol);
    };

    TM_DISPATCH_DTYPES_RET(x.dtype(), invoke, half_t, bfloat16_t, float);
}

void FC_Header()
{
    printf("%16s%16s%16s%16s%16s%16s%16s\n",
           "amean",
           "amean_ref",
           "absdiff",
           "absdiff_max",
           "reldiff",
           "reldiff_max",
           "#outlier");
}

void FC_Print(const std::vector& d)
{
    printf("%16f%16f%16f%16f%16f%16f%16f\n", d[0], d[1], d[2], d[3], d[4], d[5], d[6]);
}

void LoadBinary(const std::string& path, size_t size, void* dst)
{
    std::ifstream ifs(path, std::ios::binary | std::ios::in);
    if (!ifs.is_open()) {
        std::cerr << "failed to open " << path << "\n";
        std::abort();
    }
    ifs.seekg(0, ifs.end);
    auto actual_size_in_bytes = ifs.tellg();
    ifs.seekg(0, ifs.beg);
    if (size != actual_size_in_bytes) {
        std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
                  << " bytes is requested\n";
    }
    ifs.read((char*)dst, size);
    std::cerr << "[info] " << path << " " << size << "\n";
}

namespace cg = cooperative_groups;

__global__ void curand_init(curandState* state)
{
    auto tid = cg::this_grid().thread_rank();
    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
}

template
__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        float tmp = curand_uniform(state + grid.thread_rank());
        result[i] = T(scale * tmp + shift);
    }
}

template
__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        float tmp = curand_normal(state + grid.thread_rank());
        result[i] = T(scale * tmp + shift);
    }
}

__global__ void curand_bytes(curandState* state, size_t count, uint* result)
{
    auto grid = cg::this_grid();
    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
        result[i] = curand(state + grid.thread_rank());
    }
}

struct RNG::Impl {

    curandState* states{};

    Impl()
    {
        cudaMalloc(&states, sizeof(curandState) * 64 * 64);
        curand_init<<<64, 64>>>(states);
    }

    ~Impl()
    {
        cudaFree(states);
    }

    void GenerateUInt(uint* out, size_t count)
    {
        curand_bytes<<<64, 64, 0, stream_>>>(states, count, out);
    }

    template
    void GenerateUniform(T* out, size_t count, float scale, float shift)
    {
        curand_uniform<<<64, 64, 0, stream_>>>(states, count, out, scale, shift);
    }

    template
    void GenerateNormal(T* out, size_t count, float scale, float shift)
    {
        curand_normal<<<64, 64, 0, stream_>>>(states, count, out, scale, shift);
    }

    cudaStream_t stream_{};
};

RNG::RNG(): impl_(std::make_unique()) {}

RNG::~RNG() = default;

void RNG::GenerateUInt(uint* out, size_t count)
{
    impl_->GenerateUInt(out, count);
}

template
void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
{
    impl_->GenerateUniform(out, count, scale, shift);
}

template
void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
{
    impl_->GenerateNormal(out, count, scale, shift);
}

cudaStream_t RNG::stream() const
{
    return impl_->stream_;
}

void RNG::set_stream(cudaStream_t stream)
{
    impl_->stream_ = stream;
}

template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
template void RNG::GenerateUniform(nv_bfloat16* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(nv_bfloat16* out, size_t count, float scale, float shift);

void RNG::RandomBytes(Ref out_)
{
    auto& out = out_.get();
    TM_CHECK(out.size() == out.layout().cosize());
    TM_CHECK(out.byte_size() % sizeof(uint) == 0);
    GenerateUInt((uint*)out.raw_data(), out.byte_size() / sizeof(uint));
}

void RNG::UniformFloat(Ref out_, float scale, float shift)
{
    auto& out = out_.get();
    TM_CHECK(out.size() == out.layout().cosize());
    auto invoke = [&](auto t) {
        using T = decltype(t);
        GenerateUniform(out.data(), out.size(), scale, shift);
    };
    TM_DISPATCH_DTYPES(out.dtype(), invoke, float, half_t, bfloat16_t);
}

void RNG::NormalFloat(Ref out_, float scale, float shift)
{
    auto& out = out_.get();
    TM_CHECK(out.size() == out.layout().cosize());
    auto invoke = [&](auto t) {
        using T = decltype(t);
        GenerateNormal(out.data(), out.size(), scale, shift);
    };
    TM_DISPATCH_DTYPES(out.dtype(), invoke, float, half_t, bfloat16_t);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/test/test_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/macro.h"
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

template
void Compare(const T* src,
             const T* ref,
             size_t   stride,
             int      dims,
             int      bsz,
             bool     show = false,
             float    rtol = 1e-2,
             float    atol = 1e-4);

void Compare(const void* x,
             const void* r,
             DataType    dtype,
             size_t      stride,
             int         dim,
             int         bsz,
             bool        show,
             float       rtol = 1e-2,
             float       atol = 1e-4);

template
std::vector FastCompare(const T*     src,  //
                               const T*     ref,
                               int          dims,
                               int          bsz,
                               cudaStream_t stream,
                               float        rtol = 1e-2,
                               float        atol = 1e-4);

std::vector FastCompare(const Tensor& x,  //
                               const Tensor& r,
                               cudaStream_t  stream,
                               float         rtol = 1e-2,
                               float         atol = 1e-4);

void FC_Header();

void FC_Print(const std::vector& d);

void LoadBinary(const std::string& path, size_t size, void* dst);

class RNG {
public:
    RNG();
    ~RNG();
    void GenerateUInt(uint* out, size_t count);

    template
    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);

    template
    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);

    void RandomBytes(Ref out_);

    void UniformFloat(Ref out_, float scale = 1.f, float shift = 0.f);

    void NormalFloat(Ref out_, float scale = 1.f, float shift = 0.f);

    cudaStream_t stream() const;

    void set_stream(cudaStream_t stream);

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/test/testbed_v3.h
================================================

#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/core.h"

#include "src/turbomind/core/tensor.h"
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/kernels/gemm/test/reference.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/quantization.h"

#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"

#include "src/turbomind/kernels/gpt_kernels.h"

namespace turbomind {

using std::vector;
using std::unique_ptr;

using DenseWeight = LlamaDenseWeight;
using Linear      = LlamaLinear;

using namespace gemm;

struct Parameter {
    int input_dim;
    int output_dim;

    DataType data_type;
    DataType weight_type;
    DataType input_type;

    int group_size;

    int max_batch_size;

    int  expert_num;
    int  experts_per_token;
    bool combine_experts;
};

/// TODO: add a generic copy / casting for non-sub-byte Tensor
static Tensor CopyTransposed(const Tensor& src, Tensor out = {})
{
    if (out) {
        TM_CHECK(out.shapes(0, 1) == src.shapes(1, 0)) << src << " vs " << out;
        TM_CHECK_EQ(out.dtype(), src.dtype());
    }
    else {
        out = {{src.shape(1), src.shape(0)}, src.dtype(), src.device()};
    }

    auto invoke = [&](auto t) {
        using T = decltype(t);
        invokeTransposeAxis01(
            (T*)out.raw_data(), (T*)src.raw_data(), src.shape(0), src.shape(1), 1, core::Context::stream().handle());
    };

    const int bits = byte_size(src.dtype(), 8);
    if (bits == 8) {
        invoke(uint8_t{});
    }
    else if (bits == 16) {
        invoke(uint16_t{});
    }
    else if (bits == 32) {
        invoke(int{});
    }
    else {
        TM_CHECK(0) << "Not implemented. bits = " << bits;
    }

    return out;
}

struct Testbed_v3: Parameter {

    Testbed_v3(const Parameter& param): Parameter{param}, stream_{core::Context::stream().handle()}, linear_{}
    {
        rng_.set_stream(stream_);
        ref_.set_stream(stream_);

        if (auto str = std::getenv("TM_GEMM_IMPORT")) {
            import_file_ = str;
            std::ifstream ifs(import_file_, std::ios::binary);
            auto          n = linear_.Import(ifs);
            std::cout << "Records imported: " << n << "\n";
        }
        if (auto str = std::getenv("TM_GEMM_TUNE"); str && import_file_.empty()) {
            tuning_ = true;
            std::cout << "Enable tuning\n";
        }
        if (auto str = std::getenv("TM_GEMM_EXPORT"); str && import_file_.empty()) {
            export_file_ = str;
        }

        cudaGetDeviceProperties(&prop_, 0);

        w_original_ = std::make_unique();
        w_quant_    = std::make_unique();
        w_dequant_  = std::make_unique();

        for (int i = 0; i < expert_num; ++i) {
            e_original_.push_back(std::make_unique());
            e_quant_.push_back(std::make_unique());
            e_dequant_.push_back(std::make_unique());
        }

        GenerateWeight();
        GenerateInput();

        if (expert_num) {
            LinkExperts([&](int i) { return e_original_[i].get(); }, expert_num, *w_original_);
            LinkExperts([&](int i) { return e_quant_[i].get(); }, expert_num, *w_quant_);
            LinkExperts([&](int i) { return e_dequant_[i].get(); }, expert_num, *w_dequant_);
            Route();
        }
    }

    ~Testbed_v3()
    {
        if (!export_file_.empty()) {
            std::cerr << "export file: " << export_file_ << "\n";
            std::ofstream ofs(export_file_, std::ios::binary);
            if (ofs.is_open()) {
                auto n = linear_.Export(ofs);
                std::cout << "Records exported: " << n << "\n";
            }
        }
    }

    void GenerateInput()
    {
        x_original_ = Tensor{{max_batch_size, input_dim}, data_type, kDEVICE};
        rng_.NormalFloat(x_original_, 1., 1.);

        if (input_type == data_type) {
            x_quant_   = empty_like(x_original_);
            x_dequant_ = empty_like(x_original_);
            Copy(x_original_, x_quant_);
            Copy(x_original_, x_dequant_);
        }
        else if (input_type == kFloat8_e4m3) {
            QuantizeSymm(x_quant_, x_scale_, x_original_, stream_);
            DequantizeSymm(x_dequant_, x_quant_, x_scale_, stream_);
        }
        else {
            TM_CHECK(0) << "Not implemented for input type " << to_string(input_type);
        }
    }

    void Route()
    {
        const int bsz = max_batch_size;

        std::mt19937 g{};

        /// TODO: Control the distribution
        auto expert_ids = SampleUniform(bsz, expert_num, experts_per_token, g);

        std::uniform_real_distribution dist(1e-3, 1.f);

        Buffer_ tmp(experts_per_token, kCPU);
        Buffer_ scales(bsz * experts_per_token, kCPU);

        for (int i = 0; i < bsz; ++i) {
            float sum{};
            for (auto& x : tmp) {
                x = dist(g);
                sum += x;
            }
            for (int e = 0; e < experts_per_token; ++e) {
                scales[e * bsz + i] = tmp[e] / sum;
            }
        }

        vector         count(expert_num);
        vector> f2i(expert_num);
        for (int i = 0; i < (int)expert_ids.size(); ++i) {
            ++count[expert_ids[i]];
            f2i[expert_ids[i]].push_back(i);
        }

        Buffer_ offsets(expert_num + 1, kCPU);
        offsets[0] = 0;
        for (int i = 0; i < expert_num; ++i) {
            offsets[i + 1] = offsets[i] + count[i];
        }

        for (const auto& x : count) {
            std::cout << x << " ";
        }
        std::cout << "\n";

        Buffer_ f2n(expert_ids.size(), kCPU);
        Buffer_ en2f(expert_ids.size(), kCPU);
        for (int e = 0, i = 0; e < expert_num; ++e) {
            for (auto x : f2i[e]) {
                f2n[i]   = x / experts_per_token;
                int en   = x % experts_per_token * bsz + x / experts_per_token;
                en2f[en] = i;
                ++i;
            }
        }

        f2n_ = {f2n.size(), kDEVICE};
        Copy(f2n, f2n_);

        en2f_ = {en2f.size(), kDEVICE};
        Copy(en2f, en2f_);

        scales_ = {scales.size(), kDEVICE};
        Copy(scales, scales_);

        offsets_ = {offsets.size(), kDEVICE};
        Copy(offsets, offsets_);
        h_offsets_ = offsets;

        core::Context::stream().Sync();
    }

    void GenerateWeight()
    {
        if (expert_num) {
            for (int i = 0; i < expert_num; ++i) {
                GenerateWeight(*e_original_[i], *e_quant_[i], *e_dequant_[i]);
            }
        }
        else {
            GenerateWeight(*w_original_, *w_quant_, *w_dequant_);
        }
    }

    // - quantize weight
    // - dequantize weight
    void GenerateWeight(DenseWeight& original, DenseWeight& quant, DenseWeight& dequant)
    {
        original.emplace(input_dim, output_dim, data_type, false, data_type, group_size);
        rng_.NormalFloat(original.weight, 1., .1);

        quant.emplace(input_dim, output_dim, data_type, false, weight_type, group_size);
        dequant.emplace(input_dim, output_dim, data_type, false, data_type, group_size);

        Buffer_ rbits;
        // rbits = {original.weight.size(), kDEVICE};
        // rng_.RandomBytes(Tensor{rbits});

        /// Weights are allocated in MN-major, but some quantization requires K-major tensor

        if (weight_type == data_type) {
            Copy(original.weight, quant.weight);
            Copy(original.weight, dequant.weight);
        }
        else if (weight_type == kFloat8_e4m3) {
            QuantizeSymmBlock(quant.weight, quant.scales, original.weight, stream_);
            DequantizeSymmBlock(dequant.weight, quant.weight, quant.scales, stream_);
        }
        else if (weight_type == kUint4) {
            /// Weights are allocated in (M,N), quantization needs K-major tensor
            QuantizeGroupwise(quant.weight.t(),
                              quant.scales.t(),
                              quant.zeros.t(),
                              dequant.weight.t(),
                              original.weight.t(),
                              {},
                              group_size);
        }
        else if (weight_type == kFloat4_e2m1) {
            QuantizeGroupwise(quant.weight.t(),  //
                              quant.scales.t(),
                              {},
                              dequant.weight.t(),
                              original.weight.t(),
                              rbits,
                              group_size);
        }
        else {
            TM_CHECK(0);
        }

        original.prepare(0);
        quant.prepare(expert_num > 0);
        dequant.prepare(0);
    }

    void GetReference()
    {
        if (expert_num) {
            GetReference(x_original_, e_original_, d_original_);
            GetReference(x_dequant_, e_dequant_, d_dequant_);
        }
        else {
            GetReference(x_original_, w_original_, d_original_);
            GetReference(x_dequant_, w_dequant_, d_dequant_);
        }
    }

    void GetReference(const Tensor& x, const unique_ptr& dense, Ref d_)
    {
        auto& d = d_.get();
        if (!d) {
            d = Tensor{{x.shape(0), dense->output_dim}, x.dtype(), x.device()};
        }
        /// TODO: refactor reference API
        const MatrixLayout desc_A{x.dtype(), kRowMajor, (int)x.shape(0), (int)x.shape(1), (int)x.stride(0)};  // m,k
        const MatrixLayout desc_D{d.dtype(), kRowMajor, (int)d.shape(0), (int)d.shape(1), (int)d.stride(0)};  // m,n
        ref_.gemm(x.raw_data(), desc_A, dense->weight.raw_data(), dense->k_desc, d.raw_data(), desc_D);
    }

    void GetReference(const Tensor& x, const vector>& experts, Ref d_)
    {
        Tensor xe{{x.shape(0) * experts_per_token, input_dim}, data_type, kDEVICE};
        Tensor de{{x.shape(0) * experts_per_token, output_dim}, data_type, kDEVICE};

        invokeMoeDispatch(xe, x, f2n_.data(), experts_per_token, stream_);

        for (int i = 0; i < expert_num; ++i) {
            const int base = h_offsets_[i], size = h_offsets_[i + 1] - base;
            GetReference(xe.slice(base, size), experts[i], de.slice(base, size));
        }

        auto& d = d_.get();
        if (combine_experts) {
            d = Tensor{{x.shape(0), output_dim}, data_type, kDEVICE};
            invokeMoeCombine(d,  //
                             de,
                             {},
                             scales_.data(),
                             en2f_.data(),
                             nullptr,
                             nullptr,
                             experts_per_token,
                             1.,
                             0.,
                             stream_);
        }
        else {
            d = de;
        }
    }

    void Run()
    {
        if (tuning_) {
            linear_.set_measure(true);
        }
        if (expert_num) {
            auto de = linear_.Forward(x_original_, *w_quant_, f2n_, offsets_);
            if (combine_experts) {
                d_quant_ = Tensor{{x_original_.shape(0), output_dim}, data_type, kDEVICE};
                invokeMoeCombine(d_quant_,
                                 de,
                                 {},
                                 scales_.data(),
                                 en2f_.data(),
                                 nullptr,
                                 nullptr,
                                 experts_per_token,
                                 1.,
                                 0.,
                                 stream_);
            }
            else {
                d_quant_ = de;
            }
        }
        else {
            d_quant_ = linear_.Forward(x_original_, *w_quant_);
        }
        if (tuning_) {
            linear_.set_measure(false);
        }
    }

    void Run(const Tensor& x, const vector>& experts) {}

    void Compare()
    {
        // Buffer_ h(16 * 16, kCPU);
        // Buffer_ x(linear_.buf, 16 * 16, kDEVICE);
        // Copy(x, h);

        // auto y = empty_like(w_dequant_->weight, kCPU);
        // Copy(w_dequant_->weight, y);

        // clang-format off
        printf("%20s", ""); FC_Header();
        if (!expert_num) {
            printf("%20s", "w_dequant v w_origi"); FC_Print(FastCompare(w_dequant_->weight, w_original_->weight, stream_));
        }
        printf("%20s", "quant   vs  dequant"); FC_Print(FastCompare(d_quant_, d_dequant_, stream_));
        printf("%20s", "quant   vs original"); FC_Print(FastCompare(d_quant_, d_original_, stream_));
        printf("%20s", "dequant vs original"); FC_Print(FastCompare(d_dequant_, d_original_, stream_));
        // clang-format on

        // for (int m = 0; m < 16; ++m) {
        //     for (int k = 0; k < 16; ++k) {
        //         printf("%5.1f", h[m * 16 + k]);
        //     }
        //     printf("\n");
        // }

        // printf("\n");

        // for (int m = 0; m < 16; ++m) {
        //     for (int k = 0; k < 16; ++k) {
        //         printf("%5.1f", (float)y.data()[k * output_dim + m]);
        //     }
        //     printf("\n");
        // }
    }

    cudaStream_t stream_;

    cudaDeviceProp prop_;

    Linear linear_;

    // ! weights are non-movable
    unique_ptr w_original_;
    unique_ptr w_quant_;
    unique_ptr w_dequant_;

    Tensor x_original_;
    Tensor x_quant_, x_scale_;
    Tensor x_dequant_;

    Tensor d_original_;  // x_original * w_original
    Tensor d_quant_;     // x_original * w_quant, quant for X done by `Linear`
    Tensor d_dequant_;   // x_dequant  * w_dequant

    vector> e_original_;
    vector> e_quant_;
    vector> e_dequant_;

    Buffer_ f2n_;
    Buffer_ en2f_;

    Buffer_   offsets_;
    Buffer_ scales_;

    Buffer_ h_offsets_;

    bool tuning_{};

    std::string import_file_;
    std::string export_file_;

    RNG       rng_;
    Reference ref_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/thread_group_map.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/thread_map.h"

#include 

namespace turbomind::gemm {

template
struct RakedThreadGroupMap {
    static constexpr int M = M_;
    static constexpr int N = N_;
    static constexpr int K = K_;

    static constexpr int TileM = TM;
    static constexpr int TileN = TN;
    static constexpr int TileK = TK;

    static constexpr int kGroupM = GM;
    static constexpr int kGroupN = GN;
    static constexpr int kGroupK = GK;

    static constexpr int kGroupCount = GM * GN * GK;

    static constexpr int M1 = GM * TM;
    static constexpr int N1 = GN * TN;
    static constexpr int K1 = GK * TK;

    static constexpr int kIterM = M / M1;
    static constexpr int kIterN = N / N1;
    static constexpr int kIterK = K / K1;

    static constexpr int kFootprintM = kIterM * TM;
    static constexpr int kFootprintN = kIterN * TN;
    static constexpr int kFootprintK = kIterK * TK;

    static constexpr int kDeltaM = TM;
    static constexpr int kDeltaN = TN;
    static constexpr int kDeltaK = TK;

    __device__ static int3 get_offset(int group_id)
    {
        const int m = group_id % GM;
        const int n = group_id / GM % GN;
        const int k = group_id / GM / GN;
        return {m * kFootprintM, n * kFootprintN, k * kFootprintK};
    }
};

template
struct MMA_Map {
    static constexpr int M = M_;
    static constexpr int N = N_;
    static constexpr int K = K_;

    static constexpr int TileM = tM_;
    static constexpr int TileN = tN_;
    static constexpr int TileK = tK_;

    static constexpr int kGroupM = ArrangementMN::gM;
    static constexpr int kGroupN = ArrangementMN::gN;
    static constexpr int kGroupK = gK;

    static constexpr int kGroupCount = kGroupM * kGroupN * kGroupK;

    static constexpr int kIterM = M / tM_ / kGroupM;
    static constexpr int kIterN = N / tN_ / kGroupN;
    static constexpr int kIterK = K / tK_ / kGroupK;

    static constexpr int kFootprintM = kIterM * tM_;
    static constexpr int kFootprintN = kIterN * tN_;
    static constexpr int kFootprintK = kIterK * tK_;

    static constexpr int kDeltaM = tM_ * ArrangementMN::dM;
    static constexpr int kDeltaN = tN_ * ArrangementMN::dN;
    static constexpr int kDeltaK = tK_ * (rK ? gK : 1);

    static constexpr auto kPartitionM = ArrangementMN::pM;
    static constexpr auto kPartitionN = ArrangementMN::pN;
    static constexpr auto kPartitionK = rK ? Partition::kRaked : Partition::kBlocked;

    __device__ static int3 get_offset(int group_id)
    {
        constexpr int kGroupMN = kGroupM * kGroupN;

        const auto mn = ArrangementMN::get_offset(group_id % kGroupMN, pair{});
        const int  k  = group_id / kGroupMN;

        return {mn.x * tM_, mn.y * tN_, k * tK_ * (rK ? 1 : kIterK)};
    }
};

namespace {

template
void Print_(TMap)
{
    std::cout << "M, N, K = " << TMap::M << " " << TMap::N << " " << TMap::K << "\n";
    std::cout << "TM, TN, TK = " << TMap::TileM << " " << TMap::TileN << " " << TMap::TileK << "\n";
    std::cout << "group count = " << TMap::kGroupCount << "\n";
    // std::cout << "M1, N1, K1 = " << TMap::M1 << " " << TMap::N1 << " " << TMap::K1 << "\n";
    std::cout << "itM, itN, itK = " << TMap::kIterM << " " << TMap::kIterN << " " << TMap::kIterK << "\n";
    std::cout << "fpM, fpN, fpK = " << TMap::kFootprintM << " " << TMap::kFootprintN << " " << TMap::kFootprintK
              << "\n";
    std::cout << "dM, dN, dK = " << TMap::kDeltaM << " " << TMap::kDeltaN << " " << TMap::kDeltaK << "\n";
}

}  // namespace

/// TODO: Striped partition?

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/thread_map.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/kernels/gemm/types.h"

#include 

namespace turbomind::gemm {

template
struct ThreadMap {
    static constexpr int kDimC = DimC;
    static constexpr int kDimS = DimS;

    static constexpr int kWarpCount = WarpCount;
    static constexpr int kAccessC   = AccessC;

    static constexpr int kWarpThreadC = WarpThreadC;
    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;

    static_assert(kWarpThreadC <= WARP_SIZE);

    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
    static constexpr int kWarpAccessS = kWarpThreadS;

    static constexpr int kWarpIterC = ceil_div(kDimC, kWarpAccessC);
    static constexpr int kWarpIterS = ceil_div(kDimS, kWarpAccessS);

    // Partition warps along the strided axis first to reduce strided iters
    static constexpr int kWarpS = kWarpIterS >= kWarpCount ? kWarpCount : kWarpIterS;
    static constexpr int kWarpC = kWarpCount > kWarpIterS ? kWarpCount / kWarpS : 1;

    static constexpr int kIterC = ceil_div(kWarpIterC, kWarpC);
    static constexpr int kIterS = ceil_div(kWarpIterS, kWarpS);

    // Allow partial tile when there is ONLY 1 iteration
    static_assert(kDimC % kWarpAccessC == 0 || kIterC == 1);

    // static_assert(kIterC > 0);
    // static_assert(kIterS > 0);

    static constexpr bool kAlignedC = (kDimC % kWarpAccessC == 0) && (kWarpIterC % kWarpC == 0);
    static constexpr bool kAlignedS = (kDimS % kWarpAccessS == 0) && (kWarpIterS % kWarpS == 0);

    static constexpr int kFootprintC = kWarpAccessC * kIterC;
    static constexpr int kFootprintS = kWarpAccessS * kIterS;

    static constexpr int kDeltaC = kWarpAccessC;
    static constexpr int kDeltaS = kWarpAccessS;

    // static constexpr int kDeltaC = kWarpAccessC * kWarpC;
    // static constexpr int kDeltaS = kWarpAccessS * kWarpS;

    __device__ static int2 get_offset(int warp_id, int lane_id)
    {
        int warp_offset_c = warp_id % kWarpC;
        int warp_offset_s = warp_id / kWarpC;

        int warp_thread_offset_c = lane_id % kWarpThreadC;
        int warp_thread_offset_s = lane_id / kWarpThreadC;

        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;

        // int cta_thread_offset_c = kWarpAccessC * warp_offset_c + warp_thread_offset_c * kAccessC;
        // int cta_thread_offset_s = kWarpAccessS * warp_offset_s + warp_thread_offset_s;

        return {cta_thread_offset_c, cta_thread_offset_s};
    }
};

template
__host__ __device__ static constexpr int2 idx2mk(int idx, pair)
{
    if constexpr (order == kColMajor) {
        return {idx % M, idx / M};
    }
    else {
        return {idx / K, idx % K};
    }
}

enum class Partition
{
    kBlocked,
    kRaked,
};

template
struct Blocked {
    static constexpr int gM = gM_;
    static constexpr int gN = gN_;

    // static_assert((gM - 1) * sM + (gN - 1) * sN == gM * gN - 1);

    static constexpr int dM = 1;
    static constexpr int dN = 1;

    static constexpr Partition pM = Partition::kBlocked;
    static constexpr Partition pN = Partition::kBlocked;

    template
    __device__ static int2 get_offset(int idx, pair)
    {
        constexpr int iM = ceil_div(M, gM);
        constexpr int iN = ceil_div(N, gN);

        // const int mi = idx / sM % gM;
        // const int ni = idx / sN % gN;

        const int2 mn = idx2mk(idx, pair{});
        return {mn.x * iM, mn.y * iN};
    }
};

template
struct Raked {
    static constexpr int gM = gM_;
    static constexpr int gN = gN_;

    // static_assert((gM - 1) * sM + (gN - 1) * sN == gM * gN - 1);

    static constexpr int dM = gM;
    static constexpr int dN = gN;

    static constexpr Partition pM = Partition::kRaked;
    static constexpr Partition pN = Partition::kRaked;

    template
    __device__ static int2 get_offset(int idx, Shape)
    {
        return idx2mk(idx, pair{});
    }
};

template
struct Blocked_C_Raked_S {
    static constexpr int gM = gM_;
    static constexpr int gN = gN_;

    static constexpr int dM = 1;
    static constexpr int dN = gN;

    static constexpr Partition pM = Partition::kBlocked;
    static constexpr Partition pN = Partition::kRaked;

    template
    __device__ static int2 get_offset(int idx, pair)
    {
        constexpr int iM = ceil_div(M, gM);

        const int2 mn = idx2mk(idx, pair{});
        return {mn.x * iM, mn.y};
    }
};

template
         typename Arrangement_,
         int WarpCount,
         int WarpThrC = std::min(WARP_SIZE, C / AccessC)>
struct ThreadMap_V2 {
    static constexpr int kDimC = C;
    static constexpr int kDimS = S;

    static constexpr int kWarpCount = WarpCount;
    static constexpr int kAccessC   = AccessC;

    static_assert(WarpThrC <= WARP_SIZE);

    static constexpr int kWarpThreadC = WarpThrC;
    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;

    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
    static constexpr int kWarpAccessS = kWarpThreadS;

    static constexpr int kWarpIterC = ceil_div(kDimC, kWarpAccessC);
    static constexpr int kWarpIterS = ceil_div(kDimS, kWarpAccessS);

    static constexpr int kWarpS = kWarpIterS >= kWarpCount ? kWarpCount : kWarpIterS;
    static constexpr int kWarpC = kWarpCount > kWarpIterS ? kWarpCount / kWarpS : 1;

    using Arrangement = Arrangement_;

    static constexpr auto kPartitionM = Arrangement::pM;
    static constexpr auto kPartitionN = Arrangement::pN;

    static constexpr int kIterC = ceil_div(kWarpIterC, kWarpC);
    static constexpr int kIterS = ceil_div(kWarpIterS, kWarpS);

    static constexpr bool kAlignedC = (kDimC % kWarpAccessC == 0) && (kWarpIterC % kWarpC == 0);
    static constexpr bool kAlignedS = (kDimS % kWarpAccessS == 0) && (kWarpIterS % kWarpS == 0);

    static constexpr int kFootprintC = kWarpAccessC * kIterC;
    static constexpr int kFootprintS = kWarpAccessS * kIterS;

    static constexpr int kDeltaC = kWarpAccessC * Arrangement::dM;
    static constexpr int kDeltaS = kWarpAccessS * Arrangement::dN;

    __device__ static int2 get_offset(int warp_id, int lane_id)
    {
        const int2 warp_offset = Arrangement::get_offset(warp_id, pair{});

        int warp_thr_offset_c = lane_id % kWarpThreadC;
        int warp_thr_offset_s = lane_id / kWarpThreadC;

        if constexpr (kWarpThreadC == WARP_SIZE) {
            warp_thr_offset_c = lane_id;
            warp_thr_offset_s = 0;
        }

        const int offset_c = warp_offset.x * kWarpAccessC + warp_thr_offset_c * kAccessC;
        const int offset_s = warp_offset.y * kWarpAccessS + warp_thr_offset_s;

        return {offset_c, offset_s};
    }
};

namespace {

template
void Print(TMap)
{
    std::cout << "     warps: " << TMap::kWarpCount << "\n";
    std::cout << "     shape: (" << TMap::kDimC << ", " << TMap::kDimS << ")\n";
    std::cout << "    access: (" << TMap::kAccessC << ", " << 1 << ")\n";
    std::cout << "warpThread: (" << TMap::kWarpThreadC << ", " << TMap::kWarpThreadS << ")\n";
    std::cout << "warpAccess: (" << TMap::kWarpAccessC << ", " << TMap::kWarpAccessS << ")\n";
    std::cout << "  warpIter: (" << TMap::kWarpIterC << ", " << TMap::kWarpIterS << ")\n";
    std::cout << "      warp: (" << TMap::kWarpC << ", " << TMap::kWarpS << ")\n";
    std::cout << "      iter: (" << TMap::kIterC << ", " << TMap::kIterS << ")\n";
    std::cout << " footprint: (" << TMap::kFootprintC << ", " << TMap::kFootprintS << ")\n";
    std::cout << "     delta: (" << TMap::kDeltaC << ", " << TMap::kDeltaS << ")\n";
    std::cout << "   aligned: (" << TMap::kAlignedC << "," << TMap::kAlignedS << ")\n";
}

}  // namespace

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tiled_mma.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/core/mma.h"
#include "src/turbomind/kernels/core/smem.h"
#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/simt.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/thread_map.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"

namespace turbomind::gemm {

template
struct Tiled_MMA_v2 {
    using Atom = MMA_Atom_;
    using Map  = MMA_Map_;

    static constexpr int M = Map::M;
    static constexpr int N = Map::N;
    static constexpr int K = Map::K;

    static constexpr int kGroupCount  = Map::kGroupCount;
    static constexpr int kThreadCount = kGroupCount * Atom::kThreadCount;

    static constexpr int kTileIterM = Map::kIterM;
    static constexpr int kTileIterN = Map::kIterN;
    static constexpr int kTileIterK = Map::kIterK;

    static constexpr int kDeltaM = Map::kDeltaM;
    static constexpr int kDeltaN = Map::kDeltaN;
    static constexpr int kDeltaK = Map::kDeltaK;

    static constexpr int kAtomM = Map::TileM / Atom::M;
    static constexpr int kAtomN = Map::TileN / Atom::N;
    static constexpr int kAtomK = Map::TileK / Atom::K;

    static constexpr int kMmaIterM = kTileIterM * kAtomM;
    static constexpr int kMmaIterN = kTileIterN * kAtomN;
    static constexpr int kMmaIterK = kTileIterK * kAtomK;

    __device__ static int3 get_offset(int thread_idx)
    {
        return Map::get_offset(Atom::get_group_id(thread_idx));
    }

    // (M,N)
    template
    __device__ static void mma_k_iter(FragD& frag_D, const FragA& frag_A, const FragB& frag_B, const FragC& frag_C)
    {
        if constexpr (order_ == kColMajor) {
            PRAGMA_UNROLL
            for (int n = 0; n < kMmaIterN; ++n) {
                PRAGMA_UNROLL
                for (int m = 0; m < kMmaIterM; ++m) {
                    int mm = n % 2 ? (kMmaIterM - m - 1) : m;
                    Atom::fma(frag_D[mm][n], frag_A[mm], frag_B[n], frag_C[mm][n]);
                }
            }
        }
        else {
            PRAGMA_UNROLL
            for (int m = 0; m < kMmaIterM; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < kMmaIterN; ++n) {
                    int nn = n;
                    int mm = m;
                    Atom::fma(frag_D[mm][nn], frag_A[mm], frag_B[nn], frag_C[mm][nn]);
                }
            }
        }
    }
};

template
struct Rearrange {
    using Map  = typename MMA::Map;
    using Atom = typename MMA::Atom;

    template
    __device__ static void
    apply(Array (&frag_C)[M][N], SmemAccessorV2& smem_C, int2 offset_mn, pair)
    {
        const int3 offset_mnk = MMA::get_offset(threadIdx.x);
        const int  group_id_k = offset_mnk.z / Map::kFootprintK;

        constexpr bool kRakedM = Map::kPartitionM == Partition::kRaked;
        constexpr bool kRakedN = Map::kPartitionN == Partition::kRaked;

        static constexpr int2 kMN0 = cs2mk(Layout::C0, Layout::S0);

        constexpr int kPeriodM  = ceil_div(kMN0.x, Map::kDeltaM);
        constexpr int kPeriodN  = ceil_div(kMN0.y, Map::kDeltaN);
        constexpr int kPeriodM1 = ceil_div(kMN0.x, Atom::M);
        constexpr int kPeriodN1 = ceil_div(kMN0.y, Atom::N);

        constexpr auto offset_C = Atom::static_offset_C();
        const int2     thr      = Atom::thread_offset_C();

        // Contract: All these indices is not a part of swizzling
        int phases[kPeriodM][kPeriodN][kPeriodM1][kPeriodN1][offset_C.size()];
        PRAGMA_UNROLL
        for (int m = 0; m < kPeriodM; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < kPeriodN; ++n) {
                PRAGMA_UNROLL
                for (int m1 = 0; m1 < kPeriodM1; ++m1) {
                    PRAGMA_UNROLL
                    for (int n1 = 0; n1 < kPeriodN1; ++n1) {
                        const int mm = offset_mnk.x + m * Map::kDeltaM + m1 * Atom::M + thr.x;
                        const int nn = offset_mnk.y + n * Map::kDeltaN + n1 * Atom::N + thr.y;
                        PRAGMA_UNROLL
                        for (int i = 0; i < offset_C.size(); ++i) {
                            const int2 cs           = mk2cs(mm + offset_C[i].x, nn + offset_C[i].y);
                            phases[m][n][m1][n1][i] = Layout::apply(cs.y, cs.x);
                        }
                    }
                }
            }
        }

        constexpr int K = Map::kGroupK;
        constexpr int C = offset_C.size();

        int offsets[K][M][N][C];
        int masks[K][M][N][C];

        PRAGMA_UNROLL
        for (int k = 0; k < K; ++k) {
            PRAGMA_UNROLL
            for (int m = 0; m < M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < N; ++n) {
                    int m0 = m / MMA::kAtomM, m1 = m % MMA::kAtomM, n0 = n / MMA::kAtomN, n1 = n % MMA::kAtomN;
                    int m01 =
                        m0 / kPeriodM * kPeriodM * Map::kDeltaM + m1 / kPeriodM1 * kPeriodM1 * Atom::M - offset_mn.x;
                    int n01 =
                        n0 / kPeriodN * kPeriodN * Map::kDeltaN + n1 / kPeriodN1 * kPeriodN1 * Atom::N - offset_mn.y;
                    const int2 cs       = mk2cs(m01, n01);
                    int        offset_0 = Layout::apply(cs.y, cs.x);
                    PRAGMA_UNROLL
                    for (int i = 0; i < offset_C.size(); ++i) {
                        int offset_1        = phases[m0 % kPeriodM][n0 % kPeriodN][m1 % kPeriodM1][n1 % kPeriodN1][i];
                        offsets[k][m][n][i] = offset_0 + offset_1;
                        const int bm        = offset_mnk.x - offset_mn.x + m0 * Map::kDeltaM + m1 * Atom::M + thr.x;
                        const int bn        = offset_mnk.y - offset_mn.y + n0 * Map::kDeltaN + n1 * Atom::N + thr.y;
                        const int mm        = kRakedM ? m01 : bm;
                        const int nn        = kRakedN ? n01 : bn;
                        masks[k][m][n][i]   = (Map::kGroupK == 1 || group_id_k == k)
                                            && (TM >= Map::M || (0 <= mm && mm < TM))
                                            && (TN >= Map::N || (0 <= nn && nn < TN));
                    }
                }
            }
        }

        auto _store = [](auto ptr, auto offset, auto vec) {
            if constexpr (order == kRowMajor) {
                Store(&ptr[offset], vec);
            }
            else {
                for (int i = 0; i < vec.size(); ++i) {
                    ptr[offset + Layout::apply(i, 0)] = vec[i];
                }
            }
        };

        typename Atom::FragC_ reshape_C;

        auto ptr = &smem_C(0, 0);

        PRAGMA_UNROLL
        for (int m = 0; m < M; ++m) {
            PRAGMA_UNROLL
            for (int n = 0; n < N; ++n) {
                Atom::ReshapeC(frag_C[m][n], reshape_C);
                PRAGMA_UNROLL
                for (int c = 0; c < C; ++c) {
                    auto& vec    = reshape_C[c];
                    int   offset = offsets[0][m][n][c];
                    if (masks[0][m][n][c]) {
                        _store(ptr, offset, vec);
                    }
                }
            }
        }

        __syncthreads();

#if 1
        auto _load = [](auto ptr, auto offset, auto& vec) {
            if constexpr (order == kRowMajor) {
                Load(vec, &ptr[offset]);
            }
            else {
                for (int i = 0; i < vec.size(); ++i) {
                    vec[i] = ptr[offset + Layout::apply(i, 0)];
                }
            }
        };

        PRAGMA_UNROLL
        for (int k = 1; k < K; ++k) {
            PRAGMA_UNROLL
            for (int m = 0; m < M; ++m) {
                PRAGMA_UNROLL
                for (int n = 0; n < N; ++n) {
                    Atom::ReshapeC(frag_C[m][n], reshape_C);
                    PRAGMA_UNROLL
                    for (int c = 0; c < C; ++c) {
                        auto& vec    = reshape_C[c];
                        int   offset = offsets[k][m][n][c];
                        if (masks[k][m][n][c]) {
                            std::remove_reference_t tmp;
                            _load(ptr, offset, tmp);
                            {
                                using namespace ops;
                                vec = vec + tmp;
                            }
                            _store(ptr, offset, vec);
                        }
                    }
                }
            }
            __syncthreads();
        }
#endif
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tma.cu
================================================

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/cuda_data_type.h"
#include "src/turbomind/kernels/gemm/tma.h"

namespace turbomind::gemm {

#if __CUDACC_VER_MAJOR__ >= 12

#if (CUDA_VERSION >= 13000) && (!defined(PFN_cuTensorMapEncodeTiled))
// PFN_cuTensorMapEncodeTiled not defined in cuda 13 headers.
#define PFN_cuTensorMapEncodeTiled PFN_cuTensorMapEncodeTiled_v12000
#endif

namespace {

PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()
{
    static const auto ptr = [] {
        // Get pointer to `cuTensorMapEncodeTiled`
        cudaDriverEntryPointQueryResult driver_status;
        void*                           cuTensorMapEncodeTiled_ptr = nullptr;

// https://github.com/NVIDIA/cutlass/pull/2086
#if CUDA_VERSION >= 13000
        cudaGetDriverEntryPointByVersion(
            "cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, cudaEnableDefault, &driver_status);
#else
        cudaGetDriverEntryPoint(
            "cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, &driver_status);
#endif
        TM_CHECK_EQ(driver_status, cudaDriverEntryPointSuccess);
        return reinterpret_cast(cuTensorMapEncodeTiled_ptr);
    }();
    return ptr;
}

CUtensorMap make_2d_tma_desc(void*              global_address,
                             DataType           data_type,
                             uint64_t           gmem_dims[2],
                             uint64_t           stride_in_bytes,
                             uint32_t           smem_dims[2],
                             CUtensorMapSwizzle swizzle)
{
    uint64_t global_stride[1] = {stride_in_bytes};
    uint32_t elem_strides[2]  = {1, 1};

    auto encode_func = get_cuTensorMapEncodeTiled();

    CUtensorMap tensor_map = {};

    auto result = encode_func(&tensor_map,
                              to_CUtensorMap_dtype(data_type),
                              2,
                              global_address,
                              gmem_dims,
                              global_stride,
                              smem_dims,
                              elem_strides,
                              CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
                              swizzle,
                              CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
                              CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

    TM_CHECK_EQ(result, CUDA_SUCCESS);

    return tensor_map;
}

}  // namespace

CUtensorMap make_2d_tma_desc(void*              global_address,
                             DataType           data_type,
                             uint32_t           gmem_rows,
                             uint32_t           gmem_cols,
                             uint32_t           smem_rows,
                             uint32_t           smem_cols,
                             Order              order,
                             CUtensorMapSwizzle swizzle,
                             int                ld)
{
    if (order == kRowMajor) {
        uint64_t gmem_dims[] = {gmem_cols, gmem_rows};
        uint32_t smem_dims[] = {smem_cols, smem_rows};
        return make_2d_tma_desc(
            global_address, data_type, gmem_dims, byte_size(data_type, ld ? ld : gmem_cols), smem_dims, swizzle);
    }
    else {
        uint64_t gmem_dims[] = {gmem_rows, gmem_cols};
        uint32_t smem_dims[] = {smem_rows, smem_cols};
        return make_2d_tma_desc(
            global_address, data_type, gmem_dims, byte_size(data_type, ld ? ld : gmem_rows), smem_dims, swizzle);
    }
}

CUtensorMap make_2d_tma_desc(void* ptr, const MatrixLayout& desc, uint2 smem_shape, CUtensorMapSwizzle swizzle)
{
    return make_2d_tma_desc(
        ptr, desc.type, desc.rows, desc.cols, smem_shape.x, smem_shape.y, desc.order, swizzle, desc.ld);
}

#endif

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tma.h
================================================
#include 
#include 
#include 
#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

#if __CUDACC_VER_MAJOR__ >= 12

CUtensorMap make_2d_tma_desc(void*              global_address,
                             DataType           data_type,
                             uint32_t           gmem_rows,
                             uint32_t           gmem_cols,
                             uint32_t           smem_rows,
                             uint32_t           smem_cols,
                             Order              order,
                             CUtensorMapSwizzle swizzle,
                             int                ld = 0);

CUtensorMap make_2d_tma_desc(void* ptr, const MatrixLayout& desc, uint2 smem_shape, CUtensorMapSwizzle swizzle);

constexpr CUtensorMapSwizzle get_tma_swizzle(int bytes)
{
    switch (bytes) {
        case 128:
            return CU_TENSOR_MAP_SWIZZLE_128B;
        case 64:
            return CU_TENSOR_MAP_SWIZZLE_64B;
        case 32:
            return CU_TENSOR_MAP_SWIZZLE_32B;
        case 16:  // unit swizzle is equivalent to "none"
        case 0:
            return CU_TENSOR_MAP_SWIZZLE_NONE;
        default:
            throw std::logic_error("unsupported swizzle type: " + std::to_string(bytes));
    }
    return {};
}

#endif

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/transform.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/attention/quantization.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/kernels/gemm/smem_copy.h"
#include "src/turbomind/kernels/gemm/tiled_mma.h"

namespace turbomind::gemm {

struct Transform_Default {
    template
    __device__ static void apply(Array (&frag)[K][Mf], int k, Array (&data)[K][Md], S&, int div)
    {
        static_assert(Nf * Mf == Nd * Md);
        static_assert(Nd % Nf == 0 && Mf % Md == 0);
        static_assert(sizeof(frag) == sizeof(data));

        // Alignment must be manually enforced for `reinterpret_cast`
        auto& frag_k = reinterpret_cast(&)[Md]>(frag[k]);
        auto& data_k = data[k];

        PRAGMA_UNROLL
        for (int i = 0; i < std::size(frag_k); ++i) {
            frag_k[i] = data_k[i];
        }
    }
};

template
struct Transform_HMMA_16816 {
    template
    __device__ static void
    apply(Array (&frag)[K][Mf], int k, Array (&data)[K][Md], Array (&stat)[Ks][Ms], int div)
    {
        static_assert(Nf * Mf == Nd * Md);
        static_assert(Nd % Nf == 0 && Mf % Md == 0);
        static_assert(Nf * Mf == Ns * Ms * 4);

        auto& frag_k = reinterpret_cast(&)[Md]>(frag[k]);
        auto& stat_k = reinterpret_cast(&)[Ns * Ms]>(stat[k / div]);
        auto& data_k = data[k];

        PRAGMA_UNROLL
        for (int m = 0; m < Md; ++m) {
            auto tmp = ConvertKvCache::convert(data_k[m]);
            static_assert(Nd % 8 == 0);
            PRAGMA_UNROLL
            for (int i = 0; i < Nd; i += 8) {
                PRAGMA_UNROLL
                for (int s = 0; s < 2; ++s) {
                    PRAGMA_UNROLL
                    for (int c = 0; c < 2; ++c) {
                        const int idx = (m * Nd + i) / 8 * 2 + s * StatStepS + c * StatStepC;
                        dequant((Array&)tmp[i + s * 4 + c * 2], stat_k[idx]);
                    }
                }
            }

            frag_k[m] = tmp;
        }
    }

    template
    __device__ static void dequant(Array& x, Array s)
    {
        Array& _s = (Array&)s;
        x[0]            = __hfma(x[0], _s[0], _s[1]);
        x[1]            = __hfma(x[1], _s[0], _s[1]);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        bfloat16_t s1 = __ushort_as_bfloat16((uint16_t)s[0] << 7);
        x[0]          = __hmul(x[0], s1);
        x[1]          = __hmul(x[1], s1);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        // half_t s1 = __ushort_as_half(((uint16_t)s[0] + 15 - 127) << 10);
        // Adjusted in `AdjustUe8m0ScaleForHalf`
        half_t s1 = __ushort_as_half((uint16_t)s[0] << 10);
        x[0]      = __hmul(x[0], s1);
        x[1]      = __hmul(x[1], s1);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        auto s1 = __ushort_as_bfloat16(s[0]);
        x[0]    = __hmul(x[0], s1);
        x[1]    = __hmul(x[1], s1);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        auto s1 = __ushort_as_half(s[0]);
        x[0]    = __hmul(x[0], s1);
        x[1]    = __hmul(x[1], s1);
    }
};

// Used by SM70 MMA
struct Transform_HMMA_SIMT_B {
    template
    __device__ static void
    apply(Array (&frag)[K][Mf], int k, Array (&data)[K][Md], Array (&stat)[Ks][Ms], int div)
    {
        static_assert(Nf * Mf == Nd * Md);
        static_assert(Nd % Nf == 0 && Mf % Md == 0);

        auto& frag_k = reinterpret_cast(&)[Md]>(frag[k]);
        auto& stat_k = reinterpret_cast(&)[Ns * Ms]>(stat[k / div]);
        auto& data_k = data[k];

        // static_assert(Nf != Nf);

        PRAGMA_UNROLL
        for (int m = 0; m < Md; ++m) {
            auto tmp = ConvertKvCache::convert(data_k[m]);
            PRAGMA_UNROLL
            for (int i = 0; i < Nd; i += 2) {
                dequant((Array&)tmp[i], stat_k[(m * Nd + i) / Nf]);
            }
            frag_k[m] = tmp;
        }
    }

    template
    __device__ static void dequant(Array& x, Array s)
    {
        Array& _s = (Array&)s;

        x[0] = __hfma(x[0], _s[0], _s[1]);
        x[1] = __hfma(x[1], _s[0], _s[1]);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        // half_t s1 = __ushort_as_half(((uint16_t)s[0] + 15 - 127) << 10);
        // Adjusted in `AdjustUe8m0ScaleForHalf`
        half_t s1 = __ushort_as_half((uint16_t)s[0] << 10);
        x[0]      = __hmul(x[0], s1);
        x[1]      = __hmul(x[1], s1);
    }

    __device__ static void dequant(Array& x, Array s)
    {
        auto s1 = __ushort_as_half(s[0]);
        x[0]    = __hmul(x[0], s1);
        x[1]    = __hmul(x[1], s1);
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/cache_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/tuner/cache_utils.h"

namespace turbomind::gemm {

CacheFlushing::CacheFlushing()
{
    cudaDeviceProp props{};
    cudaGetDeviceProperties(&props, 0);

    size_ = props.l2CacheSize;

    cudaMalloc(&buffer_, size_);
}

void CacheFlushing::flush(cudaStream_t stream)
{
    thread_local CacheFlushing inst{};
    inst(stream);
}

void CacheFlushing::operator()(cudaStream_t stream) const
{
    cudaMemsetAsync(buffer_, 0, size_, stream);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/cache_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

namespace turbomind::gemm {

class CacheFlushing {
public:
    static void flush(cudaStream_t stream = {});

private:
    CacheFlushing();
    void operator()(cudaStream_t stream) const;

    uint32_t* buffer_;
    size_t    size_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/measurer.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/tuner/cache_utils.h"
#include "src/turbomind/kernels/gemm/tuner/measurer.h"
#include 

namespace turbomind::gemm {

Measurer::Measurer(std::unique_ptr stop_criterion): stop_criterion_{std::move(stop_criterion)}
{
    cudaEventCreate(&ev_beg_);
    cudaEventCreate(&ev_end_);
}

Measurer::~Measurer()
{
    cudaEventDestroy(ev_beg_);
    cudaEventDestroy(ev_end_);
    ev_beg_ = ev_end_ = {};
}

std::vector
Measurer::Measure(const std::vector& specs, const Launcher& launcher, cudaStream_t stream)
{
    std::vector m;
    m.reserve(specs.size());
    for (const auto& spec : specs) {
        auto measure = MeasureOne(spec, launcher, stream);
        if (measure.sample_count) {
            m.push_back(measure);
        }
        /// TODO: report error
    }
    return m;
}

Measurement Measurer::MeasureOne(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream)
{
    Stats       stats{};
    cudaError_t status = cudaSuccess;
    while (true) {
        float ms{};
        std::tie(ms, status) = ColdRun(spec, launcher, stream);
        if (status != cudaSuccess) {
            break;
        }
        stats.add_sample(ms);
        // std::cout << spec.kernel->name() << " " << spec.swizzle << " " << stats.count() << " " << stats.mean() << " "
        //           << stats.get_variance() << "\n";
        if (stop_criterion_->should_stop(stats)) {
            break;
        }
    }
    return Measurement{
        status,
        stats.count(),
        stats.mean(),
        stats.get_variance(),
    };
}

std::pair Measurer::ColdRun(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream)
{
    CacheFlushing::flush(stream);

    cudaEventRecord(ev_beg_, stream);

    // std::cout << spec.kernel->name() << " " << spec.splits << " " << spec.swizzle << std::endl;

    launcher(spec, stream);

    cudaEventRecord(ev_end_, stream);
    cudaEventSynchronize(ev_end_);

    const auto status = cudaGetLastError();
    float      ms{};

    if (status == cudaSuccess) {
        cudaEventElapsedTime(&ms, ev_beg_, ev_end_);
    }
    else {
        TM_CHECK(status == cudaSuccess) << cudaGetErrorString(status);
    }

    return {ms, status};
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/measurer.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/tuner/stopping_criterion.h"
#include 
#include 
#include 
#include 

namespace turbomind::gemm {

struct Measurement {
    cudaError_t status;
    int         sample_count;
    float       mean;
    float       variance;
};

using Launcher = std::function;

class Measurer {
public:
    Measurer(std::unique_ptr stop_criterion);

    ~Measurer();

    std::vector
    Measure(const std::vector& specs, const Launcher& launcher, cudaStream_t stream);

private:
    Measurement MeasureOne(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream);

    std::pair ColdRun(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream);

private:
    cudaEvent_t                        ev_beg_;
    cudaEvent_t                        ev_end_;
    std::unique_ptr stop_criterion_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/params.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/tuner/params.h"
#include "src/turbomind/utils/parser.h"
#include 
#include 
#include 

namespace turbomind::gemm {

void ParseTuningParams(TuningParams& params, const std::string& str)
{
    const auto list = ParseArgsList(str);

    auto try_parse = [&](auto& value, auto name) {
        auto it = std::find_if(list.begin(), list.end(), [&](auto a) { return a.first == name; });
        if (it != list.end()) {
            std::cout << name << " " << it->second << "\n";
            Parse(value, it->second);
        }
    };

    try_parse(params.max_splits, "max_splits");
    try_parse(params.max_waves, "max_waves");
    try_parse(params.swizzle, "swizzle");
    try_parse(params.top_k, "top_k");
    try_parse(params.clusters, "clusters");
    try_parse(params.min_iter, "min_iter");
    try_parse(params.max_iter, "max_iter");
    try_parse(params.max_time, "max_time");

    if (auto it = std::find_if(list.begin(), list.end(), [&](auto a) { return a.first == "seq"; }); it != list.end()) {
        params.seq = ParseTuningSequence(it->second);
    }
}

std::vector ParseTuningSequence(const std::string& str)
{
    const std::regex triplet(R"((\d+)-(\d+)-(\d+))");

    std::vector> generators;

    const auto tokens = ParseListOrTuple(str);

    for (const auto& token : tokens) {
        std::smatch match;
        if (std::regex_match(token, match, triplet)) {
            generators.push_back({std::stoi(match[1].str()),  //
                                  std::stoi(match[2].str()),
                                  std::stoi(match[3].str())});
        }
        else {  // must be an integer string
            generators.push_back({std::stoi(token), 0, 0});
        }
    }

    if (generators.size() == 1) {  // Replace sentinel of the default generators
        auto fallback   = GetDefaultTuningGenerators();
        fallback.back() = {generators.front().front(), 0, 0};
        generators      = std::move(fallback);
    }

    return GenerateTuningSequence(generators);
}

std::vector GenerateTuningSequence(const std::vector>& generators)
{
    std::vector ret;
    if (generators.empty()) {
        return ret;
    }
    const int last = generators.back().front();
    // The last generator is a sentinel `(max_bs, 0, 0)`
    for (int i = 0; i < (int)generators.size() - 1; ++i) {
        auto [curr, next, step] = generators[i];
        if (curr >= last) {
            break;
        }
        if (next == 0 && step == 0) {  // single value
            ret.push_back(curr);
        }
        else {  // generator
            const int end = std::min(generators[i + 1][0], last);
            while (curr < end) {
                ret.push_back(curr);
                if (curr == next) {
                    step *= 2;
                    next *= 2;
                }
                curr += step;
            }
        }
    }
    ret.push_back(last);
    return ret;
}

std::vector> GetDefaultTuningGenerators()
{
    /// TODO: set generators based on device
    return {{8, 16, 8}, {16, 64, 16}, {65536}};
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/params.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

namespace turbomind::gemm {

struct TuningParams {
    // Split-k params
    int max_splits = 8;
    int max_waves  = 10;

    // Swizzling params
    std::vector swizzle{0, 3};

    // Sampling params for hierarchical kernel selection
    float top_k    = 0;
    int   clusters = 5;
    int   min_iter = 1;
    int   max_iter = 10;
    float max_time = 1.f;

    std::vector seq;
};

// example
//   max_splits=8,top_splits=5,max_waves=16,top_k=10,swizzle=[2,3,4],clusters=5,max_iter=10,min_iter=1,max_time=10.0
void ParseTuningParams(TuningParams& params, const std::string& str);

// example
//   16-16-128,256-128-1024,8192
std::vector ParseTuningSequence(const std::string& str);

std::vector GenerateTuningSequence(const std::vector>& generators);

std::vector> GetDefaultTuningGenerators();

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/sampler.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/kernel.h"
#include "src/turbomind/kernels/gemm/tuner/sampler.h"
#include 
#include 
#include 
#include 

namespace turbomind::gemm {

template
static std::vector ArgSort(size_t size, const Cmp& cmp)
{
    std::vector idxs(size);
    std::iota(idxs.begin(), idxs.end(), 0);
    std::stable_sort(idxs.begin(), idxs.end(), cmp);
    return idxs;
}

std::vector Sampler::Run(std::vector specs, const Launcher& launcher, cudaStream_t stream)
{
    std::vector> clusters;  // ptr into `specs`
    if (k_clusters_) {
        clusters = Cluster(specs, ClusteringParam{true, true});
    }
    else {
        for (auto& s : specs) {
            clusters.push_back({s});
        }
    }
    // std::cout << "k_clusters=" << k_clusters_ << ", #specs" << specs.size() << ", #clusters" << clusters.size() <<
    // "\n";

    std::vector s_1;
    for (const auto& c : clusters) {
        s_1.push_back(c.front());
    }

    auto m_1 = measurer_.Measure(s_1, launcher, stream);

    auto idxs = ArgSort(m_1.size(), [&](int i, int j) { return m_1[i].mean < m_1[j].mean; });

    if (k_clusters_) {
        const auto top_k = std::min(k_clusters_, (int)idxs.size());
        idxs.resize(top_k);

        std::vector s_2;
        for (const auto& idx : idxs) {
            auto& cluster = clusters[idx];
            // Skip cluster leader
            for (size_t j = 1; j < cluster.size(); ++j) {
                s_2.push_back(cluster[j]);
            }
        }

        // std::cout << "#s_2=" << s_2.size() << "\n";

        auto m_2 = measurer_.Measure(s_2, launcher, stream);
        // Merge measurements of the 2 runs
        m_2.insert(m_2.end(), m_1.begin(), m_1.end());
        s_2.insert(s_2.end(), s_1.begin(), s_1.end());
        m_1.swap(m_2);
        s_1.swap(s_2);
    }

    idxs = ArgSort(m_1.size(), [&](int i, int j) { return m_1[i].mean < m_1[j].mean; });

    std::vector ret;
    for (const auto& i : idxs) {
        s_1[i].measured = m_1[i].mean;
        ret.push_back(s_1[i]);
    }

    return ret;
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/sampler.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/desc.h"
#include "src/turbomind/kernels/gemm/tuner/measurer.h"

#include 

namespace turbomind::gemm {

class Sampler {
public:
    explicit Sampler(Measurer& measurer, int k_clusters): measurer_{measurer}, k_clusters_{k_clusters} {}

    std::vector Run(std::vector specs, const Launcher& launcher, cudaStream_t stream);

private:
    Measurer& measurer_;
    int       k_clusters_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/stats.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

namespace turbomind::gemm {

class Stats {
public:
    Stats(): count_{}, mean_{}, m2_{} {}

    float mean() const noexcept
    {
        return mean_;
    }

    float sum() const noexcept
    {
        return mean_ * count_;
    }

    int count() const noexcept
    {
        return count_;
    }

    float get_variance() const noexcept
    {
        return count_ < 2 ? std::numeric_limits::quiet_NaN() : m2_ / count_;
    }

    void add_sample(float x) noexcept
    {
        ++count_;
        float delta = x - mean_;
        mean_ += delta / count_;
        float delta2 = x - mean_;
        m2_ += delta * delta2;
    }

private:
    int   count_;
    float mean_;
    float m2_;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/stopping_criterion.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/tuner/stopping_criterion.h"
#include 

namespace turbomind::gemm {

namespace stopping_criterions {

class Optimistic: public StoppingCriterion {
public:
    Optimistic(int min_iter, int max_iter, float max_ms)
    {
        min_iter_ = std::max(min_iter, 1);
        max_iter_ = max_iter > 0 ? max_iter : std::numeric_limits::max();
        max_ms_   = max_ms > 0 ? max_ms : std::numeric_limits::infinity();
    }
    bool should_stop(const Stats& stats) override
    {
        return stats.count() >= min_iter_ && (stats.count() >= max_iter_ || stats.sum() >= max_ms_);
    }

private:
    int   min_iter_;
    int   max_iter_;
    float max_ms_;
};

}  // namespace stopping_criterions

std::unique_ptr CreateStoppingCriterion(int min_iter, int max_iter, float max_ms)
{
    return std::make_unique(min_iter, max_iter, max_ms);
}

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/tuner/stopping_criterion.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/gemm/tuner/stats.h"
#include 

namespace turbomind::gemm {

class StoppingCriterion {
public:
    virtual ~StoppingCriterion()                 = default;
    virtual bool should_stop(const Stats& stats) = 0;
};

std::unique_ptr CreateStoppingCriterion(int min_iter, int max_iter, float max_ms);

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/types.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/core/data_type.h"
#include 
#include 

#if ENABLE_BF16
#include 
#endif

namespace turbomind::gemm {

enum class Order : int
{
    kColMajor = 0,
    kRowMajor = 1,
};

inline constexpr Order kColMajor = Order::kColMajor;
inline constexpr Order kRowMajor = Order::kRowMajor;

constexpr Order operator~(Order a)
{
    return a == kColMajor ? kRowMajor : kColMajor;
}

constexpr const char* to_string(Order order)
{
    switch (order) {
        case kColMajor:
            return "Col";
        case kRowMajor:
            return "Row";
    }
    return "";
}

using Pack = uint32_t;

typedef enum MMA_Tag
{
    HMMA_16816 = 0x100,  // sm80+
    HMMA_1688  = 0x200,  // sm75
    HMMA_884   = 0x300,  // sm70
    HMMA_SIMT  = 0x400,  // sm75-
} MMA_Tag;

typedef enum Op_Tag
{
    OPERAND_A = 0x010,
    OPERAND_B = 0x020,
    OPERAND_U = 0x030,
    OPERAND_V = 0x040,
    OPERAND_C = 0x050,
    OPERAND_D = 0x060,
} Op_Tag;

constexpr MMA_Tag get_mma_tag(Pack pack)
{
    return static_cast(pack & 0xf00);
}

constexpr Op_Tag get_operand_tag(Pack pack)
{
    return static_cast(pack & 0x0f0);
}

constexpr int get_pack_num(Pack pack)
{
    return pack & 0x00f;
}

enum class Striding : int
{
    kFlat,     // [1111,2222,3333]
    kRagged,   // [11,2222222,333]  [0 , 2      , 9  ]
    kIndexed,  // [xx xxxxxxx xxx], [01, 2345678, 9ab]
    kBlocked,  // [11][22222][333]
};

inline const char* to_string(Striding striding)
{
    switch (striding) {
        case Striding::kFlat:
            return "f";
        case Striding::kRagged:
            return "r";
        case Striding::kIndexed:
            return "i";
        case Striding::kBlocked:
            return "b";
        default:
            return "unknown";
    }
}

enum class QuantType : int
{
    kNone    = 0,
    kK       = 1,
    kM       = 2,
    kB       = 3,
    kDefault = kK,
};

inline const char* to_string(QuantType q)
{
    switch (q) {
        case QuantType::kNone:
            return "none";
        case QuantType::kK:
            return "k";
        case QuantType::kM:
            return "m";
        case QuantType::kB:
            return "b";
        default:
            return "unknown";
    }
}

enum class Epilogue : int
{
    kNone               = 0,
    kChannelCombination = 0x1,
    kGatedSilu          = 0x2,
};

struct QuantDesc {
    QuantType type;
    int       group_size;

    operator bool() const noexcept
    {
        return (int)type || group_size;
    }
};

inline std::string to_string(QuantDesc desc)
{
    if (desc) {
        return to_string(desc.type) + std::to_string(desc.group_size);
    }
    else {
        return to_string(desc.type);
    }
}

enum class DispatchPolicy : int
{
    kDefault = 0,
    kMeasure = 1,
    kReuse   = 2,
    kAppend  = 3,
};

constexpr bool operator&(const DispatchPolicy& a, const DispatchPolicy& b)
{
    return ((int)a & (int)b);
}

class Kernel;
class Context;

struct Tape {
    int   ctas;
    int   max_num;
    int   max_ctas;
    char* buffer;
    int4* gemm_shapes;
    int4* tiled_shapes;
    int4* tile_offsets;
    int2* iter_k_ranges;
    int*  tile_ids;
};

struct Operation {
    DispatchPolicy dispatch;
    Epilogue       epilogue;
    QuantDesc      quant_a;
    QuantDesc      quant_b;
    int            batch_dim;
    // void*          reserved;
};

inline Operation transpose(Operation o)
{
    std::swap(o.quant_a, o.quant_b);
    o.batch_dim = 1 - o.batch_dim;
    return o;
}

struct MatrixLayout {
    DataType type;
    Order    order;
    int      rows;
    int      cols;
    int      ld;
    Pack     pack;
    int      num;
    int*     offsets;
    int*     idxs;
};

inline std::ostream& operator<<(std::ostream& os, const MatrixLayout& x)
{
    os << x.type << " " << to_string(x.order) << " " << x.rows << " " << x.cols << " " << x.num << " " << x.ld;
    return os;
}

inline int64_t byte_size(const MatrixLayout& m)
{
    return byte_size(m.type, (int64_t)m.rows * m.cols);
}

inline Striding get_mode(const MatrixLayout& m)
{
    if (m.idxs) {
        return Striding::kIndexed;
    }
    else if (m.ld == 0 || m.offsets) {
        return Striding::kBlocked;
    }
    return Striding::kFlat;
}

struct Workspace {
    void*  barriers;
    size_t barriers_size;
    void*  partials;
    size_t partials_size;
    void*  tensormaps;
    size_t tensormaps_size;
    int*   flags;
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gemm/unpack.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/data_type.h"
#include 

namespace turbomind {

namespace {

__device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
{
    uint32_t old = *address;
    uint32_t assumed;
    do {
        assumed      = old;
        uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
        old          = atomicCAS(address, assumed, tmp);
    } while (assumed != old);
}

__device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
{
    return (*address >> (index * 4u)) & 0xfu;
}

template
__global__ void permute_u4(uint* dst, const uint* src, Array dims)
{
    constexpr int N = sizeof...(Ds);

    size_t count = 1;
    PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
        count *= dims[i];
    }

    constexpr int order[] = {Ds...};

    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {

        int indices[N]{};

        PRAGMA_UNROLL
        for (int j = N - 1, ii = i; j >= 0; --j) {
            indices[j] = ii % dims[j];
            ii /= dims[j];
        }

        auto data = read_u4(src + i / 8, i % 8);

        int index = 0;

        PRAGMA_UNROLL
        for (int j = N - 1, stride = 1; j >= 0; --j) {
            index += indices[order[j]] * stride;
            stride *= dims[order[j]];
        }

        atomic_assign_u4(dst + index / 8, index % 8, data);
    }
}

}  // namespace

// col-major interleaved
void unpack_awq_gemm(uint4_t* dst, const uint4_t* src, int rows, int cols, cudaStream_t st)
{
    Array shape{cols, rows / 8, 2, 4};
    permute_u4<0, 1, 3, 2><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape);
}

__global__ void transpose_u4_kernel(uint4_t* dst, const uint4_t* src, int s, int c)
{
    const int idx_c = 8 * (threadIdx.x + blockIdx.x * blockDim.x);
    const int idx_s = 8 * (threadIdx.y + blockIdx.y * blockDim.y);
    if (idx_c >= c || idx_s >= s) {
        return;
    }
    uint32_t ivec[8];
    PRAGMA_UNROLL
    for (int i = 0; i < 8; ++i) {
        ivec[i] = ((const uint32_t*)src)[((idx_s + i) * c + idx_c) / 8];
    }
    uint32_t ovec[8]{};
    PRAGMA_UNROLL
    for (int i = 0; i < 8; ++i) {
        PRAGMA_UNROLL
        for (int j = 0; j < 8; ++j) {
            ovec[i] |= (((ivec[j] >> (i * 4)) & 0xfu) << (j * 4));
        }
    }
    PRAGMA_UNROLL
    for (int i = 0; i < 8; ++i) {
        ((uint32_t*)dst)[((idx_c + i) * s + idx_s) / 8] = ovec[i];
    }
}

void transpose_u4(uint4_t* dst, const uint4_t* src, int s, int c, cudaStream_t st)
{
    if (s % 8 || c % 8) {
        std::cerr << "transpose_u4: invalid shape (" << s << "," << c << "), must be multiple of 8" << std::endl;
        return;
    }
    // Array shape{s, c};
    // permute_u4<1, 0><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape);

    const dim3 block(16, 16);
    const dim3 grid((c + 15) / 16, (s + 15) / 16);
    transpose_u4_kernel<<>>(dst, src, s, c);
}

// load -> unpack -> extend_to_u8 -> manipulation -> compat_to_u4 -> store
// load -> extend_to_u16 -> convert -> run

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gemm/utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/simt.h"
#include "src/turbomind/kernels/gemm/types.h"

namespace turbomind::gemm {

__host__ __device__ constexpr Order transpose(Order order)
{
    return order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;
}

__host__ __device__ constexpr MatrixLayout transpose(MatrixLayout x)
{
    auto tmp = x.cols;  // `std::swap` is not constexpr
    x.cols   = x.rows;
    x.rows   = tmp;
    x.order  = transpose(x.order);
    return x;
}

template
__host__ __device__ constexpr int2 mk2cs(int m, int k)
{
    if constexpr (order == Order::kRowMajor) {
        return {k, m};
    }
    else {
        return {m, k};
    }
}

template
__host__ __device__ constexpr int2 mk2cs(int2 mk)
{
    return mk2cs(mk.x, mk.y);
}

template
__host__ __device__ constexpr int2 cs2mk(int c, int s)
{
    if constexpr (order == Order::kRowMajor) {
        return {s, c};
    }
    else {
        return {c, s};
    }
}

template
__host__ __device__ constexpr int2 cs2mk(int2 cs)
{
    return cs2mk(cs.x, cs.y);
}

template
__host__ __device__ constexpr int2 _kn2cs(int k, int n)
{
    if constexpr (order == Order::kColMajor) {
        return {k, n};
    }
    else {
        return {n, k};
    }
}

template
__host__ __device__ constexpr Index cs2idx(int2 cs, Index ld)
{
    return ld * cs.y + cs.x;
}

template
__host__ __device__ constexpr Index cs2idx(int2 cs, Index ld, int s0)
{
    return ld * (cs.y + s0) + cs.x;
}

__host__ __device__ constexpr auto dot(int2 a, int2 b)
{
    return a.x * b.x + a.y * b.y;
}

__host__ __device__ constexpr auto dot(int2 a, long2 b)
{
    return a.x * b.x + a.y * b.y;
}

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        return mk;
    }
};

template
struct Packing_v2: PackingImpl {
};

/// TODO: move packing utility to arch/smem_copy_xxx

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        return {mk.x / 16 / num, mk.y * 16 * num};
    }
};

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        return {mk.x * 16, mk.y / 16};
    }
};

template
struct PackingImpl: PackingImpl {
};

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        return {mk.x / (simt::OP_M * num), mk.y * simt::OP_M * num};
    }
};

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        return {mk.x / (simt::OP_N * num), mk.y * simt::OP_N * num};
    }
};

template
struct PackingImpl {
    __host__ __device__ static constexpr int2 apply(int2 mk)
    {
        // return {mk.x / (16 * num), mk.y * 16 * num};
        return {mk.x / (32 * num), mk.y * 32 * num};
    }
};

}  // namespace turbomind::gemm


================================================
FILE: src/turbomind/kernels/gpt_kernels.cu
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {

template
__global__ void
embeddingLookupKernel(T* dst, int dst_stride, const T* src, int src_stride, const int* ids, int num, int dim)
{
    const int ti = blockIdx.x;

    const int64_t idx = ids[ti];

    src += idx * src_stride;
    dst += ti * dst_stride;

    for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {
        Array vec;
        Ldg(vec, &src[di]);
        Store(&dst[di], vec);
    }
}

void invokeEmbeddingLookup(Ref         out_,
                           const Buffer_& token_ids,
                           const Tensor&       embedding_table,
                           cudaStream_t        st)
{
    auto& out = out_.get();

    TM_CHECK_EQ(out.shape(0), token_ids.size());
    TM_CHECK_EQ(out.shape(1), embedding_table.shape(1));

    int num, dim;
    std::tie(num, dim) = out.shapes(0, 1);

    auto invoke = [&](auto t) {
        using T                = decltype(t);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        TM_CHECK(dim % vec_size == 0) << dim << " " << vec_size;
        const int threads = std::min(dim / vec_size, 1024);
        const int blocks  = num;
        TM_CHECK(out_.get());
        TM_CHECK(token_ids);
        TM_CHECK(embedding_table);
        embeddingLookupKernel<<>>((T*)out.raw_data(),
                                                                       out.stride(0),
                                                                       (const T*)embedding_table.raw_data(),
                                                                       embedding_table.stride(0),
                                                                       token_ids.data(),
                                                                       num,
                                                                       dim);
    };

    if (byte_size(out.dtype()) == byte_size()) {
        return invoke(uint16_t{});
    }
    TM_CHECK(0) << "not implemented";
}

// TODO Add half2 implementation
template
__global__ void transposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2)
{
    int index = threadIdx.x + blockIdx.x * blockDim.x;
    if (index < dim0 * dim1 * dim2) {
        const int input_dim2_index = index % dim2;
        index                      = (index - input_dim2_index) / dim2;
        const int input_dim1_index = index % dim1;
        index                      = (index - input_dim1_index) / dim1;
        const int input_dim0_index = index % dim0;

        out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + input_dim2_index] =
            in[input_dim0_index * dim1 * dim2 + input_dim1_index * dim2 + input_dim2_index];
    }
}

template
void invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(dim0 * dim1 * dim2 / 512.)));
    transposeAxis01<<>>(out, in, dim0, dim1, dim2);
}

template void
invokeTransposeAxis01(float* out, float* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

template void
invokeTransposeAxis01(half* out, half* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

template void
invokeTransposeAxis01(int* out, int* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

template void
invokeTransposeAxis01(uint16_t* out, uint16_t* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

template void
invokeTransposeAxis01(uint8_t* out, uint8_t* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

#ifdef ENABLE_BF16
template void invokeTransposeAxis01(
    __nv_bfloat16* out, __nv_bfloat16* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
#endif

template
__global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1)
{
    // out: [dim1, dim0]
    // in: [dim0, dim1]
    // in_skipping_dim1: [dim1]

    int index = threadIdx.x + blockIdx.x * blockDim.x;
    if (index < dim0 * dim1) {
        const int input_dim1_index = index % dim1;
        index                      = (index - input_dim1_index) / dim1;
        const int input_dim0_index = index % dim0;
        const int in_offset        = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1;

        out[input_dim1_index * dim0 + input_dim0_index] = in[in_offset + input_dim0_index * dim1 + input_dim1_index];
    }
}

template
void invokeTransposeAxis01(
    T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(dim0 * dim1 / 512.)));
    transposeAxis01<<>>(out, in, in_skipping_dim1, dim0, dim1);
}

template void invokeTransposeAxis01(
    int* out, int* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);

template
__global__ void transpose_2d_kernel(T* __restrict__ dst, const T* __restrict__ src, int rows, int cols, bool swap_xy)
{
    __shared__ T smem[TILE_DIM][TILE_DIM + 1];

    const int block_idx_x = swap_xy ? blockIdx.y : blockIdx.x;
    const int block_idx_y = swap_xy ? blockIdx.x : blockIdx.y;

    {
        const int j = block_idx_x * TILE_DIM + threadIdx.x;
        const int i = block_idx_y * TILE_DIM + threadIdx.y;

#pragma unroll
        for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {
            if (i + y < rows && j < cols) {
                smem[threadIdx.y + y][threadIdx.x] = src[(i + y) * cols + j];
            }
        }
    }

    __syncthreads();

    {
        const int j = block_idx_y * TILE_DIM + threadIdx.x;
        const int i = block_idx_x * TILE_DIM + threadIdx.y;

#pragma unroll
        for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {
            if (i + y < cols && j < rows) {
                dst[(i + y) * rows + j] = smem[threadIdx.x][threadIdx.y + y];
            }
        }
    }
}

template
void invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st)
{
    constexpr int TILE_DIM   = 32;  // warp size
    constexpr int BLOCK_ROWS = 8;

    const dim3 block(TILE_DIM, BLOCK_ROWS);

    dim3 grid((cols + TILE_DIM - 1) / TILE_DIM,  //
              (rows + TILE_DIM - 1) / TILE_DIM);
    bool swap_xy = false;

    if (grid.y > 65535) {  // max dim for grid.y
        std::swap(grid.x, grid.y);
        swap_xy = true;
    }

    transpose_2d_kernel<<>>(dst, src, rows, cols, swap_xy);
}

template void invokeTranspose2D_(uint32_t*, const uint32_t*, int, int, cudaStream_t);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/gpt_kernels.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 
#include 
#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {

template
struct inputIdsEmbeddingLookupPosEncodingSoftPromptParam {
    T*           from_tensor;
    int*         output_ids;
    int*         input_lengths;
    const T*     embedding_table;
    const T*     pos_table;
    const float* prefix_soft_prompt_embedding;
    const int*   prefix_soft_prompt_lengths;
    int*         input_ids;
    int          start_step;
    int          max_input_length;
    int          max_prefix_soft_prompt_length;
    int          batch_size;
    int          beam_width;
    int          hidden_units;
    cudaStream_t stream;
};

template
struct pPromptTuningParam {
    // Batch number of ptrs, each ptr is the ptr of the specific p/prompt tuning weights for this sequence
    const T** p_prompt_tuning_batch_weights = nullptr;
    // The start id of p_prompt_tuning token ids (based on the tokenizer)
    // PROMPT_0 --> p_prompt_tuning_id_start; PROMPT_1 --> p_prompt_tuning_id_start + 1; ...
    const int p_prompt_tuning_id_start = 0;
    // Request prompt embeddding's max length
    const int request_prompt_max_length = 0;
    // Whether or not use the request prompt embeddings
    const bool use_request_p_prompt_embedding = false;
    // Request prompt embeddings
    const T* request_prompt_embedding = nullptr;
};

template
void invokeInputIdsEmbeddingLookupPosEncoding(T*                    from_tensor,
                                              int*                  output_ids,
                                              const T*              embedding_table,
                                              const T*              pos_table,
                                              pPromptTuningParam prompt_param,
                                              const int*            input_ids,
                                              const int             start_step,
                                              const int             length,
                                              const int             max_length,
                                              const int             batch_size,
                                              const int             hidden_units,
                                              cudaStream_t          stream);

template
void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam param);

template
void invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);

template
void invokeTransposeAxis01(
    T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);

template
void invokeBuildDecoderAttentionMask(T*           attention_mask,
                                     const int*   sequence_lengths,
                                     const int*   prefix_prompt_lengths,
                                     const int    batch_size,
                                     const int    max_seq_len,
                                     const int    max_prompt_length,
                                     cudaStream_t stream);

template
void invokeLookupHiddenStateOfLastToken(T*           from_tensor,
                                        const T*     hidden_state,
                                        const int*   input_lengths,
                                        const int    max_input_length,
                                        const int    batch_size,
                                        const int    hidden_units,
                                        cudaStream_t stream);

void invokeTileGptPromptInputs(int*         tiled_input_ids,
                               int*         tiled_input_lengths,
                               int*         tiled_prompt_lengths,
                               const int*   input_ids,
                               const int*   input_lengths,
                               const int*   prefix_prompt_lengths,
                               const int    batch_size,
                               const int    beam_width,
                               const int    max_input_length,
                               cudaStream_t stream);

void invokeTileGptInputs(int*         tiled_input_ids,
                         int*         tiled_input_lengths,
                         const int*   input_ids,
                         const int*   input_lengths,
                         const int    batch_size,
                         const int    beam_width,
                         const int    max_input_length,
                         cudaStream_t stream);

void invokeFindContextDups(int*         shared_contexts,
                           int*         batch_to_compact,
                           int*         compact_to_batch,
                           int*         compact_size,
                           const int*   input_ids,
                           const size_t batch_size,
                           const size_t input_seq_len,
                           cudaStream_t stream = 0);

template
void invokeCompactInputs(T*           compact_input,
                         T*           compact_attention_mask,
                         int*         compact_input_lengths,
                         const T*     decoder_input,
                         const T*     decoder_mask,
                         const int*   input_lengths,
                         const int*   compact_idx,
                         size_t       compact_size,
                         size_t       seq_len,
                         size_t       hidden_dimension,
                         cudaStream_t stream = 0);

template
void invokeUnCompactOutputs(T*           uncompact_buffer,
                            const T*     compact_buffer,
                            const int*   batch_to_compact_idx,
                            size_t       batch_size,
                            size_t       buffer_stride,
                            cudaStream_t stream = 0);

template
void invokeUnCompactCaches(T*           uncompact_k_cache,
                           T*           uncompact_v_cache,
                           const T*     compact_k_cache,
                           const T*     compact_v_cache,
                           const int*   batch_to_compact_idx,
                           size_t       batch_size,
                           size_t       num_heads,
                           size_t       max_seq_len,
                           size_t       seq_len,
                           size_t       size_per_head,
                           size_t       local_batch_size,
                           size_t       ite,
                           cudaStream_t stream = 0);

void invokeUpdatePaddingCount(int*         total_padding_count,
                              const int*   input_lengths,
                              const int*   tiled_prompt_lengths,
                              size_t       max_input_length,
                              size_t       max_prompt_length,
                              size_t       batch_size,
                              size_t       beam_width,
                              cudaStream_t stream = 0);

inline void invokeUpdatePaddingCount(int*         total_padding_count,
                                     const int*   input_lengths,
                                     size_t       max_input_length,
                                     size_t       batch_size,
                                     size_t       beam_width,
                                     cudaStream_t stream = 0)
{
    invokeUpdatePaddingCount(
        total_padding_count, input_lengths, (const int*)nullptr, max_input_length, 0, batch_size, beam_width, stream);
}

void invokeMaskPaddingTokens(bool*        masked_tokens,
                             const int*   input_lengths,
                             const int*   tiled_prefix_prompt_lengths,
                             const size_t memory_len,
                             const size_t max_input_length,
                             const size_t initial_step,
                             size_t       batch_size,
                             size_t       beam_width,
                             cudaStream_t stream = 0);

inline void invokeMaskPaddingTokens(bool*        masked_tokens,
                                    const int*   input_lengths,
                                    const size_t memory_len,
                                    const size_t max_input_length,
                                    const size_t initial_step,
                                    size_t       batch_size,
                                    size_t       beam_width,
                                    cudaStream_t stream = 0)
{
    invokeMaskPaddingTokens(masked_tokens,
                            input_lengths,
                            (const int*)nullptr,
                            memory_len,
                            max_input_length,
                            initial_step,
                            batch_size,
                            beam_width,
                            stream);
}

template
void invokeSumLengthDimension(float*       out_buf,
                              const T*     in_buf,
                              const size_t batch_size,
                              const size_t input_length,
                              const size_t hidden_dim,
                              cudaStream_t stream = 0);

template
void invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st);

template
void invokeTranspose2D(T* dst, const T* src, int rows, int cols, cudaStream_t st)
{
    if constexpr (sizeof(T) == 4) {
        // FT_CHECK(0);
        invokeTranspose2D_((uint32_t*)dst, (const uint32_t*)src, rows, cols, st);
    }
    else {
        FT_CHECK(0);
    }
}

void invokeEmbeddingLookup(Ref         out_,
                           const Buffer_& token_ids,
                           const Tensor&       embedding_table,
                           cudaStream_t        st);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/logprob_kernels.cu
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 
#include 
#include 

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include 
#else
#include "3rdparty/cub/cub.cuh"
#endif

#include "src/turbomind/kernels/logprob_kernels.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/macro.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

template
__global__ void log_probs_kernel(float*       log_probs,
                                 const T*     logits,
                                 const int*   ids,
                                 const int*   lengths,
                                 const size_t max_input_length,
                                 const size_t batch_size,
                                 const size_t vocab_size,
                                 const size_t vocab_size_padded,
                                 bool         batch_first)
{
    // Calculate the log probability from logits.
    //   log_probs[t, :] = log(softmax(logits))[ids[t + 1, :]]
    //
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length -1],
    //     log probabilities of each token.
    // logits: [max_length, batch_size, vocab_size_padded] or [batch_size, max_length, vocab_size_padded]
    // lengths: [batch_size], sequence lengths
    // ids: [max_length, batch_size], token ids.
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.
    // vocab_size: [1], vocab_size,
    // vocab_size: [1], vocab_size_padded, padded vocab size.

    const bool IS_FP16   = std::is_same::value;
    const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

    int tidx = threadIdx.x;                            // vocab dim
    int bidx = batch_first ? blockIdx.x : blockIdx.y;  // batch dim
    int step = batch_first ? blockIdx.y : blockIdx.x;  // step dim

    __shared__ float s_max_logit;

    if (bidx < batch_size && step < lengths[bidx] - 1) {
        // reposition logits to data for the current batch.
        int step_offset  = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded;
        int batch_offset = batch_first ? bidx * max_input_length * vocab_size_padded : bidx * vocab_size_padded;
        logits += step_offset + batch_offset;

        // Find max(logits).
        float local_max = -MAX_T_VAL;
        float val       = -MAX_T_VAL;
        for (int i = tidx; i < vocab_size; i += blockDim.x) {
            val       = static_cast(logits[i]);
            local_max = fmax(local_max, val);
        }

        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max);
        if (tidx == 0) {
            s_max_logit = max_val;
        }
        __syncthreads();

        // Calculate the denominator: sum_i exp(logits[i])
        float local_sum_exp = 0.0f;
        for (int i = tidx; i < vocab_size; i += blockDim.x) {
            val = __expf(static_cast(logits[i]) - s_max_logit);
            local_sum_exp += val;
        }

        float sum_exp = blockDim.x <= 32 ? warpReduceSum(local_sum_exp) : blockReduceSum(local_sum_exp);
        if (tidx == 0) {
            int idx = batch_first ? step + bidx * (max_input_length - 1) : step * batch_size + bidx;
            // log_probs[step, ...] is the log probability of a token at step t + 1.
            int token_idx  = batch_first ? step + 1 + bidx * max_input_length : (step + 1) * batch_size + bidx;
            log_probs[idx] = static_cast(logits[ids[token_idx]]) - s_max_logit - __logf(sum_exp + 1e-9f);
        }
    }
}

__global__ void accumulate_log_probs(float*       cum_log_probs,
                                     const float* log_probs,
                                     const int*   lengths,
                                     const size_t max_input_length,
                                     const size_t batch_size,
                                     const bool   batch_first)
{
    // Accumulate the log probability along with the sequence dimension.
    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
    //
    // cum_log_probs: [batch_size], cumulative log probability
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
    //   log probability of each token
    // lengths: [batch_size], sequence lengths
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.

    int bidx = blockIdx.x;   // batch dim
    int tidx = threadIdx.x;  // step dim

    if (bidx < batch_size) {
        int length = lengths[bidx];
        // reposition logits to data for the current batch.
        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.
        float local_accum = 0.0f;
        for (int step = tidx; step < length - 1; step += blockDim.x) {
            local_accum += static_cast(log_probs[step * stride]);
        }
        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum(local_accum);
        if (tidx == 0) {
            cum_log_probs[bidx] = accum;
        }
    }
}

template
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first)
{
    // A batched version of log prob computation.
    //
    // cum_log_probs: [batch_size]
    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
    // input_lengths: [batch_size]
    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.

    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
    // block_size should be multiple of 32 to use warpReduceMax.
    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
    assert(block_size % 32 == 0);
    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
    assert(vocab_size <= vocab_size_padded);

    float* log_probs = reinterpret_cast(workspace);
    int    gx        = batch_first ? batch_size : max_input_length - 1;
    int    gy        = batch_first ? max_input_length - 1 : batch_size;
    dim3   grid(gx, gy);
    log_probs_kernel<<>>(log_probs,
                                                         logits,
                                                         input_ids,
                                                         input_lengths,
                                                         max_input_length,
                                                         batch_size,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         batch_first);
    accumulate_log_probs<<>>(
        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}

template void invokeLogProbFromLogits(float*       cum_log_probs,
                                      const float* logits,
                                      const int*   input_ids,
                                      const int*   input_lengths,
                                      const size_t max_input_length,
                                      const size_t batch_size,
                                      const size_t vocab_size,
                                      const size_t vocab_size_padded,
                                      void*        workspace,
                                      const size_t workspace_size,
                                      cudaStream_t stream,
                                      const bool   batch_first);

template void invokeLogProbFromLogits(float*       cum_log_probs,
                                      const half*  logits,
                                      const int*   input_ids,
                                      const int*   input_lengths,
                                      const size_t max_input_length,
                                      const size_t batch_size,
                                      const size_t vocab_size,
                                      const size_t vocab_size_padded,
                                      void*        workspace,
                                      const size_t workspace_size,
                                      cudaStream_t stream,
                                      const bool   batch_first);
}  // end of namespace turbomind


================================================
FILE: src/turbomind/kernels/logprob_kernels.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

namespace turbomind {

template
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first = false);
}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/norm/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

add_library(rms_norm rms_norm.cu)
set_property(TARGET rms_norm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET rms_norm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)


================================================
FILE: src/turbomind/kernels/norm/rms_norm.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "cub/block/block_reduce.cuh"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/kernels/norm/rms_norm.h"

namespace turbomind {

namespace kernel {

template
__global__ void RMSNorm(T*       dst,
                        int      dst_ld,
                        const T* src,
                        int      src_ld,
                        const T* __restrict__ weights,
                        int   dims,
                        int   num,
                        float eps,
                        float inv_dims)
{
    const int ti = blockIdx.x;
    const int di = threadIdx.x * vec_size;

    if (ti >= num) {
        return;
    }

    src += src_ld * ti;

    Array accum{};
    Array     vec;

    for (int i = di; i < dims; i += block_dim * vec_size) {
        Load(vec, &src[i]);
        Array tmp = cast(vec);
        using namespace ops;
        accum = accum + tmp * tmp;
    }

    float sum{};
    PRAGMA_UNROLL
    for (int i = 0; i < vec_size; ++i) {
        sum += accum[i];
    }

    using BlockReduce = cub::BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;

    sum = BlockReduce{temp_storage}.Sum(sum);

    __shared__ float shared_sum;

    if (threadIdx.x == 0) {
        shared_sum = rsqrtf(sum * inv_dims + eps);
    }

    __syncthreads();

    sum = shared_sum;

    dst += dst_ld * ti;

    Array sv;
    for (int i = di; i < dims; i += block_dim * vec_size) {
        Load(vec, &src[i]);
        Ldg(sv, &weights[i]);
        PRAGMA_UNROLL
        for (int c = 0; c < vec_size; ++c) {
            vec[c] = (T)((float)vec[c] * sum) * sv[c];
            // vec[c] = (T)((float)vec[c] * sum * (float)sv[c]);
        }
        Store(&dst[i], vec);
    }
}

}  // namespace kernel

void invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st)
{
    if (x.size() == 0) {
        return;
    }

    TM_CHECK(x.ndim() == 2);
    TM_CHECK(out.shape() == x.shape());
    TM_CHECK(out.dtype() == x.dtype());
    TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1));

    auto invoke = [&](auto t) {
        using T = decltype(t);

        const auto [num, dim] = x.shapes(0, 1);

        constexpr int vec_size = 16 / sizeof(T);

        constexpr int threads = 512;
        const int     blocks  = num;

        kernel::RMSNorm<<>>((T*)out.raw_data(),  //
                                                                                 out.stride(0),
                                                                                 (const T*)x.raw_data(),
                                                                                 x.stride(0),
                                                                                 (const T*)w.raw_data(),
                                                                                 dim,
                                                                                 num,
                                                                                 eps,
                                                                                 1.f / dim);
    };

    TM_DISPATCH_PRIMARY_DTYPES(x.dtype(), invoke);
}

namespace kernel {

template
__global__ void RMSNormQK(T*       data,  //
                          int      ld,
                          const T* weight,
                          int      dim,
                          int      n,
                          int      token_num,
                          float    eps,
                          float    inv_dim)
{
    static_assert((max_dim & (max_dim - 1)) == 0);

    constexpr int thr_per_qk = max_dim / vec_size;

    const int bi = (threadIdx.x + blockIdx.x * blockDim.x) / thr_per_qk;
    const int di = threadIdx.x % thr_per_qk * vec_size;
    const int ti = bi / n;
    const int hi = bi % n;

    if (bi >= token_num * n) {
        return;
    }

    data += ti * ld + hi * dim;

    Array vec{};
    if (di < dim) {
        Load(vec, &data[di]);
    }

    using namespace ops;
    auto acc = cast(vec);
    acc      = acc * acc;

    float sum{};
    PRAGMA_UNROLL
    for (int i = 0; i < vec_size; ++i) {
        sum += acc[i];
    }

    PRAGMA_UNROLL
    for (int mask = thr_per_qk / 2; mask >= 1; mask /= 2) {
        sum += __shfl_xor_sync((uint32_t)-1, sum, mask);
    }

    sum = rsqrtf(sum * inv_dim + eps);

    Array w;
    if (di < dim) {
        Ldg(w, &weight[di]);
        PRAGMA_UNROLL
        for (int i = 0; i < vec_size; ++i) {
            vec[i] = (T)((float)vec[i] * sum) * w[i];
        }
        Store(&data[di], vec);
    }
}

}  // namespace kernel

void invokeQkRMSNorm(void*        data,
                     int          ld,
                     const void*  weight,
                     DataType     dtype,
                     int          head_dim,
                     int          n,
                     int          token_num,
                     float        eps,
                     cudaStream_t stream)
{

    auto invoke = [&](auto t) {
        using T = decltype(t);

        auto launch = [&](auto max_dim_c) {
            constexpr int kMaxDim = std::decay_t::value;
            TM_CHECK_LE(head_dim, kMaxDim);

            constexpr int vec_size   = sizeof(uint4) / sizeof(T);
            constexpr int thr_per_qk = kMaxDim / vec_size;

            FT_CHECK(head_dim % vec_size == 0);

            const int threads   = thr_per_qk * n * (int64_t)token_num;
            const int block_dim = 512;
            const int grid_dim  = cdiv(threads, block_dim);

            kernel::RMSNormQK<<>>(
                (T*)data, ld, (const T*)weight, head_dim, n, token_num, eps, 1.f / head_dim);
        };

        if (head_dim <= 128) {
            launch(constant<128>{});
        }
        else {
            launch(constant<256>{});
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);
}

void invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st)
{
    TM_CHECK(x.ndim() == 3);

    int token_num, head_num, head_dim;
    std::tie(token_num, head_num, head_dim) = x.shapes(0, 1, 2);

    TM_CHECK(x.stride(1) == head_dim);

    auto data   = x.raw_data();
    auto stride = x.stride(0);

    auto invoke = [&](auto t) {
        using T = decltype(t);

        auto launch = [&](auto max_dim_c) {
            constexpr int kMaxDim = std::decay_t::value;
            TM_CHECK_LE(head_dim, kMaxDim);

            constexpr int vec_size   = sizeof(uint4) / sizeof(T);
            constexpr int thr_per_qk = kMaxDim / vec_size;

            TM_CHECK(head_dim % vec_size == 0);

            const int threads   = token_num * head_num * thr_per_qk;
            const int block_dim = 512;
            const int grid_dim  = cdiv(threads, block_dim);

            kernel::RMSNormQK<<>>(
                (T*)data, stride, (const T*)w.raw_data(), head_dim, head_num, token_num, eps, 1.f / head_dim);
        };

        if (head_dim <= 128) {
            launch(constant<128>{});
        }
        else {
            launch(constant<256>{});
        }
    };

    TM_DISPATCH_PRIMARY_DTYPES(x.dtype(), invoke);
}

// r' <- r + (h + b)
// h' <- norm(r') * w
template
__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual,
                                          T* __restrict__ hidden_states,
                                          const T* __restrict__ weights,
                                          const T* __restrict__ bias,
                                          int   dims,
                                          int   num,
                                          float eps,
                                          float inv_dims)
{
    const int ti = blockIdx.x;
    const int di = threadIdx.x * vec_size;

    if (ti >= num) {
        return;
    }

    residual += dims * ti;
    hidden_states += dims * ti;

    Array accum{};

    Array r_vec;
    Array h_vec;
    Array b_vec;

    for (int i = di; i < dims; i += block_dim * vec_size) {
        Load(r_vec, &residual[i]);
        Load(h_vec, &hidden_states[i]);

        using namespace ops;
        r_vec = r_vec + h_vec;

        if (bias) {
            Ldg(b_vec, &bias[i]);
            r_vec = r_vec + b_vec;
        }

        Store(&residual[i], r_vec);

        Array tmp = cast(r_vec);

        accum = accum + tmp * tmp;
    }

    float sum{};
    PRAGMA_UNROLL
    for (int i = 0; i < vec_size; ++i) {
        sum += accum[i];
    }

    using BlockReduce = cub::BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;

    sum = BlockReduce{temp_storage}.Sum(sum);

    __shared__ float shared_sum;

    if (threadIdx.x == 0) {
        shared_sum = rsqrtf(sum * inv_dims + eps);
    }

    __syncthreads();

    sum = shared_sum;

    Array w_vec;
    for (int i = di; i < dims; i += block_dim * vec_size) {
        Load(r_vec, &residual[i]);
        Ldg(w_vec, &weights[i]);
        PRAGMA_UNROLL
        for (int c = 0; c < vec_size; ++c) {
            r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];
            // r_vec[c] = (T)((float)r_vec[c] * sum * (float)w_vec[c]);
        }
        Store(&hidden_states[i], r_vec);
    }
}

template
void invokeBiasResidualRMSNorm(
    T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st)
{
    constexpr int vec_size = 16 / sizeof(T);
    constexpr int threads  = 512;
    const int     blocks   = num;

    BiasResidualRMSNormKernel<<>>(residual,  //
                                                                                       hidden_states,
                                                                                       weights,
                                                                                       bias,
                                                                                       dims,
                                                                                       num,
                                                                                       eps,
                                                                                       1.f / dims);
}

template void invokeBiasResidualRMSNorm(half*        residual,
                                        half*        hidden_states,
                                        const half*  weights,
                                        const half*  bias,
                                        int          dims,
                                        int          num,
                                        float        eps,
                                        cudaStream_t st);

#if ENABLE_BF16
template void invokeBiasResidualRMSNorm(nv_bfloat16*       residual,
                                        nv_bfloat16*       hidden_states,
                                        const nv_bfloat16* weights,
                                        const nv_bfloat16* bias,
                                        int                dims,
                                        int                num,
                                        float              eps,
                                        cudaStream_t       st);
#endif

void invokeResidualBiasRMSNorm(void*        hidden_states,
                               void*        residual,
                               const void*  weights,
                               const void*  bias,
                               DataType     dtype,
                               int          dims,
                               int          num,
                               float        eps,
                               cudaStream_t st)
{
    if (num == 0) {
        return;
    }
    auto invoke = [&](auto t) {
        using T                = decltype(t);
        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        constexpr int threads  = 512;
        const int     blocks   = num;
        BiasResidualRMSNormKernel<<>>((T*)residual,  //
                                                                                           (T*)hidden_states,
                                                                                           (const T*)weights,
                                                                                           (const T*)bias,
                                                                                           dims,
                                                                                           num,
                                                                                           eps,
                                                                                           1.f / dims);
    };

    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);
}

template
__global__ void biasKernel(T* data, const B* bias, int num, int dim)
{
    int ti = blockIdx.x;
    int di = threadIdx.x * vec_size;

    Array b;
    Ldg(b, bias + di);

    Array x;
    Load(x, data + ti * dim + di);
    using namespace ops;
    x = x + cast(b);
    Store(data + ti * dim + di, x);
}

void ApplyBias(Tensor& data, const Tensor& bias, cudaStream_t st)
{
    if (!bias) {
        return;
    }

    const int num = data.shape(0);
    const int dim = data.shape(1);

    TM_CHECK_EQ(dim, bias.shape(-1));

    auto invoke0 = [&](auto t) {
        using T      = decltype(t);
        auto invoke1 = [&](auto b) {
            using B                = decltype(b);
            constexpr int vec_size = sizeof(uint4) / std::max(sizeof(T), sizeof(B));
            TM_CHECK(dim % vec_size == 0);
            const int blocks  = num;
            const int threads = dim / vec_size;
            TM_CHECK_LE(threads, 1024);
            biasKernel<<>>(data.data(),  //
                                                                   bias.data(),
                                                                   num,
                                                                   dim);
        };
        if constexpr (data_type_v == kFloat) {
            TM_DISPATCH_PRIMARY_DTYPES(bias.dtype(), invoke1);
        }
        else {  // skip mixing half and bf16
            invoke1(t);
        }
    };
    TM_DISPATCH_DTYPES(data.dtype(), invoke0, float, half, nv_bfloat16);
}

template
__global__ void biasKernel(T* data, const T* bias, const int* offsets, int num, int dim, int groups, float scale)
{
    int ti = blockIdx.x;
    int di = threadIdx.x * vec_size;

    __shared__ int s_idx;

    if (int tid = threadIdx.x; tid < groups) {
        int b = __ldg(&offsets[tid]);
        int e = __ldg(&offsets[tid + 1]);
        if (b <= ti && ti < e) {
            s_idx = tid;
        }
    }

    data += ti * dim;

    __syncthreads();

    bias += s_idx * dim;

    if (di >= dim) {
        return;
    }

    Array b;
    Ldg(b, bias + di);

    PRAGMA_UNROLL
    for (int i = 0; i < vec_size; ++i) {
        b[i] = (T)((float)b[i] * scale);
    }

    Array x;
    Load(x, data + di);

    using namespace ops;
    x = x + b;

    Store(data + di, x);
}

void ApplyBias(Tensor& data, const Tensor& bias, const Buffer_& offsets, float scale, cudaStream_t st)
{
    if (!bias) {
        return;
    }

    const int num    = data.shape(0);
    const int dim    = data.shape(1);
    const int groups = offsets.size() - 1;

    TM_CHECK_EQ(dim, bias.shape(-1));

    // std::cout << data << " " << bias << " " << offsets << "\n";

    auto invoke = [&](auto t) {
        using T = decltype(t);

        constexpr int vec_size = sizeof(uint4) / sizeof(T);
        TM_CHECK(dim % vec_size == 0);

        const int blocks  = num;
        const int threads = std::max(dim / vec_size, groups);

        TM_CHECK_LE(threads, 1024);

        biasKernel<<>>(data.data(),  //
                                                            bias.data(),
                                                            offsets.data(),
                                                            num,
                                                            dim,
                                                            offsets.size() - 1,
                                                            scale);
    };

    TM_DISPATCH_PRIMARY_DTYPES(data.dtype(), invoke);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/norm/rms_norm.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

void invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st);

void invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st);

template
void invokeBiasResidualRMSNorm(
    T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st);

void invokeResidualBiasRMSNorm(void*        hidden_states,
                               void*        residual,
                               const void*  weights,
                               const void*  bias,
                               DataType     dtype,
                               int          dims,
                               int          num,
                               float        eps,
                               cudaStream_t st);

void ApplyBias(Tensor& x, const Tensor& bias, const Buffer_& offsets, float scale, cudaStream_t st);

void ApplyBias(Tensor& x, const Tensor& bias, cudaStream_t st);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/penalty_types.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 
#include 

#include "src/turbomind/utils/string_utils.h"

namespace turbomind {

enum class RepetitionPenaltyType
{
    Additive,        // the presence penalty
    Multiplicative,  // the repetition penalty
    None             // No repetition penalty.
};

inline float getDefaultPenaltyValue(RepetitionPenaltyType penalty_type)
{
    switch (penalty_type) {
        case RepetitionPenaltyType::Additive:
            return 0.0f;
        case RepetitionPenaltyType::Multiplicative:
            return 1.0f;
        default:
            break;
    }
    return 0.0f;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/quantization.cu
================================================


#include 

#include 
#include 
#include 

#include 

#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/floating_point.h"
#include "src/turbomind/kernels/core/math.h"

#include "src/turbomind/kernels/quantization.cuh"
#include "src/turbomind/kernels/quantization.h"

#include "src/turbomind/kernels/attention/quantization.h"

namespace turbomind {

template
__global__ void quant_symm_row(
    Tout* out, int out_ld, Tscale* scales, int scales_ld, const T* src, int src_ld, int num, int dim, Tscale qmax)
{
#if TURBOMIND_ARCH_SM90
    static_assert(group_size % vec_size == 0);
    constexpr int threads = group_size / vec_size;
    const int     dim1    = round_up(dim, WARP_SIZE * vec_size);
    for (int ti = blockIdx.x; ti < num; ti += gridDim.x) {
        for (int di = threadIdx.x * vec_size; di < dim1; di += blockDim.x * vec_size) {
            Array vec{};
            if (di < dim) {
                Ldg(vec, src + ti * src_ld + di);
            }
            auto         absmax    = fmaxf(static_cast(find_absmax(vec)), 1e-8f);
            const Tscale scale     = absmax / qmax;
            const Tscale inv_scale = qmax / absmax;
            if (threadIdx.x % threads == 0 && di < dim) {
                // column-major
                scales[(di / group_size) * scales_ld + ti] = scale;
            }
            Array tmp;
            PRAGMA_UNROLL
            for (int c = 0; c < vec_size; ++c) {
                tmp[c] = Tout(static_cast(vec[c]) * inv_scale);
            }
            if (di < dim) {
                Store(out + ti * out_ld + di, tmp);
            }
        }
    }
#endif
}

void QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st)
{
    TM_CHECK_EQ(src.ndim(), 2);
    TM_CHECK_EQ(src.stride(1), 1);  // row-major

    const auto [num, dim] = src.shapes(0, 1);

    using T      = bfloat16_t;
    using Tout   = fp8_e4m3_t;
    using Tscale = float;

    constexpr int group_size = 128;
    constexpr int vec_size   = 8;

    constexpr int alignment = 16 / sizeof(Tscale);

    if (!out) {
        out = Tensor_{src.shape(), kDEVICE};
    }
    else {
        TM_CHECK(out.shape() == src.shape());
    }

    const int aligned_num = round_up(num, alignment);

    const int s_dim = cdiv(dim, group_size);

    if (!scale) {
        scale = Tensor_({{s_dim, num}, {aligned_num, 1}}, kDEVICE);
    }
    else {
        TM_CHECK(std::make_tuple(s_dim, num) == scale.shapes(0, 1));
        TM_CHECK(scale.stride(1) == 1);
        TM_CHECK(scale.stride(0) % alignment == 0);
    }

    constexpr int block_dim = 512;

    quant_symm_row<<>>(out.data(),  //
                                                                    out.stride(0),
                                                                    scale.data(),
                                                                    scale.stride(0),
                                                                    src.data(),
                                                                    src.stride(0),
                                                                    num,
                                                                    dim,
                                                                    448.f);
}

template
__global__ void
dequant_symm_row(Tout* out, int out_ld, const T* src, int src_ld, const Tscale* scales, int scales_ld, int num, int dim)
{
#if TURBOMIND_ARCH_SM90
    static_assert(group_size % vec_size == 0);
    for (int ti = blockIdx.x; ti < num; ti += gridDim.x) {
        for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {
            Array vec;
            Ldg(vec, src + ti * src_ld + di);
            const auto            scale = __ldg(&scales[(di / group_size) * scales_ld + ti]);
            Array tmp;
            PRAGMA_UNROLL
            for (int c = 0; c < vec_size; ++c) {
                tmp[c] = Tout(static_cast(vec[c]) * scale);
            }
            Store(out + ti * out_ld + di, tmp);
        }
    }
#endif
}

void DequantizeSymm(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st)
{
    using T      = fp8_e4m3_t;
    using Tout   = bfloat16_t;
    using Tscale = float;

    if (!out) {
        out = Tensor_{src.layout(), kDEVICE};
    }
    else {
        TM_CHECK(out.layout() == src.layout());
    }

    auto [num, dim] = src.shapes(0, 1);

    constexpr int group_size = 128;
    constexpr int vec_size   = 8;

    constexpr int block_dim = 512;

    dequant_symm_row<<>>(out.data(),  //
                                                                                       out.stride(0),
                                                                                       src.data(),
                                                                                       src.stride(0),
                                                                                       scale.data(),
                                                                                       scale.stride(0),
                                                                                       num,
                                                                                       dim);
}

template
__global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale qmax, int num, int dim)
{
    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v)) {
        static_assert(block_size % vec_size == 0);
        constexpr int threads = block_size / vec_size;

        static_assert(cta_size % threads == 0);
        constexpr int rows = cta_size / threads;

        constexpr int S = cdiv(block_size, rows);

        using BlockReduce = cub::BlockReduce;
        __shared__ typename BlockReduce::TempStorage temp_storage;
        __shared__ T                                 shared_inv_scale;

        const int row = threadIdx.x / threads;
        const int col = threadIdx.x % threads;
        const int ti  = blockIdx.x * block_size;
        const int di  = blockIdx.y * block_size + col * vec_size;

        T                  absmax{};
        Array xs[S]{};
        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            if (auto r = ti + s * rows + row; r < num && di < dim) {
                Ldg(xs[s], src + (int64_t)r * dim + di);
            }
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                absmax = __hmax(absmax, __habs(xs[s][i]));
            }
        }

        absmax = BlockReduce{temp_storage}.Reduce(absmax, [](auto a, auto b) { return __hmax(a, b); });
        if (threadIdx.x == 0) {
            auto maxval                                 = fmaxf(static_cast(absmax), 1e-8f);
            scales[blockIdx.x * gridDim.y + blockIdx.y] = maxval / qmax;
            shared_inv_scale                            = qmax / maxval;
        }
        __syncthreads();
        const Tscale inv_scale = shared_inv_scale;

        Array ys[S];
        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                ys[s][i] = Tout(static_cast(xs[s][i]) * inv_scale);
            }
            if (auto r = ti + s * rows + row; r < num && di < dim) {
                Store(out + (int64_t)r * dim + di, ys[s]);
            }
        }
    }
}

void QuantizeSymmBlock(Ref out_, Ref scale_, const Tensor& src, cudaStream_t st)
{
    TM_CHECK(src.is_contiguous());
    TM_CHECK_EQ(src.ndim(), 2);

    auto invoke = [&](auto t) {
        using T      = decltype(t);
        using Tout   = fp8_e4m3_t;
        using Tscale = float;

        constexpr int block_size = 128;
        constexpr int vec_size   = 8;

        const auto [num, dim] = src.shapes(0, 1);

        const int bnum = cdiv(num, block_size);
        const int bdim = cdiv(dim, block_size);

        constexpr int cta_size = 1024;
        const dim3    grid(bnum, bdim);

        auto& out   = out_.get();
        auto& scale = scale_.get();

        if (!out) {
            out = Tensor_{src.layout(), kDEVICE};
        }
        else {
            TM_CHECK(out.layout() == src.layout());
        }

        if (!scale) {
            scale = Tensor_({bnum, bdim}, kDEVICE);
        }
        else {
            TM_CHECK(std::make_tuple(bnum, bdim) == scale.shapes(0, 1));
        }

        quant_symm_block<<>>(  //
            out.data(),
            scale.data(),
            src.data(),
            448.f,
            num,
            dim);
    };

    TM_DISPATCH_PRIMARY_DTYPES(src.dtype(), invoke);
}

template
__global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales, int num, int dim)
{
    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v)) {
        static_assert(block_size % vec_size == 0);
        constexpr int threads = block_size / vec_size;
        static_assert(cta_size % threads == 0);
        constexpr int rows  = cta_size / threads;
        constexpr int S     = cdiv(block_size, rows);
        const int     col   = threadIdx.x % threads;
        const int     row   = threadIdx.x / threads;
        const auto    scale = __ldg(&scales[blockIdx.x * gridDim.y + blockIdx.y]);
        const auto    di    = blockIdx.y * block_size + col * vec_size;
        PRAGMA_UNROLL
        for (int s = 0; s < S; ++s) {
            const auto ti = blockIdx.x * block_size + s * rows + row;
            if (ti < num && di < dim) {
                Array x;
                Ldg(x, src + (int64_t)ti * dim + di);
                Array y;
                PRAGMA_UNROLL
                for (int i = 0; i < vec_size; ++i) {
                    y[i] = Tout(static_cast(x[i]) * scale);
                }
                Store(out + (int64_t)ti * dim + di, y);
            }
        }
    }
}

void DequantizeSymmBlock(Ref out_, Ref src_, const Tensor& scale, cudaStream_t st)
{
    auto invoke = [&](auto tout) {
        using T      = fp8_e4m3_t;
        using Tout   = decltype(tout);
        using Tscale = float;

        constexpr int block_size = 128;
        constexpr int vec_size   = 8;

        auto& out = out_.get();
        auto& src = src_.get();

        if (!out) {
            out = Tensor_{src.layout(), kDEVICE};
        }
        else {
            TM_CHECK(out.layout() == src.layout());
        }

        const auto [num, dim] = src.shapes(0, 1);

        const int bnum = cdiv(num, block_size);
        const int bdim = cdiv(dim, block_size);

        constexpr int cta_size = 1024;
        const dim3    grid(bnum, bdim);

        dequant_symm_block<<>>(  //
            out.data(),
            src.data(),
            scale.data(),
            num,
            dim);
    };

    if (!out_.get()) {
        return invoke(nv_bfloat16{});
    }

    TM_DISPATCH_PRIMARY_DTYPES(out_.get().dtype(), invoke);
}

template
__global__ void Compact1D_Kernel(D* d, const T* s, int n)
{
    constexpr int bits     = end_bit - start_bit;
    constexpr int vec_size = bitsof / bits;

    const auto idx = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;

    if (idx * vec_size >= n) {
        return;
    }

    Array s_vec;

    Load(s_vec, &s[idx * vec_size]);

    constexpr T mask = ((1 << bits) - 1) << start_bit;

    D pack{};

    PRAGMA_UNROLL
    for (int i = 0; i < vec_size; ++i) {
        pack |= ((s_vec[i] & mask) >> start_bit) << (i * bits);
    }

    d[idx] = pack;
}

template
struct IntegralQuantizer {

    using T = T_;
    using Q = Q_;

    using Scale = T;
    using Zero  = T;

    static constexpr int bits  = bits_;
    static constexpr int max_q = (1 << bits) - 1;

    template
    __device__ void operator()(const Array&    x,  //
                               const Array& pred,
                               const R&              rbits,
                               Array&          q,
                               Array&          d,
                               T&                    scale,
                               T&                    zero,
                               int                   threads) const
    {
        auto f = cast(x);

        float minval = std::numeric_limits::infinity();
        float maxval = -minval;

        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            if (pred[i]) {
                minval = fminf(minval, f[i]);
                maxval = fmaxf(maxval, f[i]);
            }
        }

        for (int offset = threads / 2; offset >= 1; offset /= 2) {
            minval = fminf(minval, __shfl_xor_sync((uint32_t)-1, minval, offset));
            maxval = fmaxf(maxval, __shfl_xor_sync((uint32_t)-1, maxval, offset));
        }

        auto clamp = [](int x, int a, int b) { return max(a, min(b, x)); };

        float scale_ = fmaxf(maxval - minval, 1e-5f) / (float)max_q;
        int   zero_  = clamp(-round(minval / scale_), 0, max_q);

        scale = (T)scale_;
        zero  = (T)zero_;

        // T sz = zero_ * scale_;

        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            q[i] = clamp(round(f[i] / scale_) + zero_, 0, max_q);
            d[i] = (T)((int)q[i] - zero_) * (T)scale_;
            // d[i] = __hfma((T)q[i], (T)scale_, -sz);
        }
    }
};

template
struct FloatingPointQuantizer {

    using T = T_;
    using Q = Q_;

    using Scale = uint8_t;
    using Zero  = void;

    using traits = FloatingPoint;

    static constexpr int bits = traits::bits;

    float pre_rounding_scale_;

    __host__ __device__ FloatingPointQuantizer(float pre_rounding_scale = 1.f): pre_rounding_scale_{pre_rounding_scale}
    {
    }

    template
    __device__ void operator()(const Array&    x,  //
                               const Array& pred,
                               const R&              rbits,
                               Array&          q,
                               Array&          d,
                               Scale&                scale,
                               Z                     ignore,
                               int                   threads) const
    {
        auto f = cast(x);

        float absmax = 0.f;

        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            if (pred[i]) {
                absmax = fmaxf(absmax, fabsf(f[i]));
            }
        }

        for (int offset = threads / 2; offset >= 1; offset /= 2) {
            absmax = fmaxf(absmax, __shfl_xor_sync((uint32_t)-1, absmax, offset));
        }

        auto get_exponent = [](float x) -> int { return (__float_as_uint(x) >> 23U) & 0xFFU; };

        int scale_i32 = get_exponent(absmax) - (traits::exponent_bias + 1);

        // int scale_i32 = 127;

        if (scale_i32 < 0) {  // absmax(group) < 2*2^-125, flush to zero
            scale_i32 = 0;
            f         = {};
        }

        scale = scale_i32;

        float scale_f32 = __uint_as_float((uint32_t)scale_i32 << 23U);

        PRAGMA_UNROLL
        for (int i = 0; i < N; ++i) {
            q[i] = traits::from_f32((f[i] * pre_rounding_scale_) / scale_f32, rbits[i]);
            d[i] = (traits::to_f32(q[i]) * scale_f32) / pre_rounding_scale_;
        }
    }
};

template
__global__ void QuantizeGroupwise_Kernel(Quantizer       quantizer,
                                         Q*              q,
                                         S*              s,
                                         Z*              z,
                                         T*              d,
                                         const T*        x,
                                         const unsigned* r,
                                         Array   stride_q,
                                         Array   stride_s,
                                         Array   stride_d,
                                         Array   stride_x,
                                         int             M,
                                         int             K,
                                         int             G)
{
    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v)) {
        static constexpr bool has_zero = !std::is_void_v;

        int m = blockIdx.x;
        int k = threadIdx.x + blockIdx.y * blockDim.x;

        const int threads_per_group = G / vec_size;
        const int warp_k            = WARP_SIZE * vec_size;

        k *= vec_size;

        for (; k < round_up(K, warp_k); k += gridDim.y * blockDim.x * vec_size) {

            Array    x_vec;
            Array p_vec;

            Array r_vec;

            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                p_vec[i] = k + i < K;
                x_vec[i] = p_vec[i] ? x[stride_x[0] * m + stride_x[1] * (k + i)] : T{0};
                if (r) {
                    r_vec[i] = p_vec[i] ? r[m * K + k] : 0;
                }
            }

            Array q_vec;
            Array d_vec;

            S                                    scale;
            std::conditional_t zero{};

            auto invoke = [&](auto rbits) {
                quantizer(x_vec, p_vec, rbits, q_vec, d_vec, scale, zero, threads_per_group);
            };

            r ? invoke(r_vec) : invoke(Array{});

            PRAGMA_UNROLL
            for (int i = 0; i < vec_size; ++i) {
                const auto idx = stride_q[0] * m + stride_q[1] * (k + i);
                if (p_vec[i]) {
                    q[idx] = q_vec[i];
                    d[idx] = d_vec[i];
                }
            }
            if (threadIdx.x % threads_per_group == 0) {
                const auto idx = stride_s[0] * m + stride_s[1] * (k / G);
                if (p_vec[0]) {
                    s[idx] = (S)scale;
                    if constexpr (has_zero) {
                        z[idx] = (S)zero;
                    }
                }
            }
        }
    }
}

void QuantizeGroupwise(Tensor            quant,    // (m,k)
                       Tensor            scales,   // (m,k/g)
                       Tensor            zeros,    // (m,k/g)
                       Tensor            dequant,  // (m,k)
                       Tensor            src,      // (m,k)
                       Buffer_ rbits,    // (m*k)
                       int               group_size)
{
    // std::cout << quant << std::endl;
    // std::cout << scales << std::endl;
    // std::cout << zeros << std::endl;
    // std::cout << dequant << std::endl;
    // std::cout << src << std::endl;

    if (zeros) {
        TM_CHECK(scales.layout() == zeros.layout());
    }
    TM_CHECK(quant.shape() == dequant.shape());
    TM_CHECK(quant.size() == quant.layout().cosize());

    auto stream = core::Context::stream().handle();

    auto stride_2d = [](const Tensor& t) {
        TM_CHECK_EQ(t.ndim(), 2);
        auto [a, b] = t.strides(0, 1);
        return Array{(int)a, (int)b};
    };

    const int m = src.shape(0);
    const int k = src.shape(1);

    // std::cout << "m" << m << "k" << k << "\n";

    auto invoke = [&](auto quantizer) {
        using Quantizer = decltype(quantizer);

        using T = typename Quantizer::T;
        using Q = typename Quantizer::Q;
        using S = typename Quantizer::Scale;
        using Z = typename Quantizer::Zero;

        constexpr int bits = Quantizer::bits;

        Tensor_ proxy = empty_like(quant, data_type_v);

        constexpr int vec = 8;

        TM_CHECK((group_size & (group_size - 1)) == 0);
        TM_CHECK_GE(group_size, vec);
        TM_CHECK_LE(group_size, WARP_SIZE * vec);

        const int threads = round_up(std::min(cdiv(k, vec), 1024), WARP_SIZE);

        QuantizeGroupwise_Kernel<<>>(quantizer,
                                                                 proxy.data(),
                                                                 scales.data(),
                                                                 zeros.data_or((Z*)nullptr),
                                                                 dequant.data(),
                                                                 src.data(),
                                                                 rbits.data_or(nullptr),
                                                                 stride_2d(proxy),
                                                                 stride_2d(scales),
                                                                 stride_2d(dequant),
                                                                 stride_2d(src),
                                                                 m,
                                                                 k,
                                                                 group_size);

        Compact1D_Kernel<0, bits><<>>(
            (uint32_t*)quant.raw_data(), (Q*)proxy.raw_data(), quant.size());
    };

    if (0) {}
    else if (src.dtype() == kHalf && quant.dtype() == kUint4) {
        invoke(IntegralQuantizer{});
    }
    else if (src.dtype() == kBfloat16 && quant.dtype() == kFloat4_e2m1) {
        invoke(FloatingPointQuantizer{});
    }
    else if (src.dtype() == kHalf && quant.dtype() == kFloat4_e2m1) {
        invoke(FloatingPointQuantizer{});
    }
    else {
        TM_CHECK(0) << "Unsupported types: " << to_string(src.dtype()) << ", " << to_string(quant.dtype());
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/quantization.cuh
================================================

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/common.h"

namespace turbomind {

#if 0
template
__device__ Array find_minmax(const Array& a)
{
    static_assert((threads & (threads - 1)) == 0);
    static_assert(sizeof(Array) == sizeof(uint32_t));
    uint32_t data;
    auto&    minmax = reinterpret_cast&>(data);
    minmax          = {a[0], a[0]};
    PRAGMA_UNROLL
    for (int i = 1; i < N; ++i) {
        minmax = hmin(minmax[0], a[i]);
        minmax = hmax(minmax[1], a[i]);
    }
    PRAGMA_UNROLL
    for (int mask = threads / 2; mask > 0; mask /= 2) {
        uint32_t tmp = __shfl_xor_sync(uint32_t(-1), data, mask);
        auto&    vec = reinterpret_cast&>(tmp);
        minmax[0]    = hmin(minmax[0], vec[0]);
        minmax[1]    = hmax(minmax[1], vec[1]);
    }
    return minmax;
}
#endif

template
__device__ T find_absmax(const Array& a)
{
    static_assert((threads & (threads - 1)) == 0);
    static_assert(sizeof(Array) == sizeof(uint32_t));
    uint32_t data;
    auto&    val = *reinterpret_cast(&data);
    val          = __habs(a[0]);
    PRAGMA_UNROLL
    for (int i = 1; i < N; ++i) {
        val = __hmax(val, __habs(a[i]));
    }
    PRAGMA_UNROLL
    for (int mask = threads / 2; mask > 0; mask /= 2) {
        uint32_t tmp = __shfl_xor_sync(uint32_t(-1), data, mask);
        auto&    x   = *reinterpret_cast(&tmp);
        val          = __hmax(val, x);
    }
    return val;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/quantization.h
================================================
#include "src/turbomind/core/core.h"

namespace turbomind {

void QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st);

void DequantizeSymm(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st);

void QuantizeSymmBlock(Ref out_, Ref scale_, const Tensor& src, cudaStream_t st);

void DequantizeSymmBlock(Ref out_, Ref src_, const Tensor& scale, cudaStream_t st);

void QuantizeGroupwise(Tensor            quant,    // (m,k)
                       Tensor            scales,   // (m,k/g)
                       Tensor            zeros,    // (m,k/g)
                       Tensor            dequant,  // (m,k)
                       Tensor            src,      // (m,k)
                       Buffer_ rbits,    // (m*k)
                       int               group_size);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/reduce_kernel_utils.cuh
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once
#include 
#include 
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include 
#else
#include 
#endif
#include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include 
#include 
#include 
#include 
#include 

namespace cg = cooperative_groups;

namespace turbomind {

template
struct BytesToType;

template<>
struct BytesToType<2> {
    using type = uint16_t;
};
template<>
struct BytesToType<4> {
    using type = uint32_t;
};
template<>
struct BytesToType<8> {
    using type = uint64_t;
};
template<>
struct BytesToType<16> {
    using type = float4;
};

template
__device__ inline T getMaxValue();

template<>
__device__ inline float getMaxValue()
{
    return FLT_MAX;
}

template<>
__device__ inline half getMaxValue()
{
    return __ushort_as_half((unsigned short)0x7BFFU);
}

#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 getMaxValue<__nv_bfloat16>()
{
#if __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16((unsigned short)0x7F7FU);
#endif
    return {};
}
#endif

template
__device__ inline T getInfValue();

template<>
__device__ inline float getInfValue()
{
    return INFINITY;
}

template<>
__device__ inline half getInfValue()
{
    return __ushort_as_half((unsigned short)0x7C00U);
}

#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 getInfValue<__nv_bfloat16>()
{
#if __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16((unsigned short)0x7F80U);
#endif
    return {};
}
#endif

template
__device__ inline void copy(const void* local, void* data)
{
    using T = typename BytesToType::type;

    const T* in  = static_cast(local);
    T*       out = static_cast(data);
    *out         = *in;
}

#define HALF_FLT_MAX 65504.F
#define FINAL_MASK 0xffffffff

template
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
    for (int mask = 16; mask > 0; mask >>= 1)
        val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));  //__shfl_sync bf16 return float when sm < 80
    return val;
}

/* Calculate the sum of all elements in a block */
template
__inline__ __device__ T blockReduceSum(T val)
{
    static __shared__ T shared[32];
    int                 lane = threadIdx.x & 0x1f;
    int                 wid  = threadIdx.x >> 5;

    val = warpReduceSum(val);

    if (lane == 0)
        shared[wid] = val;

    __syncthreads();

    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
    val = warpReduceSum(val);

    return val;
}

template
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
    for (int mask = 16; mask > 0; mask >>= 1)
        val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
    return val;
}

/* Calculate the maximum of all elements in a block */
template
__inline__ __device__ T blockReduceMax(T val)
{
    static __shared__ T shared[32];
    int                 lane = threadIdx.x & 0x1f;  // in-warp idx
    int                 wid  = threadIdx.x >> 5;    // warp idx

    val = warpReduceMax(val);  // get maxx in each warp

    if (lane == 0)  // record in-warp maxx by warp Idx
        shared[wid] = val;

    __syncthreads();

    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
    val = warpReduceMax(val);

    return val;
}

/* Calculate the maximum of all elements in a block */
template
__inline__ __device__ T blockAllReduceMax(T val)
{
    static __shared__ T shared[32];
    int                 lane = threadIdx.x & 0x1f;  // in-warp idx
    int                 wid  = threadIdx.x >> 5;    // warp idx

    val = warpReduceMax(val);  // get maxx in each warp

    if (lane == 0)  // record in-warp maxx by warp Idx
        shared[wid] = val;

    __syncthreads();

    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
    val = warpReduceMax(val);

    return val;
}

template
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
    for (int i = 0; i < NUM; i++) {
#pragma unroll
        for (int mask = 16; mask > 0; mask >>= 1)
            val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
    }
    return (T)(0.0f);
}

template
__inline__ __device__ T blockReduceSumV2(T* val)
{
    static __shared__ T shared[NUM][33];
    int                 lane = threadIdx.x & 0x1f;
    int                 wid  = threadIdx.x >> 5;

    warpReduceSumV2(val);

    if (lane == 0) {
#pragma unroll
        for (int i = 0; i < NUM; i++) {
            shared[i][wid] = val[i];
        }
    }

    __syncthreads();

    bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
    for (int i = 0; i < NUM; i++) {
        val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
    }
    warpReduceSumV2(val);
    return (T)0.0f;
}

template
__inline__ __device__ T warpReduceMaxV2(T* val)
{
#pragma unroll
    for (int i = 0; i < NUM; i++) {
#pragma unroll
        for (int mask = 16; mask > 0; mask >>= 1)
            val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
    }
    return (T)(0.0f);
}

template
__inline__ __device__ T blockReduceMaxV2(T* val)
{
    static __shared__ T shared[32][NUM];
    int                 lane = threadIdx.x & 0x1f;  // in-warp idx
    int                 wid  = threadIdx.x >> 5;    // warp idx

    warpReduceMaxV2(val);  // get maxx in each warp

    if (lane == 0)  // record in-warp maxx by warp Idx
    {
#pragma unroll
        for (int i = 0; i < NUM; i++) {
            shared[wid][i] = val[i];
        }
    }

    __syncthreads();

    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
    for (int i = 0; i < NUM; i++) {
        val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
    }
    warpReduceMaxV2(val);

    return (T)0.0f;
}

template
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
{
    cg::thread_block          cta  = cg::this_thread_block();
    cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);

    const int tid    = cta.thread_rank();
    const int blockz = blockDim.x;
    for (int i = 0; i < NUM; i++) {
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
        cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus());
#else
        // TODO Add implementation here
        if (threadIdx.x == 0 && blockIdx.x == 0) {
            printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
            assert(false);
        }
#endif
    }
    cg::sync(cta);
    if (tid == 0) {
#pragma unroll
        for (int i = 0; i < NUM; i++) {
            float beta = 0.0f;
            for (int j = 0; j < blockz; j += 32) {
                beta += cgBlockReduceSumElements_shm[i * blockz + j];
            }
            element_list[i] = beta;
        }
    }
}

template
struct TopK {
    int p[MAX_K];
    T   u[MAX_K];

    __device__ __forceinline__ void insert(T elem, int elem_id)
    {
        if (elem > u[MAX_K - 1] || (p[MAX_K - 1] == -1) || ((elem == u[MAX_K - 1]) && (elem_id < p[MAX_K - 1])))
        // if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))
        {
            u[MAX_K - 1] = elem;
            p[MAX_K - 1] = elem_id;
        }

        for (int k = MAX_K - 2; k >= 0; --k) {
            if ((u[k + 1] > u[k]) || (p[k] == -1) || ((u[k + 1] == u[k]) && (p[k + 1] < p[k])))
            // if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k])))
            {
                T   u2   = u[k];
                int p2   = p[k];
                u[k]     = u[k + 1];
                p[k]     = p[k + 1];
                u[k + 1] = u2;
                p[k + 1] = p2;
            }
        }
    }

    __device__ __forceinline__ void init()
    {
        const bool IS_FP16   = std::is_same::value;
        const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

        for (int i = 0; i < MAX_K; i++) {
            p[i] = -1;
            u[i] = -MAX_T_VAL;
        }
    }
};

template
__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b)
{
    TopK res = a;
    for (int i = 0; i < MAX_K; ++i)
        res.insert(b.u[i], b.p[i]);
    return res;
}

template
struct TopK_2 {
    int p = 0;
    T   u = -getInfValue();

    __device__ __forceinline__ void insert(T elem, int elem_id)
    {
        if (elem > u) {
            u = elem;
            p = elem_id;
        }
    }

    __device__ __forceinline__ void init()
    {
        u = -getInfValue();
        p = 0;
    }
};

template
__device__ __forceinline__ TopK_2 reduce_topk_op_2(const TopK_2& a, const TopK_2& b)
{
    return a.u > b.u ? a : b;
}

template
__device__ __forceinline__ T clamp_inf_for_half(const float input)
{
    return input;
}

template<>
__device__ __forceinline__ half clamp_inf_for_half(const float input)
{
    // clamp inf values to enable fp16 training
    return input > 0.0f ? (half)min(input, HALF_FLT_MAX - 1000) : (half)max(input, -HALF_FLT_MAX + 1000);
}

#ifdef ENABLE_BF16
template<>
__device__ __forceinline__ __nv_bfloat16 clamp_inf_for_half(const float input)
{
    return __float2bfloat16(input);
}
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_kernels.cu
================================================
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include 
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/turbomind/kernels/sampling_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/utils/constant.h"

namespace turbomind {

template
__global__ void sampling(const T*       logits,
                         const int      stride,
                         const int*     indices,
                         const int*     kept,
                         curandState_t* curandstate,
                         int*           output_ids,
                         int*           sequence_length,
                         T*             sampled_logprobs,
                         int*           sampled_indexes,
                         int*           sampled_nums)
{
    int tid      = threadIdx.x;
    int batch_id = blockIdx.x;
    int n        = kept[batch_id];

    logits += stride * batch_id;
    indices += stride * batch_id;

    __shared__ float rand_num_s;
    __shared__ int   selected;
    if (tid == 0) {
        rand_num_s = curand_uniform(curandstate + batch_id);
    }
    __syncthreads();

    typedef cub::BlockScan  BlockScan;
    __shared__ typename BlockScan::TempStorage temp_storage;

    float                 local_rand = rand_num_s;
    float                 prefix_sum = 0.f;
    BlockPrefixCallbackOp prefix_op{0};
    int                   end = (n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
    for (int i = tid; i < end; i += BLOCK_SIZE) {
        float thread_logit = (i < n) ? static_cast(logits[i]) : 0.f;
        BlockScan(temp_storage).InclusiveSum(thread_logit, prefix_sum, prefix_op);
        auto count = __syncthreads_count(prefix_sum > local_rand);
        if (count != 0 || (i + BLOCK_SIZE) >= end) {
            if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {
                selected             = min(i, n - 1);
                output_ids[batch_id] = indices[selected];
            }
            break;
        }
    }

    if (tid == 0) {
        sequence_length[batch_id] += 1;
    }

    if (sampled_logprobs != nullptr && sampled_indexes != nullptr && sampled_nums != nullptr) {
        __syncthreads();
        sampled_logprobs += batch_id * kMaxLogProb;
        sampled_indexes += batch_id * kMaxLogProb;
        int end = min(n, kMaxLogProb);
        for (int i = tid; i < end; i += BLOCK_SIZE) {
            sampled_logprobs[i] = logf(logits[i]);
            sampled_indexes[i]  = indices[i];
        }
        if (n > kMaxLogProb && selected >= kMaxLogProb) {
            if ((kMaxLogProb - 1 + BLOCK_SIZE - tid) % BLOCK_SIZE == 0) {
                sampled_logprobs[kMaxLogProb - 1] = logf(logits[selected]);
                sampled_indexes[kMaxLogProb - 1]  = indices[selected];
            }
        }
        sampled_nums[batch_id] = min(n, kMaxLogProb);
    }
}

template
void invokeSampling(SamplingParams& params, cudaStream_t stream)
{
    const int grid  = params.batch_size;
    const int block = 256;
    sampling<<>>((T*)params.logits,
                                                   params.stride,
                                                   params.indices,
                                                   params.kept,
                                                   params.curandstate,
                                                   params.output_ids,
                                                   params.sequence_length,
                                                   (T*)params.sampled_logprobs,
                                                   params.sampled_indexes,
                                                   params.sampled_nums);
}

template void invokeSampling(SamplingParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_kernels.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 

#include 
#include 

namespace turbomind {

struct SamplingParams {
    void*          logits;
    int            stride;
    int*           indices;
    int*           kept;
    curandState_t* curandstate;
    size_t         batch_size;
    int*           output_ids;
    int*           sequence_length;
    void*          sampled_logprobs;
    int*           sampled_indexes;
    int*           sampled_nums;
};

template
void invokeSampling(SamplingParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_penalty_kernels.cu
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 
#include 

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_penalty_kernels.h"

namespace turbomind {

template
__global__ void batchApplyTemperaturePenalty_v2(T*           logits,
                                                const T*     bias,
                                                const float* temperatures,
                                                const int    batch_size,
                                                const int    vocab_size,
                                                const int    vocab_size_padded)
{
    const int vi = blockIdx.x * blockDim.x + threadIdx.x;
    const int bi = blockIdx.y;

    __shared__ float shared_scale;

    if (threadIdx.x == 0) {
        shared_scale = fdividef(1.f, temperatures[bi] + 1e-6f);
    }

    __syncthreads();

    const float scale = shared_scale;

    logits += (size_t)bi * vocab_size_padded;

    const int step = gridDim.x * blockDim.x * vec_size;

    for (int i = vi * vec_size; i < vocab_size_padded; i += step) {
        Array vec;
        // load
        if constexpr (sizeof(vec) >= sizeof(uint)) {
            Load(vec, logits + i);
        }
        else {
            PRAGMA_UNROLL
            for (int j = 0; j < vec_size; ++j) {
                vec[j] = logits[i + j];
            }
        }

        // process
        PRAGMA_UNROLL
        for (int c = 0; c < vec_size; ++c) {
            if (i + c < vocab_size) {
                vec[c] = (float)vec[c] * scale;
            }
            else {
                vec[c] = -getInfValue();
            }
        }

        // store
        if constexpr (sizeof(vec) >= sizeof(uint)) {
            Store(logits + i, vec);
        }
        else {
            PRAGMA_UNROLL
            for (int j = 0; j < vec_size; ++j) {
                logits[i + j] = vec[j];
            }
        }
    }
}

template
void invokeBatchApplyTemperaturePenalty_v2(T*           logits,
                                           const T*     bias,
                                           const float* temperatures,
                                           const int    batch_size,
                                           const int    vocab_size,
                                           const int    vocab_size_padded,
                                           cudaStream_t stream)
{

    auto invoke = [&](auto vec_size) {
        constexpr int threads        = 256;
        const int     blocks_per_tok = (vocab_size_padded + threads * vec_size - 1) / (threads * vec_size);
        const dim3    blocks(blocks_per_tok, batch_size);
        batchApplyTemperaturePenalty_v2<<>>(  //
            logits,
            bias,
            temperatures,
            batch_size,
            vocab_size,
            vocab_size_padded);
    };

    if (vocab_size_padded % 4 == 0) {
        invoke(std::integral_constant{});
    }
    else if (vocab_size_padded % 2 == 0) {
        invoke(std::integral_constant{});
    }
    else {
        invoke(std::integral_constant{});
    }
}

#define INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(T)                                                       \
    template void invokeBatchApplyTemperaturePenalty_v2(T*           logits,                                           \
                                                        const T*     bias,                                             \
                                                        const float* temperatures,                                     \
                                                        const int    batch_size,                                       \
                                                        const int    vocab_size,                                       \
                                                        const int    vocab_size_padded,                                \
                                                        cudaStream_t stream);

INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(float);

template
__global__ void RepetitionPenaltyKernel(T*                logits,
                                        const float*      penalties,
                                        const int* const* token_ids_ptrs,
                                        const int*        sequence_length,
                                        int               vocab_size,
                                        int               mask_size)
{
    const int bi = blockIdx.x;

    const int  seq_len   = sequence_length[bi];
    const int* token_ids = token_ids_ptrs[bi];

    extern __shared__ uint32_t masks[];  // up to 512k vocab size on 64k smem devices

    for (int i = threadIdx.x; i < mask_size; i += blockDim.x) {
        masks[i] = 0;
    }

    __syncthreads();

    for (int ti = threadIdx.x; ti < seq_len; ti += blockDim.x) {
        const int token_id = token_ids[ti];
        atomicOr(&masks[token_id / 32], 1U << (token_id % 32));
    }

    __syncthreads();

    logits += bi * (int64_t)vocab_size;

    const float penalty = penalties[bi];

    for (int di = threadIdx.x; di < vocab_size; di += blockDim.x) {
        if (masks[di / 32] & (1U << (di % 32))) {
            const float logit = logits[di];
            logits[di]        = logit < 0.f ? logit * penalty : logit / penalty;
        }
    }
}

void ApplyRepetitionPenalty(Tensor&               logits,
                            const Buffer_& penalties,
                            const Buffer_&  token_ids_ptrs,
                            const Buffer_&   sequence_length,
                            cudaStream_t          stream)
{
    TM_CHECK_EQ(logits.ndim(), 2);
    auto invoke = [&](auto dtype) {
        using T                      = decltype(dtype);
        const auto [bsz, vocab_size] = logits.shapes(0, 1);
        const int mask_size          = cdiv((int)vocab_size, 32);
        const int smem_size          = sizeof(uint32_t) * mask_size;
        auto      func               = RepetitionPenaltyKernel;
        if (smem_size > (48 << 10)) {
            TM_CHECK_EQ(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size), 0);
        }
        TM_LOG_DEBUG("smem_size = %d", smem_size);
        func<<>>(
            logits.data(), penalties.data(), token_ids_ptrs.data(), sequence_length.data(), vocab_size, mask_size);
    };
    invoke(float{});
}

template
__global__ void batchApplyMinLengthPenalty(T* __restrict__ logits,
                                           const int* __restrict__ min_lengths,
                                           const int* __restrict__ sequence_lengths,
                                           const int vocab_size_padded,
                                           const int batch_size,
                                           const int* __restrict__ end_ids,
                                           const int end_ids_size)
{
    int tid = threadIdx.x + blockIdx.x * blockDim.x;
    int bid = tid / end_ids_size;
    int eid = tid % end_ids_size;
    if (bid < batch_size) {
        int end_id = end_ids[bid * end_ids_size + eid];
        if (end_id > 0 && sequence_lengths[bid] + 1 < min_lengths[bid]) {
            T mask_val                               = -getMaxValue();
            logits[bid * vocab_size_padded + end_id] = mask_val;
        }
    }
}

template
void invokeMinLengthPenalty(T*           logits,
                            const int*   min_lengths,
                            const int*   sequnece_lengths,
                            const int    vocab_size_padded,
                            const int    batch_size,
                            const int*   end_ids,
                            const int    end_ids_size,
                            cudaStream_t stream)
{
    const dim3 block(std::min(batch_size * end_ids_size, 1024));
    const dim3 grid((batch_size * end_ids_size + block.x - 1) / block.x);
    batchApplyMinLengthPenalty<<>>(
        logits, min_lengths, sequnece_lengths, vocab_size_padded, batch_size, end_ids, end_ids_size);
}

#define INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(T)                                                                       \
    template void invokeMinLengthPenalty(T*           logits,                                                          \
                                         const int*   min_lengths,                                                     \
                                         const int*   sequnece_lengths,                                                \
                                         const int    vocab_size_padded,                                               \
                                         const int    batch_size,                                                      \
                                         const int*   end_ids,                                                         \
                                         const int    end_ids_size,                                                    \
                                         cudaStream_t stream);

INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(float);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_penalty_kernels.h
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include 

#include "src/turbomind/utils/cuda_utils.h"

#include "src/turbomind/core/core.h"

namespace turbomind {

void ApplyRepetitionPenalty(Tensor&               logits,
                            const Buffer_& penalties,
                            const Buffer_&  token_ids_ptrs,
                            const Buffer_&   sequence_length,
                            cudaStream_t          stream);

template
void invokeBatchApplyTemperaturePenalty_v2(T*           logits,
                                           const T*     bias,
                                           const float* temperatures,
                                           const int    batch_size,
                                           const int    vocab_size,
                                           const int    vocab_size_padd,
                                           cudaStream_t stream);

template
void invokeMinLengthPenalty(T*           logits,
                            const int*   min_lengths,
                            const int*   sequnece_lengths,
                            const int    vocab_size_padded,
                            const int    batch_size,
                            const int*   end_ids,
                            const int    end_ids_size,
                            cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_topk_kernels.cu
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include 
#else
#include "3rdparty/cub/cub.cuh"
#endif

#include "src/turbomind/core/core.h"

#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_topk_kernels.h"

#include "src/turbomind/utils/constant.h"

namespace turbomind {

// __global__ void curandInitialize(curandState_t* state, const int size, const unsigned long long random_seed)
// {
//     if (threadIdx.x + blockIdx.x * blockDim.x < size) {
//         curand_init(random_seed, 0, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]);
//     }
// }

// void invokeCurandInitialize(curandState_t*           state,
//                             const size_t             batch_size,
//                             const unsigned long long random_seed,
//                             cudaStream_t             stream)
// {
//     dim3 block(256);
//     dim3 grid((int)(ceil(batch_size * 1.0 / 256)));
//     curandInitialize<<>>(state, batch_size, random_seed);
// }

// __global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* random_seeds)
// {
//     int idx = threadIdx.x + blockIdx.x * blockDim.x;
//     if (idx < size) {
//         curand_init(random_seeds[idx], 0, 0, &states[idx]);
//     }
// }

// void invokeCurandBatchInitialize(curandState_t*  states,
//                                  const size_t    batch_size,
//                                  const uint64_t* random_seeds,
//                                  cudaStream_t    stream)
// {
//     dim3 block(256);
//     dim3 grid((int)(ceil(batch_size * 1.0 / 256)));
//     static_assert(sizeof(uint64_t) == sizeof(unsigned long long));
//     curandBatchInitialize<<>>(states, batch_size, (unsigned long long*)random_seeds);
// }

__global__ void InitializeRandomStates_Kernel(curandState_t*            states,
                                              const unsigned long long* random_seeds,
                                              const bool*               mask,
                                              const size_t              size)
{
    if (auto idx = threadIdx.x + blockIdx.x * (size_t)blockDim.x; idx < size && mask[idx]) {
        curand_init(random_seeds[idx], 0, 0, &states[idx]);
    }
}

void InitializeRandomStates(
    curandState_t* states, const uint64_t* random_seeds, const bool* mask, size_t batch_size, cudaStream_t stream)
{
    constexpr int threads = 128;
    const int     blocks  = (batch_size + threads - 1) / threads;

    static_assert(sizeof(uint64_t) == sizeof(unsigned long long));

    InitializeRandomStates_Kernel<<>>(
        (curandState_t*)states, (const unsigned long long*)random_seeds, mask, batch_size);
}

template
__global__ void topKSortStage1(T*         logits,
                               int*       topk_tmp_id_buf,
                               T*         topk_tmp_val_buf,
                               const int  max_top_k,
                               const int* top_ks,
                               const int  vocab_size,
                               const int  vocab_size_padded)
{
    typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce;
    __shared__ typename BlockReduce::TempStorage    temp_storage;

    const int tid = threadIdx.x;
    const int bid = blockIdx.x;

    const int block_lane = bid % BLOCKS_PER_BEAM;  // block id for a beam
    const int batch_id   = bid / BLOCKS_PER_BEAM;  // row id for log_probs
    const int k          = top_ks[batch_id];
    if (k == 0) {
        return;
    }

    logits += batch_id * vocab_size_padded;
    topk_tmp_id_buf += batch_id * BLOCKS_PER_BEAM * max_top_k + block_lane * k;
    topk_tmp_val_buf += batch_id * BLOCKS_PER_BEAM * max_top_k + block_lane * k;

    TopK_2 partial;
    const T   MAX_T_VAL = getMaxValue();

    for (int ite = 0; ite < k; ite++) {
        partial.init();
#pragma unroll
        for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size;
             elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) {
            partial.insert(logits[elem_id], elem_id);
        }

        TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2);

        if (tid == 0) {
            topk_tmp_id_buf[ite]  = total.p;
            topk_tmp_val_buf[ite] = total.u;
            if (total.u != -getInfValue()) {
                logits[total.p] = -MAX_T_VAL;
            }
        }
        __syncthreads();
    }
}

template
__global__ void topKSortStage2(const int* top_ks,
                               const int  max_top_k,
                               const int* topk_tmp_id_buf,
                               T*         topk_tmp_val_buf,
                               const int  vocab_size_padded,
                               T*         sorted_logits,
                               int*       sorted_indices,
                               int*       kept)
{
    const T MAX_T_VAL = getMaxValue();

    const int tid      = threadIdx.x;
    const int batch_id = blockIdx.x;
    const int k        = top_ks[batch_id];

    if (k == 0) {
        return;
    }

    sorted_indices += batch_id * vocab_size_padded;
    sorted_logits += batch_id * vocab_size_padded;
    const int size   = k * BLOCKS_PER_BEAM;
    const int stride = max_top_k * BLOCKS_PER_BEAM;

    typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce;
    __shared__ typename BlockReduce::TempStorage        temp_storage;
    extern __shared__ char                              array[];
    __shared__ float                                    s_sum;
    __shared__ float                                    s_max;
    T*                                                  s_val  = topk_tmp_val_buf + batch_id * stride;
    int*                                                s_id   = reinterpret_cast(array);
    float*                                              s_val2 = reinterpret_cast(s_id + k);

    if (tid == 0) {
        kept[batch_id] = min(kept[batch_id], k);
        s_sum          = 0.0f;
    }

    TopK_2 partial;
    for (int ite = 0; ite < k; ite++) {
        partial.init();
#pragma unroll
        for (int i = tid; i < size; i += BLOCK_SIZE) {
            partial.insert((float)s_val[i], i);
        }

        TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2);

        if (tid == 0) {
            if (ite == 0) {
                s_max = total.u;
            }
            s_id[ite]      = total.p;
            s_val[total.p] = -MAX_T_VAL;
            total.u        = __expf(total.u - s_max);
            s_val2[ite]    = total.u;
            s_sum += total.u;
        }
        __syncthreads();
    }

    // norm selected
    float thread_sum = s_sum;
    topk_tmp_id_buf += batch_id * stride;
    for (int i = tid; i < k; i += BLOCK_SIZE) {
        sorted_logits[i]  = (T)(s_val2[i] / thread_sum);
        sorted_indices[i] = topk_tmp_id_buf[s_id[i]];
    }
}

#define CASE_K(K_MAX, BLOCK_SIZE_1, BLOCK_SIZE_2, BLOCKS_PER_BEAM)                                                     \
    topKSortStage1                                                                   \
        <<>>((T*)params.logits,                                 \
                                                                    topk_tmp_ids_buf,                                  \
                                                                    topk_tmp_val_buf,                                  \
                                                                    max_top_k,                                         \
                                                                    params.top_ks,                                     \
                                                                    params.vocab_size,                                 \
                                                                    params.vocab_size_padded);                         \
    topKSortStage2                                                                   \
        <<>>(params.top_ks,             \
                                                                                            params.max_top_k,          \
                                                                                            topk_tmp_ids_buf,          \
                                                                                            topk_tmp_val_buf,          \
                                                                                            params.vocab_size_padded,  \
                                                                                            (T*)params.sorted_logits,  \
                                                                                            params.sorted_indices,     \
                                                                                            params.kept);

template
void invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream)
{
    const int max_top_k             = params.max_top_k;
    const int batch_size            = params.batch_size;
    const int max_block_per_beam    = 8;
    int       topk_tmp_ids_buf_size = batch_size * max_top_k * max_block_per_beam;  // type int
    int       topk_tmp_val_buf_size = batch_size * max_top_k * max_block_per_beam;  // type T

    TM_CHECK(core::Context::stream().handle() == stream);

    Buffer_ topk_tmp_ids(round_up(topk_tmp_ids_buf_size, 32), kDEVICE);
    Buffer_   topk_tmp_val(round_up(topk_tmp_val_buf_size, 32), kDEVICE);

    auto topk_tmp_ids_buf = topk_tmp_ids.data();
    auto topk_tmp_val_buf = topk_tmp_val.data();

    if (max_top_k <= 16) {
        CASE_K(16, 128, 128, 8);
    }
    else if (max_top_k <= 32) {
        CASE_K(32, 256, 128, 8);
    }
    else if (max_top_k <= 64) {
        CASE_K(64, 256, 256, 8);
    }
    else if (max_top_k <= 1024) {
        CASE_K(1024, 256, 256, 8);
    }
    else {
        throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
    }
}

template void invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_topk_kernels.h
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include "src/turbomind/utils/logger.h"
#include 
namespace turbomind {

template
void invokeBatchTopKSampling(void*          workspace,
                             size_t&        workspace_size,
                             const T*       log_probs,
                             int*           ids,
                             int*           sequence_length,
                             bool*          finished,
                             float*         cum_log_probs,
                             float*         output_log_probs,
                             float*         sampled_logprobs,
                             uint32_t*      sampled_indexes,
                             uint32_t*      sampled_nums,
                             curandState_t* curandstate,
                             const int      max_top_k,
                             const int*     top_ks,
                             const float    top_p,
                             const float*   top_ps,
                             const int      vocab_size_padded,
                             const int*     end_ids,
                             cudaStream_t   stream,
                             const int      batch_size,
                             const bool*    skip_decode);

// void invokeCurandInitialize(curandState_t*     state,
//                             const size_t       batch_size,
//                             unsigned long long random_seed,
//                             cudaStream_t       stream);

// void invokeCurandBatchInitialize(curandState_t*  states,
//                                  const size_t    batch_size,
//                                  const uint64_t* random_seeds,
//                                  cudaStream_t    stream);

void InitializeRandomStates(curandState_t*  states,  //
                            const uint64_t* random_seeds,
                            const bool*     mask,
                            size_t          batch_size,
                            cudaStream_t    stream);

struct TopKSortFilterParams {
    void* logits;
    void* sorted_logits;
    int*  sorted_indices;
    int*  kept;
    int*  top_ks;
    int   max_top_k;
    int   batch_size;
    int   vocab_size;
    int   vocab_size_padded;
};

template
void invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_topp_kernels.cu
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include 
#else
#include "3rdparty/cub/cub.cuh"
#endif

#include "src/turbomind/core/core.h"

#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_topp_kernels.h"

#include "src/turbomind/utils/constant.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

__global__ void topPSortInitialize(const int    vocab_size_padded,
                                   const int    vocab_size,
                                   const size_t batch_size,
                                   const int*   top_ks,
                                   int*         topp_id_val_buf,
                                   int*         begin_offset_buf,
                                   int*         end_offset_buf)
{
    int tid = threadIdx.x;
    int bid = blockIdx.x;

    // According to https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
    // `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`
    // We need to move `begin_offset` (instead of `end_offset`) to make empty intervals
    if (bid == 0) {
        for (int i = tid; i < batch_size; i += blockDim.x) {
            int beg = i * vocab_size_padded;
            int end = i * vocab_size_padded + vocab_size;
            if (top_ks[i] > 0) {  // already sorted by topk, make it an empty interval
                beg = end;
            }
            begin_offset_buf[i] = beg;
            end_offset_buf[i]   = end;
        }
    }

    int index = tid + bid * blockDim.x;
    while (index < batch_size * vocab_size_padded) {
        int batch_id = index / vocab_size_padded;
        if (top_ks[batch_id] == 0) {
            // sort by topp
            topp_id_val_buf[index] = index % vocab_size_padded;
        }
        index += blockDim.x * gridDim.x;
    }
}

void invokeTopPSortInitialize(const int    vocab_size_padded,
                              const int    vocab_size,
                              const size_t batch_size,
                              const int*   top_ks,
                              int*         topp_id_val_buf,
                              int*         begin_offset_buf,
                              int*         end_offset_buf,
                              cudaStream_t stream)
{
    const size_t block_size = 512;
    const size_t grid_size  = (batch_size * vocab_size_padded + block_size - 1) / block_size;
    topPSortInitialize<<>>(
        vocab_size_padded, vocab_size, batch_size, top_ks, topp_id_val_buf, begin_offset_buf, end_offset_buf);
}

template
static __global__ void softmax(T* logits, const int vocab_size_padded, const int vocab_size, const int* kept)
{
    int bid = blockIdx.x;
    int n   = kept[bid];
    // skip softmax as it was already done by topk
    if (n != vocab_size) {
        return;
    }
    logits += bid * vocab_size_padded;

    float            max_val = -1 * FLT_MAX;
    __shared__ float s_max_val;
    __shared__ float s_sum_val;

    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {
        max_val = max(max_val, (float)logits[tid]);
    }

    max_val = blockReduceMax((float)max_val);
    if (threadIdx.x == 0) {
        s_max_val = max_val;
    }
    __syncthreads();

    max_val       = s_max_val;
    float sum_val = 0.0f;
    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {
        logits[tid] = __expf((float)logits[tid] - max_val);
        sum_val += (float)logits[tid];
    }

    sum_val = blockReduceSum(sum_val);
    if (threadIdx.x == 0) {
        s_sum_val = sum_val;
    }
    __syncthreads();

    sum_val = s_sum_val;
    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {
        logits[tid] = ((float)logits[tid] / sum_val);
    }
}

template
void invokeSoftmax(T*           logits,
                   const int    vocab_size_padded,
                   const int    vocab_size,
                   const int    batch_size,
                   const int*   kept,
                   cudaStream_t stream)
{
    dim3 grid(batch_size);
    dim3 block(std::min(vocab_size_padded, 1024));
    softmax<<>>(logits, vocab_size_padded, vocab_size, kept);
}

#define INSTANTIATE_INVOKE_SOFTMAX(T)                                                                                  \
    template void invokeSoftmax(T * logits,                                                                         \
                                   const int    vocab_size_padded,                                                     \
                                   const int    vocab_size,                                                            \
                                   const int    batch_size,                                                            \
                                   const int*   kept,                                                                  \
                                   cudaStream_t stream);

INSTANTIATE_INVOKE_SOFTMAX(float);

template
__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T*     logits,
                                                                          T*           sorted_logits,
                                                                          int*         sorted_indices,
                                                                          int*         kept,
                                                                          const int    vocab_size,
                                                                          const int    vocab_size_padded,
                                                                          int*         begin_offset_buf,
                                                                          int*         end_offset_buf,
                                                                          const float* top_ps,
                                                                          const int*   top_ks)
{
    int thread_id = threadIdx.x;
    int batch_id  = blockIdx.x;
    if (top_ks[batch_id] > 0) {
        return;
    }

    logits += batch_id * vocab_size_padded;
    sorted_logits += batch_id * vocab_size_padded;
    sorted_indices += batch_id * vocab_size_padded;
    float p_threshold = top_ps[batch_id];

    typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce;
    __shared__ typename BlockReduce::TempStorage               temp_storage;
    TopK                                             partial;

    const T MAX_T_VAL = getMaxValue();

#pragma unroll
    for (int i = 0; i < MAX_K; ++i) {
        partial.p[i] = -1;
        partial.u[i] = -MAX_T_VAL;
    }

#pragma unroll
    for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
        partial.insert(logits[elem_id], elem_id);
    }

    TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op);

    if (thread_id == 0) {
        float sum_prob = 0.f;

#pragma unroll
        for (int i = 0; i < MAX_K; i++) {
            sum_prob += (float)total.u[i];
        }

        if (sum_prob >= p_threshold) {
            begin_offset_buf[batch_id] = end_offset_buf[batch_id];
            kept[batch_id]             = MAX_K;

#pragma unroll
            for (int i = 0; i < MAX_K; ++i) {
                sorted_logits[i]  = (float)total.u[i] / sum_prob;
                sorted_indices[i] = total.p[i];
            }
        }
    }
}

template
void invokeTopPSort(TopPSortParams& params, cudaStream_t stream)
{
    const int num_items = params.vocab_size_padded * (params.batch_size - 1) + params.vocab_size;

    size_t cub_temp_storage_size{};
    check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
                                                                        cub_temp_storage_size,
                                                                        (T*)nullptr,
                                                                        (T*)nullptr,
                                                                        (int*)nullptr,
                                                                        (int*)nullptr,
                                                                        num_items,
                                                                        params.batch_size,
                                                                        (int*)nullptr,
                                                                        (int*)nullptr,
                                                                        0,              // begin_bit
                                                                        sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8
                                                                        stream));       // cudaStream_t

    TM_CHECK(core::Context::stream().handle() == stream);

    Buffer_ cub_temp_storage(cub_temp_storage_size, kDEVICE);

    Buffer_ topp_ids(params.batch_size * params.vocab_size_padded, kDEVICE);
    Buffer_ beg_offset(params.batch_size, kDEVICE);
    Buffer_ end_offset(params.batch_size, kDEVICE);

    auto topp_ids_buf   = topp_ids.data();
    auto beg_offset_buf = beg_offset.data();
    auto end_offset_buf = end_offset.data();

    invokeTopPSortInitialize(params.vocab_size_padded,
                             params.vocab_size,
                             params.batch_size,
                             params.top_ks,
                             topp_ids_buf,
                             beg_offset_buf,
                             end_offset_buf,
                             stream);

    topp_beam_topk_kernel<<>>((T*)params.logits,
                                                                            (T*)params.sorted_logits,
                                                                            params.sorted_indices,
                                                                            params.kept,
                                                                            params.vocab_size,
                                                                            params.vocab_size_padded,
                                                                            beg_offset_buf,
                                                                            end_offset_buf,
                                                                            params.top_ps,
                                                                            params.top_ks);

    check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage.data(),
                                                                        cub_temp_storage_size,
                                                                        (T*)params.logits,
                                                                        (T*)params.sorted_logits,
                                                                        topp_ids_buf,
                                                                        params.sorted_indices,
                                                                        num_items,
                                                                        params.batch_size,
                                                                        beg_offset_buf,
                                                                        end_offset_buf,
                                                                        0,              // begin_bit
                                                                        sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8
                                                                        stream));       // cudaStream_t
}

template void invokeTopPSort(TopPSortParams& params, cudaStream_t stream);

template
__global__ void topPMinPFilter(T*           sorted_logits,
                               int*         sorted_indices,
                               int*         kept,
                               const int    vocab_size_padded,
                               const float* top_ps,
                               const float* min_ps)
{
    int   tid        = threadIdx.x;
    int   bid        = blockIdx.x;
    int   n          = kept[bid];
    float sum_logits = 1.f;
    float top_p      = top_ps[bid];
    float min_p      = min_ps[bid];
    sorted_logits += bid * vocab_size_padded;
    sorted_indices += bid * vocab_size_padded;

    const float kEps = 1e-6f;

    __shared__ int   s_kept;
    __shared__ float s_sum;

    if (tid == 0) {
        s_kept = n;
        s_sum  = 1.f;
    }
    __syncthreads();

    if (top_p != 1.0f) {
        typedef cub::BlockScan  BlockScan;
        __shared__ typename BlockScan::TempStorage temp_storage;
        // Initialize running total
        BlockPrefixCallbackOp prefix_op(0);
        // topp
        int   end        = ((n + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
        float prefix_sum = 0.f;
        for (int i = tid; i < end; i += BLOCK_SIZE) {
            float thread_count = (i < n) ? (float)sorted_logits[i] : 0.f;
            BlockScan(temp_storage).InclusiveSum(thread_count, prefix_sum, prefix_op);
            auto count = __syncthreads_count(prefix_sum > top_p);
            if (count != 0 || (i + BLOCK_SIZE >= end)) {
                if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {
                    s_kept = min(i + 1, n);
                    s_sum  = prefix_sum;
                }
                break;
            }
        };
        __syncthreads();
    }

    if (min_p != 0.f) {
        n          = s_kept;
        sum_logits = s_sum;

        typedef cub::BlockScan  BlockScan;
        __shared__ typename BlockScan::TempStorage temp_storage;
        // Initialize running total
        BlockPrefixCallbackOp prefix_op(0);
        // minp
        float scaled_min_p = (float)sorted_logits[0] / (sum_logits + kEps) * min_p;
        int   end          = ((n + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
        float prefix_sum   = 0.f;
        for (int i = tid; i < end; i += BLOCK_SIZE) {
            float thread_count = (i < n) ? (float)sorted_logits[i] / (sum_logits + kEps) : 0.f;
            BlockScan(temp_storage).ExclusiveSum(thread_count, prefix_sum, prefix_op);
            auto count = __syncthreads_count(thread_count < scaled_min_p);
            if (count != 0 || (i + BLOCK_SIZE >= end)) {
                if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {
                    if (count == 0) {
                        ++i;
                        prefix_sum += thread_count;
                    }
                    s_kept = min(i, n);
                    s_sum *= prefix_sum;
                }
                break;
            }
        };
        __syncthreads();
    }

    if (top_p != 1.f || min_p != 0.f) {
        n          = s_kept;
        sum_logits = s_sum;
        if (tid == 0) {
            kept[bid] = n;
        }
        // norm
        for (int i = tid; i < n; i += BLOCK_SIZE) {
            sorted_logits[i] = (float)sorted_logits[i] / sum_logits;
        }
    }
}

template
void invokeTopPMinPFilter(TopPMinPFilterParams& params, cudaStream_t stream)
{
    topPMinPFilter<<>>((T*)params.sorted_logits,
                                                                  params.sorted_indices,
                                                                  params.kept,
                                                                  params.vocab_size_padded,
                                                                  params.top_ps,
                                                                  params.min_ps);
}

template void invokeTopPMinPFilter(TopPMinPFilterParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/sampling_topp_kernels.h
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include 

namespace turbomind {

void invokeTopPSortInitialize(const int    vocab_size_padded,
                              const int    vocab_size,
                              const size_t batch_size,
                              const int*   top_ks,
                              int*         topp_id_val_buf,
                              int*         begin_offet_buf,
                              int*         end_offset_buf,
                              cudaStream_t stream);

template
void invokeSoftmax(T*           logits,
                   const int    vocab_size_padded,
                   const int    vocab_size,
                   const int    batch_size,
                   const int*   kept,
                   cudaStream_t stream);

struct BlockPrefixCallbackOp {
    // Running prefix
    float running_total;
    // Constructor
    __device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
    // Callback operator to be entered by the first warp of threads in the block.
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
    __device__ float operator()(float block_aggregate)
    {
        float old_prefix = running_total;
        running_total += block_aggregate;
        return old_prefix;
    }
};

struct TopPSortParams {
    void*  logits;
    void*  sorted_logits;
    int*   sorted_indices;
    int*   kept;
    int*   top_ks;
    float* top_ps;
    int    batch_size;
    int    vocab_size;
    int    vocab_size_padded;
};

template
void invokeTopPSort(TopPSortParams& params, cudaStream_t stream);

struct TopPMinPFilterParams {
    void*  sorted_logits;
    int*   sorted_indices;
    int*   kept;
    float* top_ps;
    float* min_ps;
    int    batch_size;
    int    vocab_size;
    int    vocab_size_padded;
};

template
void invokeTopPMinPFilter(TopPMinPFilterParams& params, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/stop_criteria_kernels.cu
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/stop_criteria_kernels.h"

namespace turbomind {

__global__ void stop_words_criterion_v2(const int** token_ids_ptrs,
                                        const int*  sequence_length,
                                        const int*  stop_words,
                                        bool*       finished,
                                        int         stop_words_len,
                                        int         batch_size)
{
    const int id        = blockIdx.x * blockDim.x + threadIdx.x;
    const int batch_idx = blockIdx.y;

    const int* base_stop_words = stop_words + batch_idx * 2 * stop_words_len;
    const int* base_offsets    = base_stop_words + stop_words_len;

    if (id >= stop_words_len || base_offsets[id] < 0) {
        return;
    }

    const int item_end   = base_offsets[id];
    const int item_start = (id > 0) ? base_offsets[id - 1] : 0;
    const int item_size  = item_end - item_start;

    const int  seq_len   = sequence_length[batch_idx];
    const int* token_ids = token_ids_ptrs[batch_idx];

    /* Enough previously generated tokens to look for a match */
    if (seq_len >= item_size) {
        // token_ids[seq_len - 1] is the last token
        for (int token_idx = item_size - 1, offset = seq_len - 1; token_idx >= 0; token_idx--, offset--) {
            if (token_ids[offset] != base_stop_words[item_start + token_idx]) {
                return;
            }
        }
        finished[batch_idx] = true;
    }
}

void invokeStopWordsCriterion_v2(const int**  token_ids_ptrs,
                                 const int*   sequence_length,
                                 const int*   stop_words,
                                 bool*        finished,
                                 int          stop_words_len,
                                 int          batch_size,
                                 cudaStream_t stream)
{
    // Check if we have sampled a word from the stop_words list. If so, stop the sequence.

    const int  block = std::min(round_up(stop_words_len, 32), 256);
    const dim3 grid(cdiv(stop_words_len, block), batch_size);

    stop_words_criterion_v2<<>>(
        token_ids_ptrs, sequence_length, stop_words, finished, stop_words_len, batch_size);
}

__global__ void length_criterion_v2(bool*      finished,  //
                                    const int* sequence_length,
                                    const int* sequence_length_limit,
                                    int        batch_size)
{
    const int idx = threadIdx.x + blockDim.x * blockIdx.x;
    if (idx >= batch_size) {
        return;
    }
    if (sequence_length[idx] >= sequence_length_limit[idx]) {
        finished[idx] = true;
    }
}

void invokeLengthCriterion_v2(bool*        finished,  //
                              const int*   sequence_length,
                              const int*   sequence_length_limit,
                              int          batch_size,
                              cudaStream_t stream)
{
    // Check if we have attained the sequence length limit. If so, stop the sequence.

    constexpr int block = 256;
    const int     grid  = cdiv(batch_size, block);

    length_criterion_v2<<>>(finished, sequence_length, sequence_length_limit, batch_size);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/stop_criteria_kernels.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include 

#include 

namespace turbomind {

void invokeStopWordsCriterion_v2(const int**  token_ids_ptrs,
                                 const int*   sequence_length,
                                 const int*   stop_words,
                                 bool*        finished,
                                 int          stop_words_len,
                                 int          batch_size,
                                 cudaStream_t stream);

void invokeLengthCriterion_v2(bool*        finished,  //
                              const int*   sequence_length,
                              const int*   sequence_length_limit,
                              int          batch_size,
                              cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/test_quantization.cc
================================================


#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/stream.h"

#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/kernels/quantization.h"

using namespace turbomind;

int main()
{
    core::ContextGuard ctx{core::Stream::create(), core::Allocator{kCPU}, core::Allocator{kDEVICE}};

    auto stream = core::Context::stream().handle();

    const int m = 1024, n = 2048, gs = 128;

    Tensor_ h_x{{m, n}, kCPU};
    Tensor_ h_x_f{{m, n}, kCPU};

    Tensor_ x{{m, n}, kDEVICE};
    Tensor_ x_f{{m, n}, kDEVICE};
    Tensor_ x_q{{m, n}, kDEVICE};

    // Tensor_ x_s{{{m, n / gs}, {1, round_up(m, 4)}}, kDEVICE};
    Tensor_ x_s;

    RNG r;
    r.set_stream(stream);

    /////////////////////////////////////////////////////////////////////////////////////
    // round trip of dequant(quant(x))
    r.UniformFloat(x, 2.f, 2.f);  // [-1, +1]
    Copy(x, h_x);
    QuantizeSymm(x_q, x_s, x, stream);
    DequantizeSymm(x_f, x_q, x_s, stream);
    Copy(x_f, h_x_f);
    FC_Header();
    FC_Print(FastCompare(x_f, x, stream));

    /////////////////////////////////////////////////////////////////////////////////////
    // round trip of dequant(quant(dequant(quant(x)))), aligned representable values
    Copy(x_f, x);
    Clear(x_f);
    QuantizeSymm(x_q, x_s, x, stream);
    DequantizeSymm(x_f, x_q, x_s, stream);
    FC_Print(FastCompare(x_f, x, stream));

    /////////////////////////////////////////////////////////////////////////////////////
    // round trip of dequant(quant(x))
    // x_s = {{cdiv(m, gs), cdiv(n, gs)}, kDEVICE};
    x_s = {};
    r.UniformFloat(x, 2.f, 2.f);  // [-1, +1]
    Copy(x, h_x);
    QuantizeSymmBlock(x_q, x_s, x, stream);
    DequantizeSymmBlock(x_f, x_q, x_s, stream);
    FC_Print(FastCompare(x_f, x, stream));

    /////////////////////////////////////////////////////////////////////////////////////
    // round trip of dequant(quant(dequant(quant(x)))), aligned representable values
    Copy(x_f, x);
    Clear(x_f);
    QuantizeSymmBlock(x_q, x_s, x, stream);
    DequantizeSymmBlock(x_f, x_q, x_s, stream);
    FC_Print(FastCompare(x_f, x, stream));

    return 0;
}


================================================
FILE: src/turbomind/kernels/unfused_attention_kernels.cu
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 
#include 

#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/unfused_attention_kernels.h"

#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

template
__global__ void __launch_bounds__(1024) softmax_kernel(T*           attn_score,
                                                       const float* qk,
                                                       const T*     attn_mask,
                                                       const T*     sinks,
                                                       const int    batch_size,
                                                       const int    head_num,
                                                       const int    q_length,
                                                       const int    k_length)
{
    // attn_score [batch_size, num_heads, q_length, k_length]
    // qk         [batch_size, num_heads, q_length, k_length]
    // attn_mask  [batch_size,            q_length, k_length]

    const long bi = blockIdx.y;  // Batch index.
    const int  hi = blockIdx.z;  // Head index.

    __shared__ float s_mean, s_max;

    float sink = -std::numeric_limits::infinity();
    if (sinks) {
        sink = sinks[hi];
    }

    // Loop along with Q dimension.
    for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {

        float data[ITEMS_PER_THREAD];
        long  qk_offset;
        float local_max = -std::numeric_limits::infinity();

        // Loop along with K dimension.
        for (int i = 0; i < ITEMS_PER_THREAD; i++) {
            if (int ki = blockDim.x * i + threadIdx.x; ki < k_length) {  // Index of K dimension.

                qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + ki;

                float qk_val  = static_cast(qk[qk_offset]);
                float qk_bias = 0.0f;

                long  mask_offset = (bi * q_length + qi) * k_length + ki;
                float mask_val    = static_cast(ldg(&attn_mask[mask_offset]));

                if (!mask_val) {
                    qk_bias -= std::numeric_limits::infinity();
                }

                data[i]   = qk_val + qk_bias;
                local_max = fmaxf(local_max, data[i]);
            }
        }

        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max);

        if (threadIdx.x == 0) {
            s_max = fmaxf(max_val, sink);
        }

        __syncthreads();

        float local_sum = 0;

        for (int i = 0; i < ITEMS_PER_THREAD; i++) {
            if (blockDim.x * i + threadIdx.x < k_length) {
                data[i] = expf(data[i] - s_max);
                local_sum += data[i];
            }
        }

        float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum);

        if (threadIdx.x == 0) {
            sum_val += expf(sink - s_max);
            s_mean = sum_val;
            s_mean = fdividef(1.f, s_mean);
        }
        __syncthreads();

        for (int i = 0; i < ITEMS_PER_THREAD; i++) {
            if (blockDim.x * i + threadIdx.x < k_length) {
                qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;
                attn_score[qk_offset] = (T)(data[i] * s_mean);
            }
        }
    }
}

template
void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream)
{
    // attention_score,    (batch_size, head_num, q_length, k_length), softmax output.
    // qk,                 (batch_size, head_num, q_length, k_length), QK^T.
    // attention_mask,     (batch_size, q_length, k_length), attention mask.

    dim3 grid(param.q_length, param.batch_size, param.num_heads);

    auto invoke = [&](auto items_per_thread) {
        const int block = round_up(cdiv(param.k_length, items_per_thread.value), WARP_SIZE);
        FT_CHECK(block <= 1024);
        softmax_kernel<<>>(param.attention_score,
                                                                              param.qk,
                                                                              param.attention_mask,
                                                                              param.sinks,
                                                                              param.batch_size,
                                                                              param.num_heads,
                                                                              param.q_length,
                                                                              param.k_length);
    };

    const auto k = param.k_length;

    if (k <= 1024) {
        invoke(std::integral_constant{});
    }
    else if (k <= 2048) {
        invoke(std::integral_constant{});
    }
    else if (k <= 4096) {
        invoke(std::integral_constant{});
    }
    else if (k <= 8192) {
        invoke(std::integral_constant{});
    }
    else if (k <= 16384) {
        invoke(std::integral_constant{});
    }
    else if (k <= 32768) {
        invoke(std::integral_constant{});
    }
    else if (k <= 65536) {
        invoke(std::integral_constant{});
    }
    else if (k <= 131072) {
        invoke(std::integral_constant{});
    }
    else {
        throw std::runtime_error("not impelmented");
    }
}

template void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream);
#endif
#if ENABLE_FP32
template void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream);
#endif

// clang-format off
template struct packed_type;
template <>          struct packed_type         { using type = float; }; // we don't need to pack float by default
template <>          struct packed_type          { using type = half2; };

#ifdef ENABLE_BF16
template<>
struct packed_type<__nv_bfloat16> {
    using type = __nv_bfloat162;
};
#endif

template struct num_elems;
template <>          struct num_elems           { static constexpr int value = 1; };
template <>          struct num_elems          { static constexpr int value = 2; };
template <>          struct num_elems          { static constexpr int value = 4; };
template <>          struct num_elems            { static constexpr int value = 1; };
template <>          struct num_elems           { static constexpr int value = 2; };
#ifdef ENABLE_BF16
template <>          struct num_elems<__nv_bfloat16>   { static constexpr int value = 1; };
template <>          struct num_elems<__nv_bfloat162>  { static constexpr int value = 2; };
#endif

template struct packed_as;
template          struct packed_as              { using type = T; };
template<>                    struct packed_as          { using type = half2; };
template<>                    struct packed_as         { using type = float2; };
template<>                    struct packed_as         { using type = int16_t; };
template<>                    struct packed_as        { using type = int2; };
template<>                    struct packed_as          { using type = half; };
#ifdef ENABLE_BF16
template<> struct packed_as<__nv_bfloat16,  2> { using type = __nv_bfloat162; };
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16;  };
#endif

inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator*(float2 a, float  b) { return make_float2(a.x * b, a.y * b); }
// clang-format on

template
__global__ void transpose_remove_padding(const T*     src,
                                         T*           dst,
                                         const int    batch_size,
                                         const int    seq_len,
                                         const int    head_num,
                                         const int    size_per_head,
                                         const int*   mask_offset,
                                         const float* scale,
                                         const int    int8_mode)
{
    // TODO: optimize this kernel?
    // do remove_sequence_length_padding
    const int bid = blockIdx.x;  // batch * seq_len or valid_word_num

    const int token_offset = mask_offset ? mask_offset[bid] : 0;

    const int src_batch_id = (bid + token_offset) / seq_len;
    const int src_seq_id   = (bid + token_offset) % seq_len;

    const int dst_seq_id = bid;

    const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head;
    const int dst_offset_base = dst_seq_id * head_num * size_per_head;

    using Int8_Packed_T  = typename packed_as::value>::type;
    using Float_Packed_T = typename packed_as::value>::type;
    const Float_Packed_T scale_val =
        int8_mode == 2 ? cuda_cast(*scale) : cuda_cast(0.0f);

    for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) {
        const int head_id   = idx / size_per_head;
        const int hidden_id = idx % size_per_head;
        const T   src_elem  = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]);
        if (int8_mode == 2) {
            reinterpret_cast(dst)[dst_offset_base + idx] =
                cuda_cast(cuda_cast(src_elem) * scale_val);
        }
        else {
            dst[dst_offset_base + idx] = src_elem;
        }
    }
}

// clang-format off
template
void invokeTransposeAttentionOutRemovePadding(T*           src,
                                              T*           dst,
                                              const int    valid_word_num,
                                              const int    batch_size,
                                              const int    seq_len,
                                              const int    head_num,
                                              const int    size_per_head,
                                              const int*   mask_offset,
                                              const float* scale,
                                              const int    int8_mode,
                                              cudaStream_t stream)
{
#ifdef ENABLE_BF16
    bool is_half2 = (std::is_same::value || std::is_same::value) && (size_per_head % 2 == 0);
#else
    bool is_half2 = (std::is_same::value) && (size_per_head % 2 == 0);
#endif
    using T2       = typename TypeConverter::Type;  // fp16 to half2, bf16 to bf162
    int block_size = head_num * size_per_head;
    if (is_half2) {
        while (block_size > 512) {
            if (block_size % 2 == 0) {
                block_size /= 2;
            }
            else {
                is_half2   = false;
                block_size = std::min(block_size, 1024);
                break;
            }
        }
    }
    else {
        block_size = std::min(block_size, 1024);
    }

    if (is_half2) {
        transpose_remove_padding<<>>(
            (T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset, scale, int8_mode);
    }
    else {
        transpose_remove_padding<<>>(
            src, dst, batch_size, seq_len, head_num, size_per_head, mask_offset, scale, int8_mode);
    }
}
// clang-format on

#define INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(T)                                                               \
    template void invokeTransposeAttentionOutRemovePadding(T*           src,                                           \
                                                           T*           dst,                                           \
                                                           const int    valid_word_num,                                \
                                                           const int    batch_size,                                    \
                                                           const int    seq_len,                                       \
                                                           const int    head_num,                                      \
                                                           const int    size_per_head,                                 \
                                                           const int*   mask_offset,                                   \
                                                           const float* scale,                                         \
                                                           const int    int8_mode,                                     \
                                                           cudaStream_t stream)
#ifdef ENABLE_FP32
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(float);
#endif
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING

template
__global__ void addRelativeAttentionBias(
    T* qk_buf, const T* relative_attention_bias, const int batch_size, const int head_num, const int seq_len)
{
    for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) {
        int batch_id = i / seq_len;
        int seq_id   = i % seq_len;

        const int bias_index = blockIdx.x * seq_len + seq_id;
        const int qk_index   = batch_id * gridDim.x * seq_len + bias_index;
        qk_buf[qk_index]     = add(qk_buf[qk_index], relative_attention_bias[bias_index]);
    }
}

template
void invokeAddRelativeAttentionBias(T*           qk_buf,
                                    const T*     relative_attention_bias,
                                    const int    batch_size,
                                    const int    head_num,
                                    const int    seq_len,
                                    cudaStream_t stream)
{
    // qk_buf: [batch_size, head_num, seq_len, seq_len]
    // relative_attention_bias: [1, head_num, seq_len, seq_len]
    dim3 grid(head_num * seq_len);
    dim3 block(512);
    using T2 = typename TypeConverter::Type;
#ifdef ENABLE_BF16
    const bool is_half2 = (std::is_same::value || std::is_same::value) && (seq_len % 2 == 0);
#else
    const bool is_half2 = (std::is_same::value) && (seq_len % 2 == 0);
#endif
    if (is_half2) {
        addRelativeAttentionBias<<>>(
            (T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2);
    }
    else {
        addRelativeAttentionBias<<>>(
            qk_buf, relative_attention_bias, batch_size, head_num, seq_len);
    }
}

#define INSTANTIATEADDRELATIVEATTENTIONBIAS(T)                                                                         \
    template void invokeAddRelativeAttentionBias(T*           qk_buf,                                                  \
                                                 const T*     relative_attention_bias,                                 \
                                                 const int    batch_size,                                              \
                                                 const int    head_num,                                                \
                                                 const int    seq_len,                                                 \
                                                 cudaStream_t stream)
#if 0
#ifdef ENABLE_FP32
INSTANTIATEADDRELATIVEATTENTIONBIAS(float);
#endif
INSTANTIATEADDRELATIVEATTENTIONBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDRELATIVEATTENTIONBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDRELATIVEATTENTIONBIAS
#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/kernels/unfused_attention_kernels.h
================================================
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

namespace turbomind {

template
struct MaskedSoftmaxParam {
    // Common parameters.
    T*           attention_score = nullptr;  // (batch_size, head_num, q_length, k_length)
    const float* qk              = nullptr;  // (batch_size, head_num, q_length, k_length)
    const T*     attention_mask  = nullptr;  // (batch_size, q_length, k_length)
    int          batch_size      = 0;
    int          q_length        = 0;
    int          k_length        = 0;
    int          num_heads       = 0;
    const T*     sinks           = nullptr;
};

template
void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream);

template
void invokeTransposeQKV(T*           dst,
                        T*           src,
                        const int    batch_size,
                        const int    seq_len,
                        const int    head_num,
                        const int    size_per_head,
                        const float* scale,
                        const int    int8_mode,
                        cudaStream_t stream);

template
void invokeTransposeAttentionOutRemovePadding(T*           src,
                                              T*           dst,
                                              const int    valid_word_num,
                                              const int    batch_size,
                                              const int    seq_len,
                                              const int    head_num,
                                              const int    size_per_head,
                                              const int*   mask_offset,
                                              const float* scale,
                                              const int    int8_mode,
                                              cudaStream_t stream);

template
void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                    T*           k_buf,
                                    T*           v_buf,
                                    T*           QKV,
                                    const T*     qkv_bias,
                                    const int*   padding_offset,
                                    const int*   context_length,
                                    const int*   input_length,
                                    const float* rope_theta,
                                    const int    batch_size,
                                    const int    seq_len,
                                    const int    token_num,
                                    const int    head_num,
                                    const int    kv_head_num,
                                    const int    size_per_head,
                                    const int    rotary_embedding_dim,
                                    float        rotary_embedding_base,
                                    int          max_position_embeddings,
                                    bool         use_dynamic_ntk,
                                    bool         use_logn_attn,
                                    cudaStream_t stream);

template
void invokeTranspose4d(T*           dst,
                       T*           src,
                       const int    local_batch_size,
                       const int    seq_len,
                       const int    size_per_head,
                       const int    local_hidden_units,
                       const int    local_head_num,
                       const int    batch_size,
                       const int    ite,
                       cudaStream_t stream);

template
void invokeTranspose4dBatchMajor(T*           k_dst,
                                 T*           v_dst,
                                 const T*     k_src,
                                 const T*     v_src,
                                 const int    local_batch_size,
                                 const int    seq_len,
                                 const int    max_seq_len,
                                 const int    size_per_head,
                                 const int    local_head_num,
                                 cudaStream_t stream);

template
void invokeAddRelativeAttentionBias(T*           qk_buf,
                                    const T*     relative_attention_bias,
                                    const int    batch_size,
                                    const int    head_num,
                                    const int    seq_len,
                                    cudaStream_t stream);

template
void invokeAddHead3SizeQKVBias(const T*     mm_qkv,
                               const T*     bias_qkv,
                               T*           q_buf_,
                               T*           k_buf_,
                               T*           v_buf_,
                               const int    batch,
                               const int    window_num,
                               const int    window_len,
                               const int    head_num,
                               const int    size_per_head,
                               cudaStream_t stream);

template
void invokeMaskedSoftMaxWithRelPosBias(T*           qk_buf,
                                       const T*     attn_mask,
                                       const T*     relative_pos_bias,
                                       const int    batch_size,
                                       const int    num_head,
                                       const int    window_num,
                                       const int    window_len,
                                       const float  qk_scale,
                                       cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/macro.h
================================================
#pragma once

#if !defined(__PRETTY_FUNCTION__) && !defined(__GNUC__)

#define __PRETTY_FUNCTION__ __FUNCSIG__

#endif

typedef unsigned int uint;


================================================
FILE: src/turbomind/models/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.8)

add_library(models STATIC
        language_model.cc
        input_processor.cc
        output_processor.cc
        llama/LlamaLinear.cu
        llama/BlockManager.cc
        llama/BlockTrie.cc
        llama/SequenceManager.cc
        llama/LlamaWeight.cc
        llama/LlamaDenseWeight.cc
        llama/LlamaDecoderLayerWeight.cc
        llama/LlamaFfnLayer.cc
        llama/moe_ffn_layer.cc
        llama/unified_decoder.cc
        llama/unified_attention_layer.cc
        llama/llama_kernels.cu
        llama/llama_utils.cu
        llama/mla_utils.cu
        llama/GatedDeltaNetWeight.cc
        llama/GatedDeltaNetLayer.cc
        llama/gated_delta_net_kernels.cu)
set_property(TARGET models PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET models PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(models PUBLIC
        generation
        core
        gemm2
        rms_norm
        CUDA::cublas
        CUDA::cudart
        nvidia::cutlass::cutlass
        activation_kernels
        activation
        attention
        decoding_kernels
        quantization_kernels
        unfused_attention_kernels
        gpt_kernels
        memory_utils
        cuda_utils
        logger
        anomaly_handler)
target_compile_options(models PRIVATE $<$:-Xptxas=-v --generate-line-info --threads 8>)

if(BUILD_TEST)
    add_executable(bench_gated_delta_net
            llama/bench_gated_delta_net.cc)
    target_link_libraries(bench_gated_delta_net PRIVATE
            models
            CUDA::cudart)

    add_executable(bench_conv1d_silu
            llama/bench_conv1d_silu.cc)
    target_link_libraries(bench_conv1d_silu PRIVATE
            models
            CUDA::cudart)
endif()


================================================
FILE: src/turbomind/models/input_processor.cc
================================================

#include "src/turbomind/models/input_processor.h"

#include "src/turbomind/core/check.h"
#include "src/turbomind/core/core.h"

#include "src/turbomind/engine/request.h"

#include "src/turbomind/models/llama/SequenceManager.h"

namespace turbomind {

using std::vector;

struct InputProcessor::Impl {
public:
    Impl(const EngineParam& engine, const ModelParam& model, int phases):
        max_batch_size_{engine.max_batch_size}, max_forward_token_num_{engine.max_forward_token_num}
    {
        input_ids_buf_         = {max_forward_token_num_, kCPUpinned};
        input_ids_offsets_buf_ = {max_batch_size_ + 1, kCPUpinned};
        decode_token_pos_buf_  = {max_batch_size_, kCPUpinned};

        data_.reserve(phases);
        for (int i = 0; i < phases; ++i) {
            auto& d              = data_.emplace_back();
            d.input_ids          = empty_like(input_ids_buf_, kDEVICE);
            d.input_ids_offsets  = empty_like(input_ids_offsets_buf_, kDEVICE);
            d.selected_token_pos = empty_like(decode_token_pos_buf_, kDEVICE);

            d.autoreg_ids_pos = {max_batch_size_, kCPU};  // ! CPU buffer

            /// TODO: initialize only when required
            d.input_embeds_buf = {{max_forward_token_num_, (int)model.hidden_units}, model.data_type, kCPUpinned};
        }
    }

    int Add(RequestCache& c)
    {
        const auto& [r, s] = std::tie(*c.req, *c.seq);

        // trim input embeds
        if (!s.input_embeds_offsets.empty()) {
            Interval l{0, (int)s.tokens.size()};
            using Size    = Interval::Size;
            auto& embeds  = s.input_embeds;
            auto& offsets = s.input_embeds_offsets;
            int   i       = embeds.size() - 1;
            for (; i >= 0; --i) {
                Interval r{offsets[i], Size{(int)embeds[i].shape(0)}};
                if (auto o = r & l) {
                    if (o.end() < r.end()) {
                        embeds[i] = embeds[i].slice(0, o.end() - r.begin());
                    }
                    break;
                }
            }
            embeds.resize(i + 1);
            offsets.resize(i + 1);
        }

        if (auto ranges_ptr = r.inputs.try_("input_embedding_ranges")) {  // [n, 2]
            auto embeds = r.inputs.at("input_embeddings");                // [k, d]
            if (ranges_ptr->ndim() != 2 || embeds.ndim() != 2 || ranges_ptr->shape(1) != 2) {
                /// TODO: reject for invalid shapes
                return Request::kInvalid;
            }

            // clone the embeds if the request persists
            if (!r.session.end_flag) {
                auto tmp = std::exchange(embeds, empty_like(embeds));
                std::copy_n((const uint8_t*)tmp.raw_data(), tmp.byte_size(), (uint8_t*)embeds.raw_data());
            }

            const auto [sum, dim] = embeds.shapes(0, 1);
            const auto n          = ranges_ptr->shape(0);
            const auto ranges     = ranges_ptr->data();

            int offset = 0;
            int last   = c.step0;
            for (int i = 0; i < n; ++i) {
                Interval range{c.step0 + ranges[i * 2], c.step0 + ranges[i * 2 + 1]};
                auto     size = (int)range.size();
                if (range.begin() < last) {
                    /// TODO: reject for non-sorted ranges
                    return Request::kInvalid;
                }
                if (range.end() > c.seq_len) {
                    /// TODO: reject for dst range OOB
                    return Request::kInvalid;
                }
                if (offset + size > sum) {
                    /// TODO: reject for src range OOB
                    return Request::kInvalid;
                }
                s.input_embeds_offsets.push_back(range.begin());
                s.input_embeds.push_back(embeds.slice(offset, size));  // reference into `embeds`
                offset += size;
                last = range.end();
            }
        }

        return 0;
    }

    void Add(int phase, TensorMap& env)
    {
        const Buffer_ rc = env.at("requests").buffer();
        for (int i = 0; i < rc.size(); ++i) {
            auto& c = *TM_CHECK_NOTNULL(rc[i]);
            if (c.status == 0) {
                c.status = Add(c);
            }
        }
    }

    void Setup(int phase, TensorMap& env)
    {
        auto& d    = data_.at(phase);
        auto& b    = *env.at("batch").data()[0];
        auto& copy = *env.at("copy").data()[0];

        const auto& rc = b.rc;

        input_ids_offsets_buf_[0] = 0;
        for (int i = 0; i < rc.size(); ++i) {
            input_ids_offsets_buf_[i + 1] = input_ids_offsets_buf_[i];
            if (const auto& c = *rc[i]; TM_UNLIKELY(!c.autoregres)) {
                const auto src = c.token_ids + c.history_len + c.alpha;
                std::copy_n(src, c.input_len, input_ids_buf_.data() + input_ids_offsets_buf_[i]);
                // dbg(std::vector(src, src + c.input_len));
                d.autoreg_ids_pos[i] = -1;
                input_ids_offsets_buf_[i + 1] += c.input_len;
            }
            else {
                d.autoreg_ids_pos[i] = input_ids_offsets_buf_[i];
                input_ids_offsets_buf_[i + 1] += 1;
            }
            decode_token_pos_buf_[i] = input_ids_offsets_buf_[i + 1] - 1;
        }

        // dbg(core::to_vector(input_ids_offsets_buf_.slice(0, bsz + 1)));
        // dbg(core::to_vector(decode_token_pos_buf_.slice(0, bsz)));

        copy(input_ids_buf_, input_ids_offsets_buf_[b.bsz], d.input_ids);
        copy(decode_token_pos_buf_, b.bsz, d.selected_token_pos);
        copy(input_ids_offsets_buf_, b.bsz + 1, d.input_ids_offsets);

        // dbg(decode_token_pos_buf_[0]);

        d.input_token_num = input_ids_offsets_buf_[b.bsz];
        // dbg(d.input_token_num);

        env.produce("token_num", Buffer{&d.input_token_num, 1, kCPU});

        ////////////////////////////////////////////////////////////////
        /// input embeddings
        d.input_embeds_coords.clear();
        auto embed_ptr = (uint8_t*)d.input_embeds_buf.raw_data();
        for (int k = 0; k < rc.size(); ++k) {
            if (auto& c = *rc[k]; !c.autoregres) {
                const auto& embeds  = c.seq->input_embeds;
                const auto& offsets = c.seq->input_embeds_offsets;
                Interval    p{input_ids_offsets_buf_[k], input_ids_offsets_buf_[k + 1]};
                Interval    s{c.history_len + c.alpha, p.size()};
                for (int i = (int)offsets.size() - 1; i >= 0; --i) {
                    Interval r{offsets[i], Interval::Size{(int)embeds[i].shape(0)}};
                    auto     o = r & s;
                    if (auto size = (int)o.size()) {
                        auto src  = embeds[i].slice(o.begin() - r.begin(), size);
                        embed_ptr = std::copy_n((const uint8_t*)src.raw_data(), src.byte_size(), embed_ptr);
                        d.input_embeds_coords.emplace_back(size, p.begin() + (o.begin() - s.begin()));
                    }
                }
            }
        }
    }

    void Prepare(int phase, TensorMap& env)
    {
        auto& d    = data_.at(phase);
        auto& b    = *env.at("batch").data()[0];
        auto& copy = *env.at("copy").data()[0];

        // last output token + draft tokens
        const Buffer_ autoreg_ids = env.at("autoreg_ids").buffer();

        // core::CopyT copy{};

        if (auto g = copy.group()) {
            for (int i = 0; i < b.bsz; ++i) {
                if (auto pos = d.autoreg_ids_pos[i]; pos >= 0) {
                    TM_CHECK_LT(b.perm[i], b.bs0);
                    copy(autoreg_ids.data() + b.perm[i], 1, &d.input_ids[pos]);
                }
            }
        }

        env.produce("input_ids", d.input_ids.slice(0, d.input_token_num));
        env.produce("q_offsets", d.input_ids_offsets.slice(0, b.bsz + 1));
        env.produce("selected_token_pos", d.selected_token_pos.slice(0, b.bsz));
    }

    void PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy)
    {
        auto&      d           = data_.at(phase);
        const auto byte_stride = byte_size(embeds.dtype(), embeds.stride(0));
        int        offset      = 0;
        for (const auto& [size, pos] : d.input_embeds_coords) {
            auto src = d.input_embeds_buf.slice(offset, size);
            copy((uint8_t*)src.raw_data(), src.byte_size(), (uint8_t*)embeds.raw_data() + byte_stride * pos);
            offset += size;
        }
    }

private:
    struct Data {
        Buffer_ input_ids;
        Buffer_ input_ids_offsets;
        int          input_token_num;

        Buffer_ selected_token_pos;

        Buffer_ autoreg_ids_pos;

        Tensor                      input_embeds_buf;
        vector> input_embeds_coords;  // (size, pos)
    };

private:
    const int max_batch_size_;
    const int max_forward_token_num_;

    vector data_;

    Buffer_ input_ids_buf_;
    Buffer_ input_ids_offsets_buf_;

    Buffer_ decode_token_pos_buf_;
};

InputProcessor::~InputProcessor() = default;

InputProcessor::InputProcessor(const EngineParam& engine, const ModelParam& model, int phases):
    impl_{std::make_unique(engine, model, phases)}
{
}

void InputProcessor::Run(BatchOp op, int phase, TensorMap& env)
{
    switch (op) {
        case BatchOp::kAdd:
            return impl_->Add(phase, env);
        case BatchOp::kSetup:
            return impl_->Setup(phase, env);
        case BatchOp::kPrepare:
            return impl_->Prepare(phase, env);
        default:
            return;
    }
}

void InputProcessor::PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy)
{
    impl_->PatchEmbedding(phase, embeds, copy);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/input_processor.h
================================================
#pragma once

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class InputProcessor {
public:
    ~InputProcessor();

    InputProcessor(const EngineParam& engine, const ModelParam& model, int phases);

    void Run(BatchOp op, int phase, TensorMap& env);

    void PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy);

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/language_model.cc
================================================

#include "src/turbomind/models/language_model.h"

#include 

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/copy.h"
#include "src/turbomind/core/interval.h"
#include "src/turbomind/core/state.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/engine/request.h"
#include "src/turbomind/generation/generation.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/models/input_processor.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/unified_decoder.h"
#include "src/turbomind/models/output_processor.h"
#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/cuda_utils.h"

// #include "dbg.h"

namespace turbomind {

using std::vector;
using std::unique_ptr;
using std::shared_ptr;

struct LanguageModel::Impl {
    const DataType       dtype_;
    const ModelParam     param_;
    const AttentionParam attn_param_;
    const Communicators& comm_;
    const LlamaWeight&   weights_;
    LlamaLinear&         linear_;

    const int  tp_size_;
    const int  tp_rank_;
    const bool use_ag2d_;

    const bool debug_;

    Buffer_ false_;

    // mutable state
    State finished_;
    State sequence_length_;  // length of known tokens
    // immutable state
    Buffer_ autoreg_ids_;
    // Buffer_ autoreg_ids_offsets_;

    // Symmetric buffer for holding global hidden states or logits
    Buffer_ symm_buf_;

    // Max chunk size for compute / output full logits
    int max_logits_len_ = 0;

    Buffer_  sequence_length_buf_;
    Buffer_ finished_buf_;

    struct Data {
        Buffer_  sequence_length;
        Buffer_ finished;

        Buffer_ autoregres;
        Buffer_ generating;

        int n_generating;
    };

    vector data_;

    std::optional   input_processor_;
    std::unique_ptr unified_decoder_;
    std::optional  output_processor_;
    std::unique_ptr     generation_;  // token generator

    void Run(BatchOp op, int phase, TensorMap& env)
    {
        switch (op) {
            case BatchOp::kSetup:
                return Setup(phase, env);
            case BatchOp::kPrepare:
                return Prepare(phase, env);
            case BatchOp::kForward:
                return Forward(phase, env);
            case BatchOp::kUnprep:
                return Unprep(phase, env);
            case BatchOp::kFetch:
                return Fetch(phase, env);
            default:
                input_processor_->Run(op, phase, env);
                unified_decoder_->Run(op, phase, env);
                generation_->Run(op, phase, env);
                output_processor_->Run(op, phase, env);
        }
    }

    Impl(DataType              dtype,
         const ModelParam&     model,
         const EngineParam&    engine,
         const AttentionParam& attn,
         const MoeParam&       moe,
         const Context&        ctx,
         const LlamaWeight&    weights,
         int                   phases);

    Tensor LookupEmbedding(const Buffer_& input_ids, Buffer symm_buf);
    Tensor PostEmbedding(const Tensor& features, Buffer symm_buf);

    void Setup(int phase, TensorMap& env);
    void Prepare(int phase, TensorMap& env);
    void Forward(int phase, TensorMap& env);
    void Unprep(int phase, TensorMap& env);
    void Fetch(int phase, TensorMap& env);
};

LanguageModel::Impl::Impl(DataType              dtype,
                          const ModelParam&     model,
                          const EngineParam&    engine,
                          const AttentionParam& attn,
                          const MoeParam&       moe,
                          const Context&        ctx,
                          const LlamaWeight&    weights,
                          int                   phases):
    dtype_{dtype},
    param_{model},
    attn_param_{attn},
    comm_{ctx.comm},
    weights_{weights},
    linear_{*ctx.linear},
    tp_size_{comm_.h_tp_group->n_ranks()},
    tp_rank_{comm_.h_tp_group->rank()},
    use_ag2d_{comm_.d_comm && comm_.d_comm->Query(comm::kHasAllGather2D)},
    debug_{isDebug()}
{

    false_ = {engine.max_batch_size, kDEVICE};
    Clear(false_);

    finished_buf_ = {engine.max_batch_size, kCPUpinned};
    finished_     = {{engine.max_batch_size}, kBool, kDEVICE};

    autoreg_ids_ = {engine.max_batch_size, kDEVICE};
    // autoreg_ids_offsets_ = {engine.max_batch_size + 1, kCPU};
    // std::fill_n(autoreg_ids_offsets_.data(), autoreg_ids_offsets_.size(), 0);

    sequence_length_buf_ = {engine.max_batch_size, kCPUpinned};
    sequence_length_     = {{engine.max_batch_size}, kInt, kDEVICE};
    for (int i = 0; i < phases; ++i) {
        auto& d           = data_.emplace_back();
        d.sequence_length = empty_like(sequence_length_buf_, kDEVICE);
        d.finished        = empty_like(finished_buf_, kDEVICE);
        d.autoregres      = {engine.max_batch_size, kCPU};
        d.generating      = {engine.max_batch_size, kCPU};
    }

    input_processor_.emplace(engine, param_, phases);

    unified_decoder_ = std::make_unique(model, engine, attn, moe, ctx, phases);

    generation_ = std::make_unique(kFloat32,
                                               engine.max_batch_size,
                                               engine.session_len,
                                               model.vocab_size,
                                               weights.post_decoder_embedding.output_dim * tp_size_,
                                               comm_.h_tp_group,
                                               phases);

    const int     vocab_size     = weights_.post_decoder_embedding.output_dim * tp_size_;
    const ssize_t max_fwd_tokens = engine.max_forward_token_num;

    if (ctx.comm.d_comm) {
        auto symm_alloc = GetSymmAllocator(ctx.comm.d_comm);
        // Native comm fuses allreduce & rmsnorm in token granularity
        TM_CHECK(engine.max_forward_token_num % tp_size_ == 0);

        ssize_t bytes{};
        bytes = std::max(bytes, byte_size(dtype_, max_fwd_tokens * engine.attn_dp_size * model.hidden_units));
        bytes = std::max(bytes, byte_size(dtype_, engine.max_batch_size * vocab_size));

        symm_buf_ = {bytes, symm_alloc};
        // Compute max logits length based on symm buffer size
        max_logits_len_ = symm_buf_.view(dtype_).size() / vocab_size;
    }
    else {
        max_logits_len_ = std::max(max_fwd_tokens * model.hidden_units / vocab_size, engine.max_batch_size);
    }

    output_processor_.emplace(param_, max_logits_len_, tp_rank_, phases, [this](const Tensor& hstate) {
        return PostEmbedding(hstate, symm_buf_);
    });
}

Tensor LanguageModel::Impl::LookupEmbedding(const Buffer_& input_ids, Buffer symm_buf)
{
    const auto st = core::Context::stream().handle();

    const int hidden_units = param_.hidden_units;

    const auto& embedding_table = weights_.pre_decoder_embedding.weight;
    TM_CHECK_EQ(embedding_table.shape(1) * tp_size_, hidden_units);

    const int token_num = input_ids.size();

    Tensor input_embeds{{token_num, hidden_units}, dtype_, kDEVICE};

    if (token_num == 0) {
        return input_embeds;
    }

    if (tp_size_ == 1) {
        invokeEmbeddingLookup(input_embeds, input_ids, embedding_table, st);
        sync_check_cuda_error();
    }
    else if (use_ag2d_) {
        const auto local_hidden_units = embedding_table.shape(1);

        Tensor temp{symm_buf.view(dtype_), {token_num, tp_size_, local_hidden_units}};
        Tensor local{temp.slice({0, tp_rank_, 0}, {-1, 1, -1}).squeeze(1)};

        invokeEmbeddingLookup(local, input_ids, embedding_table, st);
        sync_check_cuda_error();

        comm_.d_comm->AllGather2D(local.raw_data(),
                                  temp.raw_data(),
                                  hidden_units,
                                  local_hidden_units,
                                  local_hidden_units,
                                  token_num,
                                  local.dtype(),
                                  {true, true},
                                  comm_.d_tp_group,
                                  st);
        sync_check_cuda_error();

        Copy(temp.buffer(), input_embeds.buffer());
    }
    else {
        const auto local_hidden_units = embedding_table.shape(1);

        Tensor temp{symm_buf.view(dtype_), {tp_size_, token_num, local_hidden_units}};
        Tensor local{temp.slice(tp_rank_).squeeze(0)};

        invokeEmbeddingLookup(local, input_ids, embedding_table, st);
        sync_check_cuda_error();

        comm_.d_comm->AllGather(local.raw_data(), temp.raw_data(), local.size(), dtype_, comm_.d_tp_group, st);
        sync_check_cuda_error();

        invokeInPlaceTranspose102((uint16_t*)input_embeds.raw_data(),
                                  (uint16_t*)temp.raw_data(),
                                  tp_size_,
                                  token_num,
                                  local_hidden_units,
                                  false,
                                  st);
        sync_check_cuda_error();
    }

    return input_embeds;
}

Tensor LanguageModel::Impl::PostEmbedding(const Tensor& features, Buffer symm_buf)
{
    NvtxScope scope("postDecodeEmbedding");

    const auto st = core::Context::stream().handle();

    const int bsz              = features.shape(0);
    const int local_vocab_size = weights_.post_decoder_embedding.output_dim;
    const int vocab_size       = local_vocab_size * tp_size_;

    if (bsz == 0) {
        return Tensor{{0, vocab_size}, dtype_, kDEVICE};
    }

    if (tp_size_ == 1) {
        Tensor logits{{bsz, vocab_size}, dtype_, kDEVICE};
        linear_.Forward(features, weights_.post_decoder_embedding, logits);
        sync_check_cuda_error();
        TM_DEBUG_TENSOR(logits, "logits", 1);
        return logits;
    }
    else if (use_ag2d_) {
        Tensor logits{symm_buf.view(dtype_), {bsz, tp_size_, local_vocab_size}};
        Tensor local = logits.slice({0, tp_rank_, 0}, {-1, 1, -1});
        linear_.Forward(features, weights_.post_decoder_embedding, local.squeeze(1));
        sync_check_cuda_error();
        comm_.d_comm->AllGather2D(local.raw_data(),
                                  logits.raw_data(),
                                  vocab_size,
                                  local_vocab_size,
                                  local_vocab_size,
                                  bsz,
                                  logits.dtype(),
                                  {true, true},
                                  comm_.d_tp_group,
                                  st);
        sync_check_cuda_error();
        return logits.view({bsz, -1});
    }
    else {
        Tensor logits{symm_buf.view(dtype_), {tp_size_, bsz, local_vocab_size}};
        Tensor local = logits.slice({tp_rank_, 0, 0}, {1, -1, -1});
        linear_.Forward(features, weights_.post_decoder_embedding, local.squeeze(0));
        sync_check_cuda_error();
        comm_.d_comm->AllGather(local.raw_data(), logits.raw_data(), local.size(), local.dtype(), comm_.d_tp_group, st);
        sync_check_cuda_error();
        Tensor out{{bsz, vocab_size}, features.dtype(), features.device()};
        invokeTransposeAxis01(
            (uint16_t*)out.raw_data(), (uint16_t*)logits.raw_data(), tp_size_, bsz, local_vocab_size, st);
        sync_check_cuda_error();
        return out;
    }
}

void LanguageModel::Impl::Setup(int phase, TensorMap& env)
{
    input_processor_->Run(BatchOp::kSetup, phase, env);

    auto& d    = data_.at(phase);
    auto& copy = *env.at("copy").data()[0];

    const auto& rc = env.at("batch").data()[0]->rc;

    d.n_generating = 0;

    for (int i = 0; i < rc.size(); ++i) {
        auto& c         = *rc[i];
        d.autoregres[i] = c.autoregres;
        d.generating[i] = c.generating;
        d.n_generating += c.generating;
        if (TM_UNLIKELY(!c.autoregres)) {
            sequence_length_buf_[i] = c.history_len + c.alpha + c.input_len;
        }
    }

    copy(sequence_length_buf_, rc.size(), d.sequence_length);

    unified_decoder_->Run(BatchOp::kSetup, phase, env);
    generation_->Run(BatchOp::kSetup, phase, env);
    output_processor_->Run(BatchOp::kSetup, phase, env);
}

void LanguageModel::Impl::Prepare(int phase, TensorMap& env)
{
    env.emplace("autoreg_ids", autoreg_ids_);

    input_processor_->Run(BatchOp::kPrepare, phase, env);

    auto& d = data_.at(phase);

    auto& b    = *env.at("batch").data()[0];
    auto& copy = *env.at("copy").data()[0];

    // core::CopyT copy{};

    if (auto group = copy.group()) {
        for (int i = 0; i < b.bsz; ++i) {
            if (const int j = b.perm[i]; j < b.bs0) {
                copy(finished_.front().data() + j, 1, finished_.back().data() + i);
            }
            else {
                copy(false_.data() + i, 1, finished_.back().data() + i);
            }
        }
        finished_.Swap();
    }

    if (auto group = copy.group()) {
        // sequence_length = history_len + input_len
        for (int i = 0; i < b.bsz; ++i) {
            if (const int j = b.perm[i]; j < b.bs0 && d.autoregres[i]) {
                copy(sequence_length_.front().data() + j, 1, sequence_length_.back().data() + i);
            }
            else {
                copy(d.sequence_length.data() + i, 1, sequence_length_.back().data() + i);
            }
        }
        sequence_length_.Swap();
    }

    Buffer_ k_offsets{b.bsz + 1, kDEVICE};
    // PrefixSum(sequence_length_.front().data(), bsz, k_offsets.data(), core::Context::stream().handle());

    // Buffer_ k_offsets_tmp{k_offsets.size(), kCPU};
    // Buffer_ sequence_length_tmp{sequence_length_.front().size(), kCPU};

    // Copy(k_offsets, k_offsets_tmp);
    // Copy(sequence_length_.front().buffer(), sequence_length_tmp);

    // core::Context::stream().Sync();

    // dbg(core::to_vector(sequence_length_tmp.slice(0, bsz)));
    // dbg(core::to_vector(k_offsets_tmp.slice(0, bsz + 1)));

    env.produce("finished", finished_.front());
    env.produce("sequence_length", sequence_length_.front());
    env.produce("k_offsets", k_offsets);

    unified_decoder_->Run(BatchOp::kPrepare, phase, env);
    generation_->Run(BatchOp::kPrepare, phase, env);
    output_processor_->Run(BatchOp::kPrepare, phase, env);
}

void LanguageModel::Impl::Forward(int phase, TensorMap& env)
{

    auto& d = data_.at(phase);
    auto& b = *env.at("batch").data()[0];

    {
        Buffer_ k_offsets = env.at("k_offsets").buffer();
        PrefixSum(sequence_length_.front().data(), b.bsz, k_offsets.data(), core::Context::stream().handle());
    }

    {  // compute input embeddings
        auto input_ids = env.at("input_ids").buffer();

        Tensor input_embeds = LookupEmbedding(input_ids, symm_buf_);
        TM_DEBUG_TENSOR(input_embeds, "embeddings", 1);

        auto& copy = *env.at("copy").data()[0];
        input_processor_->PatchEmbedding(phase, input_embeds, copy);
        copy.Run();

        env.produce("input_embeds", std::move(input_embeds));
        // dbg(env);
    }

    if (symm_buf_) {
        env.produce("symm_buf", symm_buf_);
    }

    env.produce("output_norm_weight", weights_.output_norm_weight);

    unified_decoder_->Forward(phase, env, weights_.decoder_layer_weights);

    // env.at("batch").data()[0]->Notify();

    output_processor_->OutputHiddenStatesAndLogits(phase, env, 2);

    auto& hidden_states = env.at("hidden_states");

    env.produce("logits", PostEmbedding(hidden_states, symm_buf_));

    output_processor_->OutputHiddenStatesAndLogits(phase, env, 1);

    if (d.n_generating) {
        generation_->Run(BatchOp::kForward, phase, env);
        Copy(env.at("output_ids").buffer(), autoreg_ids_);
    }
}

void LanguageModel::Impl::Unprep(int phase, TensorMap& env)
{
    auto& d    = data_.at(phase);
    auto& copy = *env.at("copy").data()[0];

    copy(sequence_length_.front().buffer(), d.sequence_length.size(), d.sequence_length);

    copy(finished_.front().buffer(), d.finished.size(), d.finished);

    generation_->Run(BatchOp::kUnprep, phase, env);
}

void LanguageModel::Impl::Fetch(int phase, TensorMap& env)
{
    auto& d    = data_.at(phase);
    auto& copy = *env.at("copy").data()[0];

    copy(d.sequence_length, d.sequence_length.size(), sequence_length_buf_);
    env.produce("sequence_length", sequence_length_buf_);

    copy(d.finished, d.finished.size(), finished_buf_);
    env.produce("finished", finished_buf_);

    env.produce("generating", d.generating);

    generation_->Run(BatchOp::kFetch, phase, env);
}

LanguageModel::~LanguageModel() = default;

LanguageModel::LanguageModel(LanguageModel&&) noexcept = default;

LanguageModel::LanguageModel(DataType              dtype,
                             const ModelParam&     model,
                             const EngineParam&    engine,
                             const AttentionParam& attn,
                             const MoeParam&       moe,
                             const Context&        ctx,
                             const LlamaWeight&    weights,
                             int                   phases)
{
    impl_ = std::make_unique(dtype, model, engine, attn, moe, ctx, weights, phases);
}

void LanguageModel::Run(BatchOp op, int phase, TensorMap& env)
{
    return TM_CHECK_NOTNULL(impl_)->Run(op, phase, env);
}

const ModelParam& LanguageModel::model_param() const noexcept
{
    return TM_CHECK_NOTNULL(impl_)->param_;
}

const AttentionParam& LanguageModel::attn_param() const noexcept
{
    return TM_CHECK_NOTNULL(impl_)->attn_param_;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/language_model.h
================================================
#pragma once

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class LlamaWeight;

class LanguageModel {
public:
    ~LanguageModel();

    LanguageModel() = default;

    LanguageModel(LanguageModel&&) noexcept;

    explicit operator bool() const noexcept
    {
        return static_cast(impl_);
    }

    LanguageModel(DataType              dtype,
                  const ModelParam&     model,
                  const EngineParam&    engine,
                  const AttentionParam& attn,
                  const MoeParam&       moe,
                  const Context&        ctx,
                  const LlamaWeight&    weights,
                  int                   phases);

    void Run(BatchOp op, int phase, TensorMap& env);

    const ModelParam&     model_param() const noexcept;
    const AttentionParam& attn_param() const noexcept;

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/Barrier.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#ifndef _MSC_VER
#include 
#endif

namespace turbomind {

#ifdef _MSC_VER

class Barrier {
public:
    Barrier(unsigned count)
    {
        TM_LOG_INFO("Barrier(%d)", (int)count);
        FT_CHECK(count == 1);
    }

    Barrier(const Barrier&) = delete;
    Barrier& operator=(const Barrier&) = delete;
    Barrier(Barrier&&) noexcept        = delete;
    Barrier& operator=(Barrier&&) noexcept = delete;

    void wait() {}

    ~Barrier() {}
};

#else

class Barrier {
public:
    Barrier(unsigned count): count_(count)
    {
        if (count_ > 1) {
            pthread_barrier_init(&barrier_, nullptr, count);
        }
    }

    Barrier(const Barrier&) = delete;
    Barrier& operator=(const Barrier&) = delete;
    Barrier(Barrier&&) noexcept        = delete;
    Barrier& operator=(Barrier&&) noexcept = delete;

    void wait()
    {
        if (count_ > 1) {
            pthread_barrier_wait(&barrier_);
        }
    }

    ~Barrier()
    {
        if (count_ > 1) {
            pthread_barrier_destroy(&barrier_);
        }
    }

private:
    const int         count_;
    pthread_barrier_t barrier_{};
};

#endif

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/BlockManager.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/string_utils.h"

namespace turbomind {

BlockManager::BlockManager(
    size_t block_size, double block_count, int chunk_size, core::Allocator allocator, GetFreeMemSize get_free_size):
    block_size_(block_size), allocator_(allocator)
{
    if (block_count < 1.) {
        max_block_count_ = GetBlockCount(block_size, block_count, get_free_size);
    }
    else {
        max_block_count_ = block_count;
    }

    if (chunk_size == 0) {
        chunk_size_ = static_cast(std::sqrt(max_block_count_));
    }
    else if (chunk_size < 0) {
        chunk_size_ = max_block_count_;
    }
    else {
        chunk_size_ = chunk_size;
    }

    TM_LOG_INFO("[BlockManager] block_size = %.3f MB", (float)block_size_ / (1 << 20));
    TM_LOG_INFO("[BlockManager] max_block_count = %d", max_block_count_);
    TM_LOG_INFO("[BlockManager] chunk_size = %d", chunk_size_);

    blocks_.reserve(max_block_count_);

    active_ids_.reserve(max_block_count_);
    cached_ids_.reserve(max_block_count_);
    free_ids_.reserve(max_block_count_);

    // pre-allocate first chunk
    Malloc();
    dbg(free_ids_);
}

BlockManager::~BlockManager()
{
    for (auto& chunk : chunks_) {
        allocator_->deallocate(chunk, block_size_);
    }
}

bool BlockManager::Malloc()
{
    auto chunk_size = std::min(chunk_size_, max_block_count_ - blocks_.size());

    if (!chunk_size) {
        return false;
    }

    auto ptr = (std::byte*)allocator_->allocate(block_size_ * chunk_size);
    if (!ptr) {
        return false;
    }

    chunks_.push_back(ptr);

    for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
        auto& block     = blocks_.emplace_back();
        block.use_count = 0;
        block.id        = (int)blocks_.size() - 1;
        block.timestamp = 0;
        block.data      = ptr;

        free_ids_.push_back(block.id);
    }

    return true;
}

size_t BlockManager::GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size)
{
    size_t free = get_free_size();
    return static_cast(free * ratio) / block_size;
}

void BlockManager::Move(std::vector& src, const std::vector& delta, std::vector& dst)
{
    TM_CHECK_GE(src.size(), delta.size());
    std::vector src1(src.size() - delta.size());
    {
        auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
        TM_CHECK(end == src1.end());
    }
    src.swap(src1);

    std::vector dst1(dst.size() + delta.size());
    {
        auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
        TM_CHECK(end == dst1.end());
    }
    dst.swap(dst1);
}

auto BlockManager::Allocate(int count) -> std::pair
{
    while (free_ids_.size() < count) {
        if (!Malloc()) {
            throw std::runtime_error("out of memory");
        }
    }

    BlockIds  block_ids(count);
    UniqueIds unique_ids(count);

    for (int i = 0; i < count; ++i) {
        int   idx = free_ids_[i];
        auto& b   = blocks_[idx];
        TM_CHECK(is_free(b));  // pre-condition: uc == 0 && ts == 0
        b.use_count = 1;
        b.unique_id = unique_id_++;
        b.timestamp = timestamp_++;
        TM_CHECK(is_active(b));  // post-condition
        block_ids[i]  = idx;
        unique_ids[i] = b.unique_id;
    }

    Move(free_ids_, block_ids, active_ids_);

    dbg(free_ids_, active_ids_);

    return {block_ids, unique_ids};
}

void BlockManager::Evict(int count)
{
    TM_CHECK_LE(count, cached_ids_.size());
    std::vector idxs(cached_ids_);
    // get first `count` cached ids according to timestamp
    std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
        return blocks_[i].timestamp < blocks_[j].timestamp;
    });
    idxs.resize(count);

    // sort the retrieved ids
    std::sort(idxs.begin(), idxs.end());

    // set as free
    for (const auto& idx : idxs) {
        auto& b = blocks_[idx];
        TM_CHECK(is_cached(b));  // pre-condition
        b.unique_id = 0;
        b.timestamp = 0;
        TM_CHECK(is_free(b));  // post-condition
    }

    Move(cached_ids_, idxs, free_ids_);

    dbg(cached_ids_, free_ids_);
}

void BlockManager::Free(BlockIds ids)
{
    std::sort(ids.begin(), ids.end());

    for (const auto& i : ids) {
        auto& b = blocks_[i];
        TM_CHECK(is_cached(b));  // pre-condition
        b.unique_id = 0;
        b.timestamp = 0;
        TM_CHECK(is_free(b));  // post-condition
    }

    Move(cached_ids_, ids, free_ids_);
}

int BlockManager::Unlock(const BlockIds& ids)
{
    BlockIds unlock;
    unlock.reserve(ids.size());

    for (const auto& i : ids) {
        auto& b = blocks_[i];
        TM_CHECK(is_active(b));  // pre-condition
        if (--b.use_count == 0) {
            unlock.push_back(b.id);
            TM_CHECK(is_cached(b));  // post-condition
        }
    }

    std::sort(unlock.begin(), unlock.end());

    Move(active_ids_, unlock, cached_ids_);

    dbg(active_ids_, cached_ids_);
    return unlock.size();
}

int BlockManager::Lock(const BlockIds& ids)
{
    BlockIds lock;
    lock.reserve(ids.size());

    for (const auto& i : ids) {
        auto& b = blocks_[i];
        if (++b.use_count == 1) {
            lock.push_back(i);
            TM_CHECK(is_active(b));  // post-condition
        }
    }

    std::sort(lock.begin(), lock.end());

    Move(cached_ids_, lock, active_ids_);

    // dbg(cached_ids_, active_ids_);

    return lock.size();
}

void BlockManager::Touch(const BlockIds& ids)
{
    std::for_each(ids.crbegin(), ids.crend(), [this](int i) {
        TM_CHECK(is_active(blocks_[i]));
        blocks_[i].timestamp = timestamp_++;
    });
}

int BlockManager::Verify(const std::vector& block_ids, const std::vector& unique_ids)
{
    TM_CHECK_EQ(block_ids.size(), unique_ids.size());
    int valid = block_ids.size();
    for (int i = 0; i < block_ids.size(); ++i) {
        if (unique_id(block_ids[i]) != unique_ids[i]) {
            valid = i;
            break;
        }
    }
    int miss = 0;
    for (int i = valid; i < block_ids.size(); ++i) {
        miss += (unique_id(block_ids[i]) != unique_ids[i]);
    }
    // All later blocks should have been invalidated
    TM_CHECK_EQ(miss, (int)block_ids.size() - valid)
        << fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss);
    return valid;
}

Snapshot BlockManager::TakeSnapshot()
{
    std::vector use_count(blocks_.size());
    for (const auto& idx : active_ids_) {
        use_count[idx] = blocks_[idx].use_count;
    }
    return {active_count(), cached_count(), free_count(), std::move(use_count)};
}

std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
{
    os << "block_size: " << manager.block_size_ << ", ";
    os << "max_block_count: " << manager.max_block_count_ << ", ";
    os << "chunk_size: " << manager.chunk_size_ << ", ";
    os << "chunks: " << manager.chunks_.size() << ", ";
    os << "active_ids: " << manager.active_ids_.size() << ", ";
    os << "cached_ids: " << manager.cached_ids_.size() << ", ";
    os << "free_ids: " << manager.free_ids_.size() << ", ";
    os << "blocks: " << manager.blocks_.size() << ", ";
    os << "unique_id: " << manager.unique_id_ << ", ";
    os << "timestamp: " << manager.timestamp_;
    return os;
}

std::ostream& operator<<(std::ostream& os, const Block& block)
{
    os << "id=" << block.id << ", use_count=" << block.use_count << ", unique_id=" << block.unique_id
       << ", timestamp=" << block.timestamp << ", data=" << block.data;
    return os;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/BlockManager.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

namespace turbomind {

// [L, H, S, D]

// [L, S/x, H, x, D]

struct Block {
    int      id;         // fixed linear id in the pool
    int      use_count;  // active sequences using the block
    uint64_t unique_id;  // unique for every block allocation
    uint64_t timestamp;
    void*    data;

    friend std::ostream& operator<<(std::ostream& os, const Block& block);
    friend std::string   to_string(const Block& b)
    {
        std::stringstream ss;
        ss << b;
        return ss.str();
    }
};

using BlockIds  = std::vector;
using UniqueIds = std::vector;

inline bool is_active(const Block& block)
{
    // timestamp may be 0 for newly allocated block that has not been written
    return block.use_count > 0;
}

inline bool is_cached(const Block& block)
{
    return block.use_count == 0 && block.timestamp != 0;
}

inline bool is_free(const Block& block)
{
    return block.use_count == 0 && block.timestamp == 0;
}

struct Snapshot {
    int              active;
    int              cached;
    int              free;
    std::vector use_count;
};

using GetFreeMemSize = std::function;

class BlockManager {
public:
    explicit BlockManager(
        size_t block_size, double block_count, int chunk_size, core::Allocator allocator, GetFreeMemSize get_free_size);

    ~BlockManager();

    // free -> active (use_count = 1, ref_count = 1)
    [[nodiscard]] std::pair Allocate(int count);

    // cached -> active (use_count += 1)
    [[maybe_unused]] int Lock(const BlockIds& ids);

    // active -> cached (use_count -= 1)
    [[maybe_unused]] int Unlock(const BlockIds& ids);

    // cached -> free (ref_count = 0)
    void Evict(int count);

    // cached -> free (ref_count -= 1)
    void Free(BlockIds bs);

    // increase timestamp in reversed order
    void Touch(const BlockIds& bs);

    [[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);

    Snapshot TakeSnapshot();

    int max_block_count() const noexcept
    {
        return max_block_count_;
    }

    int total_count() const noexcept
    {
        return blocks_.size();
    }

    int active_count() const noexcept
    {
        return active_ids_.size();
    }

    int cached_count() const noexcept
    {
        return cached_ids_.size();
    }

    int free_count() const noexcept
    {
        return free_ids_.size();
    }

    Block& block(int idx)
    {
        return blocks_[idx];
    }

    int unique_id(int idx)
    {
        return blocks_[idx].unique_id;
    }

    friend std::ostream& operator<<(std::ostream& os, const BlockManager&);

private:
    static size_t GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size);

    // move indices between sets
    static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);

    // allocate a chunk of blocks
    bool Malloc();

private:
    size_t block_size_;
    int    max_block_count_{};
    int    chunk_size_{};

    core::Allocator allocator_;

    std::vector chunks_;

    BlockIds active_ids_;
    BlockIds cached_ids_;
    BlockIds free_ids_;

    std::vector blocks_;  // < 100k

    uint64_t unique_id_{1};
    uint64_t timestamp_{1};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/BlockTrie.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/models/llama/BlockTrie.h"
#include "src/turbomind/models/llama/SequenceManager.h"

namespace turbomind {

size_t hash(const std::vector& vec)
{
    size_t seed = vec.size();
    for (const auto& i : vec) {
        seed ^= std::hash{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
    }
    return seed;
}

BlockTrie::BlockTrie(size_t block_len, std::shared_ptr block_manager):
    block_seq_len_(block_len), block_manager_(block_manager)
{
    root_ = std::make_shared();
}

std::tuple BlockTrie::Match(const Sequence& seq)
{
    BlockIds  block_ids;
    UniqueIds unique_ids;

    auto node  = root_;
    auto first = seq.prompt.begin();

    // Warning: Do not use "<=" operator even when seq.prompt length is evenly
    // divisible by block_seq_len_. The model needs at least one input token to generate output.
    while (first + block_seq_len_ < seq.prompt.end()) {
        const std::vector segment{first, first + block_seq_len_};
        const size_t           hash_key = hash(segment);
        if (const auto it = node->children.find(hash_key); it != node->children.end()) {
            if (segment == it->second->tokens) {
                block_ids.push_back(it->second->block_id);
                unique_ids.push_back(it->second->block_unique_id);
                node = it->second;
                first += block_seq_len_;
            }
            else {
                TM_LOG_WARNING("hash collision detected");
                break;
            }
        }
        else {
            break;
        }
    }

    return std::make_tuple(block_ids, unique_ids);
}

std::tuple BlockTrie::Cache(const Sequence& seq, const std::vector& tokens)
{
    // Ensure the seq is active or locked so that all cache blocks must be valid
    TM_CHECK_NE(seq.status, Sequence::kCached);
    TM_CHECK_LE(seq.cache_len, seq.blocks.size() * block_seq_len_);

    auto node = root_;

    BlockIds  cache_block_ids;
    UniqueIds cache_block_unique_ids;

    const int n_blocks = std::min(seq.cache_len, (int)tokens.size()) / block_seq_len_;

    int new_cached = 0;

    for (int idx = 0; idx < n_blocks; ++idx) {
        auto start = tokens.begin() + idx * block_seq_len_;
        auto end   = start + block_seq_len_;

        const std::vector segment(start, end);
        const size_t           hash_key = hash(segment);  // TODO(lvhan): add salt to ensure the hash security

        int      block_id        = seq.blocks[idx];
        uint64_t block_unique_id = seq.block_unique_ids[idx];

        if (auto it = node->children.find(hash_key); it != node->children.end()) {
            if (segment == it->second->tokens) {  // fast-forward
                node                  = it->second;
                node->block_id        = block_id;
                node->block_unique_id = block_unique_id;
            }
            else {
                TM_LOG_WARNING("[BlockTrie][cache] Hash collision detected");
                break;
            }
        }
        else {
            // insert new node
            node                  = node->children.emplace_hint(it, hash_key, std::make_shared())->second;
            node->hash_key        = hash_key;
            node->tokens          = segment;
            node->block_id        = block_id;
            node->block_unique_id = block_unique_id;
            new_cached += block_seq_len_;
        }
        cache_block_ids.emplace_back(block_id);
        cache_block_unique_ids.emplace_back(block_unique_id);
    }

    TM_LOG_INFO("[BlockTrie][cache] %d new tokens cached", new_cached);

    return std::make_tuple(cache_block_ids, cache_block_unique_ids);
}

void BlockTrie::Verify()
{
    DFS(root_);
}

void BlockTrie::DFS(std::shared_ptr& node)
{
    for (auto it = node->children.begin(); it != node->children.end();) {
        if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) {
            // child invalid
            it = node->children.erase(it);
        }
        else {
            DFS(it->second);
            it++;
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/BlockTrie.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/models/llama/BlockManager.h"
#include 
#include 
#include 

namespace turbomind {

struct Sequence;

struct TrieNode {
    std::unordered_map> children;
    size_t                                                hash_key;
    std::vector                                      tokens;
    int                                                   block_id;
    uint64_t                                              block_unique_id;
    int                                                   num_matched;
};

class BlockTrie {
public:
    explicit BlockTrie(size_t block_len, std::shared_ptr block_manager);

    /**
     * @brief Attempt to match cached key-value (KV) blocks for a given sequence.
     *
     * This function iterates the tokens of the sequence and attempts
     * to match them with the cached KV blocks. If the max prefix match is found,
     * it returns the IDs, unique IDs of the matched blocks.
     *
     * @param seq The sequence whose tokens are to be matched against the cached KV blocks.
     * @return A tuple containing the following:
     *         - BlockIds: A list of IDs of the matched blocks.
     *         - UniqueIds: A list of unique IDs of the matched blocks.
     *
     * @note If no blocks are matched, all containers in the returned tuple will be empty.
     */
    std::tuple Match(const Sequence& seq);

    /**
     * @brief Cache the key-value (KV) blocks of a given sequence.
     *
     * This function caches the KV blocks of the specified sequence. Only valid blocks
     * of a sequence whose status is NOT `Sequence::kCached` are considered
     * to be cached
     *
     * @param seq The sequence whose KV blocks are to be cached.
     * @param tokens The token list corresponding to the KV blocks
     * @return A tuple containing the following:
     *         - BlockIds: A list of IDs of the cached blocks.
     *         - UniqueIds: A list of unique IDs of the cached blocks.
     */
    std::tuple Cache(const Sequence& seq, const std::vector& tokens);

    /**
     * @brief remove invalid nodes
     */
    void Verify();

private:
    void DFS(std::shared_ptr& node);

private:
    size_t block_seq_len_;

    std::shared_ptr block_manager_;

    std::shared_ptr root_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)


find_package(CUDAToolkit REQUIRED)

add_library(Llama STATIC
        LlamaV2.cc
        LlamaBatch.cc
        LlamaLinear.cu
        BlockManager.cc
        BlockTrie.cc
        SequenceManager.cc
        LlamaWeight.cc
        LlamaDenseWeight.cc
        LlamaDecoderLayerWeight.cc
        LlamaFfnLayer.cc
        moe_ffn_layer.cc
        unified_decoder.cc
        unified_attention_layer.cc
        llama_kernels.cu
        llama_utils.cu
        mla_utils.cu
        GatedDeltaNetWeight.cc
        GatedDeltaNetLayer.cc
        gated_delta_net_kernels.cu
)
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
target_link_libraries(Llama PUBLIC CUDA::cudart
        engine
        core
        gemm2
        CUDA::cublas
        nvidia::cutlass::cutlass
        rms_norm
        DynamicDecodeLayer
        activation_kernels
        activation
        attention
        decoding_kernels
        quantization_kernels
        unfused_attention_kernels
        gpt_kernels
        memory_utils
        cuda_utils
        logger
        anomaly_handler)


================================================
FILE: src/turbomind/models/llama/GatedDeltaNetLayer.cc
================================================
#include "src/turbomind/models/llama/GatedDeltaNetLayer.h"
#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/gated_delta_net_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

GatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam&     model,
                                       const AttentionParam& attn,
                                       const EngineParam&    engine,
                                       int                   tp_size,
                                       const Context&        ctx,
                                       int                   phases):
    hidden_units_(model.hidden_units),
    num_k_heads_(model.linear_num_key_heads / tp_size),
    num_v_heads_(model.linear_num_value_heads / tp_size),
    key_head_dim_(model.linear_key_head_dim > 0 ? model.linear_key_head_dim : model.head_dim),
    value_head_dim_(model.linear_value_head_dim > 0 ? model.linear_value_head_dim : model.head_dim),
    d_conv_(model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4),
    key_dim_(num_k_heads_ * key_head_dim_),
    value_dim_(num_v_heads_ * value_head_dim_),
    conv_dim_(key_dim_ * 2 + value_dim_),
    norm_eps_(model.norm_eps),
    dtype_(model.data_type),
    state_dtype_(model.linear_state_dtype),
    linear_(*ctx.linear)
{
    layer_types_       = model.layer_types;
    num_linear_layers_ = 0;
    for (auto t : layer_types_) {
        if (t == 1)
            ++num_linear_layers_;
    }

    TM_LOG_INFO("GatedDeltaNetLayer: num_k=%d num_v=%d k_dim=%d v_dim=%d "
                "conv_dim=%d d_conv=%d num_linear_layers=%d",
                num_k_heads_,
                num_v_heads_,
                key_dim_,
                value_dim_,
                conv_dim_,
                d_conv_,
                num_linear_layers_);

    if (num_linear_layers_ > 0) {
        conv_state_ptrs_buf_      = {engine.max_batch_size, kCPUpinned};
        recurrent_state_ptrs_buf_ = {engine.max_batch_size, kCPUpinned};
    }

    for (int i = 0; i < phases; ++i) {
        data_.emplace_back();
        if (num_linear_layers_ > 0) {
            data_.at(i).conv_state_ptrs      = empty_like(conv_state_ptrs_buf_, kDEVICE);
            data_.at(i).recurrent_state_ptrs = empty_like(recurrent_state_ptrs_buf_, kDEVICE);
        }
    }

    int device = 0;
    cudaGetDevice(&device);
    cudaDeviceGetAttribute(&sm_count_, cudaDevAttrMultiProcessorCount, device);
    work_counter_ = {1, kDEVICE};

    check_cuda_error(cudaStreamCreateWithPriority(&aux_stream_, cudaStreamNonBlocking, -1));
    check_cuda_error(cudaEventCreateWithFlags(&ev_before_, cudaEventDisableTiming));
    check_cuda_error(cudaEventCreateWithFlags(&ev_after_, cudaEventDisableTiming));
}

GatedDeltaNetLayer::~GatedDeltaNetLayer()
{
    cudaStreamDestroy(aux_stream_);
    cudaEventDestroy(ev_before_);
    cudaEventDestroy(ev_after_);
}

void GatedDeltaNetLayer::Run(BatchOp op, int phase, TensorMap& env)
{
    if (op == BatchOp::kAdd) {
        Buffer_ rc    = env.at("requests").buffer();
        const auto             dtype = dtype_;
        for (int i = 0; i < rc.size(); ++i) {}
    }
    else if (op == BatchOp::kSetup) {
        Setup(phase, env);
    }
    else if (op == BatchOp::kPrepare) {
        auto& d     = data_.at(phase);
        d.q_offsets = env.at("q_offsets").buffer().borrow();
        d.k_offsets = env.at("k_offsets").buffer().borrow();
    }
}

void GatedDeltaNetLayer::Setup(int phase, TensorMap& env)
{
    auto&       d = data_.at(phase);
    const auto& b = *env.at("batch").data()[0];

    d.batch_size = b.rc.size();
    d.rc.resize(d.batch_size);
    d.input_lens.resize(d.batch_size);

    d.conv_states.resize(d.batch_size);
    d.recurrent_states.resize(d.batch_size);

    for (int i = 0; i < d.batch_size; ++i) {
        d.rc[i]         = b.rc[i].get();
        d.input_lens[i] = b.rc[i]->input_len;

        auto& s = *b.rc[i]->seq;
        TM_CHECK(s.conv_states && s.recurrent_states)
            << "Linear-attention state slot is not bound for sequence " << s.id;
        if (s.linear_states_need_reset) {
            // Reset newly assigned pooled slot state on first use. Keep GPU-side
            // state initialization out of SequenceManager.
            Clear(s.conv_states);
            Clear(s.recurrent_states);
            s.linear_states_need_reset = false;
        }

        // Linear-attention requests are restricted to stateless execution, so
        // the sequence-owned states can be passed directly here.
        d.conv_states[i]      = s.conv_states;
        d.recurrent_states[i] = s.recurrent_states;

        conv_state_ptrs_buf_[i]      = d.conv_states[i].raw_data();
        recurrent_state_ptrs_buf_[i] = d.recurrent_states[i].raw_data();
    }

    Copy(conv_state_ptrs_buf_, d.batch_size, d.conv_state_ptrs);
    Copy(recurrent_state_ptrs_buf_, d.batch_size, d.recurrent_state_ptrs);
}

static int linear_layer_index(int layer_id, const std::vector& layer_types)
{
    int idx = 0;
    for (int i = 0; i < layer_id && i < (int)layer_types.size(); ++i) {
        if (layer_types[i] == 1)
            ++idx;
    }
    return idx;
}

void GatedDeltaNetLayer::Forward(ForwardParam p)
{
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);

    const int token_num = p.input.shape(0);
    if (token_num == 0)
        return;

    const auto  dtype   = p.input.dtype();
    const auto  device  = p.input.device();
    const auto  stream  = core::Context::stream().handle();
    const auto& weights = *p.weights;

    auto& pd = data_.at(p.phase);

    auto dispatch = [&](auto t) {
        using T = decltype(t);

        // =================================================================
        // 1. Single fused input projection: reads p.input once from HBM.
        //    Output columns are ordered: [qkv | z | b | a]
        //    where the split dims are: conv_dim_, value_dim_, v_heads_tp_, v_heads_tp_
        // =================================================================
        const int v_heads_tp = num_v_heads_;  // already TP-sharded
        Tensor    all_proj   = linear_.Forward(p.input, weights.in_proj_all);
        sync_check_cuda_error();

        // Column offsets per token (all_proj is token-major, row-major):
        //   [0, conv_dim_)           -> mixed_qkv
        //   [conv_dim_, +value_dim_) -> z
        //   [conv_dim_+value_dim_, +v_heads_tp) -> b (beta logit)
        //   [conv_dim_+value_dim_+v_heads_tp, +v_heads_tp) -> a (alpha/dt)
        const int all_col = conv_dim_ + value_dim_ + v_heads_tp * 2;
        // const T* sub-pointers are derived per-request below; stride = all_col.

        // =================================================================
        // 2. Compute beta and g for all tokens
        //    b_raw and a_raw are sliced from the fused projection output.
        //    Stride between tokens is all_col elements.
        // =================================================================
        const int bg_total = token_num * num_v_heads_;

        const int b_offset = conv_dim_ + value_dim_;  // column offset to b logits
        const int a_offset = b_offset + v_heads_tp;   // column offset to a logits

        Tensor beta{{token_num, num_v_heads_}, dtype, device};
        Tensor g{{token_num, num_v_heads_}, dtype, device};

        auto b = all_proj.slice({0, b_offset}, {-1, v_heads_tp});
        auto a = all_proj.slice({0, a_offset}, {-1, v_heads_tp});

        ComputeBetaG_v2(beta, g, b, a, weights.A_log, weights.dt_bias, stream);

        // =================================================================
        // 3. Process all requests at once via batched kernel launches
        // =================================================================
        Tensor attn_out{{token_num, value_dim_}, dtype, device};
        Tensor conv_out{{token_num, conv_dim_}, dtype, device};

        const int state_layer_idx              = linear_layer_index(p.layer_id, layer_types_);
        const int conv_state_layer_offset      = state_layer_idx * (conv_dim_ * d_conv_);
        const int recurrent_state_layer_offset = state_layer_idx * (num_v_heads_ * key_head_dim_ * value_head_dim_);

        // ----- 3a. Fused Causal Conv1d + SiLU (all requests) -----
        // all_proj carries the non-contiguous qkv slice (stride = all_col);
        // in_stride is derived from all_proj.stride(0) inside the launcher.
        invokeFusedConv1dSiLU(conv_out,
                              all_proj,
                              weights.conv1d,
                              Tensor{},
                              pd.conv_state_ptrs,
                              pd.q_offsets,
                              pd.k_offsets,
                              pd.batch_size,
                              conv_state_layer_offset,
                              sm_count_,
                              work_counter_.data(),
                              stream);
        sync_check_cuda_error();

        // ----- 3b. Gated Delta Rule -----
        // Requests are sorted by input_len: decode (seq_len==1) first, prefill last.
        // Find the split point and dispatch each half to its optimal kernel.
        // When both are present, run them concurrently on separate streams.
        {
            int decode_count = 0;
            for (int i = 0; i < pd.batch_size; ++i) {
                if (pd.input_lens[i] <= 1)
                    ++decode_count;
                else
                    break;
            }
            const int prefill_count = pd.batch_size - decode_count;

            if (decode_count > 0 && prefill_count > 0) {
                // Fork: aux_stream (high priority) waits for prior work on main stream
                check_cuda_error(cudaEventRecord(ev_before_, stream));
                check_cuda_error(cudaStreamWaitEvent(aux_stream_, ev_before_));

                // Decode on main stream
                auto dc_state = pd.recurrent_state_ptrs.slice(0, decode_count);
                auto dc_q     = pd.q_offsets.slice(0, decode_count + 1);
                invokeGatedDeltaRuleBatched_v3(attn_out,
                                               conv_out,
                                               beta,
                                               g,
                                               dc_state,
                                               dc_q,
                                               decode_count,
                                               num_k_heads_,
                                               recurrent_state_layer_offset,
                                               state_dtype_,
                                               sm_count_,
                                               work_counter_.data(),
                                               stream);

                // Prefill on aux stream (higher priority)
                auto pf_state = pd.recurrent_state_ptrs.slice(decode_count, prefill_count);
                auto pf_q     = pd.q_offsets.slice(decode_count, prefill_count + 1);
                invokeChunkedGatedDeltaRuleBatched(attn_out,
                                                   conv_out,
                                                   beta,
                                                   g,
                                                   pf_state,
                                                   pf_q,
                                                   prefill_count,
                                                   num_k_heads_,
                                                   recurrent_state_layer_offset,
                                                   state_dtype_,
                                                   sm_count_,
                                                   work_counter_.data(),
                                                   aux_stream_);

                // Join: main stream waits for prefill to finish
                check_cuda_error(cudaEventRecord(ev_after_, aux_stream_));
                check_cuda_error(cudaStreamWaitEvent(stream, ev_after_));
            }
            else if (decode_count > 0) {
                auto state_slice = pd.recurrent_state_ptrs.slice(0, decode_count);
                auto q_slice     = pd.q_offsets.slice(0, decode_count + 1);
                invokeGatedDeltaRuleBatched_v3(attn_out,
                                               conv_out,
                                               beta,
                                               g,
                                               state_slice,
                                               q_slice,
                                               decode_count,
                                               num_k_heads_,
                                               recurrent_state_layer_offset,
                                               state_dtype_,
                                               sm_count_,
                                               work_counter_.data(),
                                               stream);
            }
            else if (prefill_count > 0) {
                auto state_slice = pd.recurrent_state_ptrs.slice(decode_count, prefill_count);
                auto q_slice     = pd.q_offsets.slice(decode_count, prefill_count + 1);
                invokeChunkedGatedDeltaRuleBatched(attn_out,
                                                   conv_out,
                                                   beta,
                                                   g,
                                                   state_slice,
                                                   q_slice,
                                                   prefill_count,
                                                   num_k_heads_,
                                                   recurrent_state_layer_offset,
                                                   state_dtype_,
                                                   sm_count_,
                                                   work_counter_.data(),
                                                   stream);
                // invokeChunkedGatedDeltaRuleBatched
            }
        }
        sync_check_cuda_error();

        // ----- 3c. RMSNormGated (all tokens at once) -----
        // Gate (z) lives at column conv_dim_ of all_proj with row-stride all_col.
        Tensor gate        = all_proj.slice({0, conv_dim_}, {-1, value_dim_});
        Tensor hidden_view = attn_out.view({token_num * num_v_heads_, value_head_dim_});
        invokeRMSNormGated(hidden_view, gate, weights.norm, norm_eps_, stream);
        sync_check_cuda_error();

        // =================================================================
        // 4. Output projection (all tokens at once)
        // =================================================================
        (void)linear_.Forward(attn_out, weights.out_proj, p.output);
        sync_check_cuda_error();
    };

    if (dtype == kHalf) {
        dispatch(half{});
    }
    else if (dtype == kBfloat16) {
        dispatch(nv_bfloat16{});
    }
    else {
        TM_CHECK(0) << "Unsupported dtype for GatedDeltaNetLayer";
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/GatedDeltaNetLayer.h
================================================
#pragma once

#include "src/turbomind/core/tensor.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/models/llama/GatedDeltaNetWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class GatedDeltaNetLayer {
public:
    struct ForwardParam {
        int                        phase;
        Tensor                     input;
        Tensor                     output;
        const GatedDeltaNetWeight* weights;
        int                        layer_id;
    };

    GatedDeltaNetLayer(const ModelParam&     model,
                       const AttentionParam& attn,
                       const EngineParam&    engine,
                       int                   tp_size,
                       const Context&        ctx,
                       int                   phases);

    ~GatedDeltaNetLayer();

    void Run(BatchOp op, int phase, TensorMap& env);

    void Forward(ForwardParam p);

private:
    void Setup(int phase, TensorMap& env);

    // Model dimensions
    int              hidden_units_;
    int              num_k_heads_;
    int              num_v_heads_;
    int              key_head_dim_;
    int              value_head_dim_;
    int              d_conv_;
    int              key_dim_;            // num_k_heads * key_head_dim
    int              value_dim_;          // num_v_heads * value_head_dim
    int              conv_dim_;           // key_dim * 2 + value_dim
    int              num_linear_layers_;  // count of linear attention layers for state sizing
    std::vector layer_types_;        // model layer types for index mapping

    float    norm_eps_;
    DataType dtype_;
    DataType state_dtype_;  // recurrent state dtype (may differ from dtype_ for float32 state)

    LlamaLinear& linear_;

    // Per-phase batch data (mirrors UnifiedAttentionLayer pattern)
    struct Data {
        std::vector rc;          // borrowed batch RequestCache pointers
        std::vector           input_lens;  // snapshot of input_len per request (captured at Setup time)
        int                        batch_size = 0;
        Buffer_               q_offsets;  // cumulative input-token offsets, device buffer
        Buffer_               k_offsets;  // cumulative key (history+input) offsets, device buffer
        std::vector        conv_states;
        std::vector        recurrent_states;
        Buffer_             conv_state_ptrs;
        Buffer_             recurrent_state_ptrs;
    };
    std::vector data_;

    // staging buffers
    Buffer_ conv_state_ptrs_buf_;
    Buffer_ recurrent_state_ptrs_buf_;

    // Queried once at construction; passed to all three kernel launchers.
    int          sm_count_{1};
    Buffer_ work_counter_;  // 1-element device int for v3 atomic claiming

    // Dual-stream dispatch: prefill on high-priority aux stream, decode on main
    cudaStream_t aux_stream_{};
    cudaEvent_t  ev_before_{};  // main→aux: prior work done
    cudaEvent_t  ev_after_{};   // aux→main: prefill done
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/GatedDeltaNetWeight.cc
================================================
#include "src/turbomind/models/llama/GatedDeltaNetWeight.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

GatedDeltaNetWeight::GatedDeltaNetWeight(int      hidden_dim,
                                         int      num_k_heads,
                                         int      num_v_heads,
                                         int      key_head_dim,
                                         int      value_head_dim,
                                         int      d_conv,
                                         bool     bias,
                                         int      tp_size,
                                         int      tp_rank,
                                         DataType data_type,
                                         DataType weight_type,
                                         int      group_size):
    tp_rank_(tp_rank), tp_size_(tp_size)
{
    const int key_dim    = num_k_heads * key_head_dim / tp_size;
    const int value_dim  = num_v_heads * value_head_dim / tp_size;
    const int v_heads_tp = num_v_heads / tp_size;
    const int conv_dim   = key_dim * 2 + value_dim;

    // GatedDeltaNet projections are stored as plain dense weights in the checkpoint
    // (dense_wtype = data_type avoids quantization path for these projections).
    const DataType dense_wtype = data_type;
    const int      dense_gsz   = 0;

    // Individual projections registered for checkpoint loading
    in_proj_qkv.emplace(hidden_dim, conv_dim, data_type, bias, dense_wtype, dense_gsz);
    in_proj_z.emplace(hidden_dim, value_dim, data_type, bias, dense_wtype, dense_gsz);
    in_proj_b.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz);
    in_proj_a.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz);
    out_proj.emplace(value_dim, hidden_dim, data_type, bias, dense_wtype, dense_gsz);

    register_module("in_proj_qkv", in_proj_qkv, tp_rank_);
    register_module("in_proj_z", in_proj_z, tp_rank_);
    register_module("in_proj_b", in_proj_b, tp_rank_);
    register_module("in_proj_a", in_proj_a, tp_rank_);
    register_module("out_proj", out_proj, tp_rank_);

    // conv1d: depthwise weights, shape (conv_dim, d_conv)
    conv1d = Tensor{{conv_dim, d_conv}, data_type, kDEVICE};
    register_parameter("conv1d." + std::to_string(tp_rank_) + ".weight", conv1d);

    // A_log: log-space decay per head, shape (num_v_heads/tp,)
    A_log = Tensor{{v_heads_tp}, data_type, kDEVICE};
    register_parameter("A_log." + std::to_string(tp_rank_) + ".weight", A_log);

    // dt_bias: per head, shape (num_v_heads/tp,)
    dt_bias = Tensor{{v_heads_tp}, data_type, kDEVICE};
    register_parameter("dt_bias." + std::to_string(tp_rank_) + ".weight", dt_bias);

    // norm: RMSNormGated weight, shape (value_head_dim,)
    norm = Tensor{{value_head_dim}, data_type, kDEVICE};
    register_parameter("norm.weight", norm);
}

// ---------------------------------------------------------------------------
// Row-wise concatenation of 4 weight matrices into a single pre-allocated
// destination tensor.
//
// Each source weight has shape (input_dim, out_dim_i) in row-major storage.
// The destination has shape (input_dim, sum_i out_dim_i) and rows are filled
// by concatenating the corresponding source rows in order.
//
// Implemented with cudaMemcpy2DAsync so that no extra temporary is needed:
// each source "column block" is scattered into the correct column range of
// the destination in one pass per source.
// ---------------------------------------------------------------------------
static void
concat_weights_4(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, Tensor& dst, cudaStream_t st)
{
    // Tensors are (K=input_dim, M=output_dim) in row-major order.
    // Each row of `dst` is [a_row | b_row | c_row | d_row].
    const int K       = dst.shape(0);
    const int M_a     = a.shape(1);
    const int M_b     = b.shape(1);
    const int M_c     = c.shape(1);
    const int M_d     = d.shape(1);
    const int M_dst   = dst.shape(1);  // M_a + M_b + M_c + M_d
    const int elem_sz = byte_size(dst.dtype(), 1);

    // Pitch of the destination row in bytes
    const size_t dst_pitch   = (size_t)M_dst * elem_sz;
    const size_t src_pitch_a = (size_t)M_a * elem_sz;
    const size_t src_pitch_b = (size_t)M_b * elem_sz;
    const size_t src_pitch_c = (size_t)M_c * elem_sz;
    const size_t src_pitch_d = (size_t)M_d * elem_sz;

    char* dst_ptr = reinterpret_cast(dst.raw_data());

    // Columns [0, M_a)
    check_cuda_error(
        cudaMemcpy2DAsync(dst_ptr, dst_pitch, a.raw_data(), src_pitch_a, src_pitch_a, K, cudaMemcpyDefault, st));

    // Columns [M_a, M_a+M_b)
    check_cuda_error(cudaMemcpy2DAsync(
        dst_ptr + src_pitch_a, dst_pitch, b.raw_data(), src_pitch_b, src_pitch_b, K, cudaMemcpyDefault, st));

    // Columns [M_a+M_b, M_a+M_b+M_c)
    check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b,
                                       dst_pitch,
                                       c.raw_data(),
                                       src_pitch_c,
                                       src_pitch_c,
                                       K,
                                       cudaMemcpyDefault,
                                       st));

    // Columns [M_a+M_b+M_c, M_dst)
    check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b + src_pitch_c,
                                       dst_pitch,
                                       d.raw_data(),
                                       src_pitch_d,
                                       src_pitch_d,
                                       K,
                                       cudaMemcpyDefault,
                                       st));
    sync_check_cuda_error();
}

void GatedDeltaNetWeight::prepare()
{
    auto stream = core::Context::stream().handle();

    // Preprocess individual weights (converts blockscale FP8, etc.)
    in_proj_qkv.preprocess();
    in_proj_z.preprocess();
    in_proj_b.preprocess();
    in_proj_a.preprocess();
    out_proj.preprocess();
    out_proj.prepare();

    // Build the fused input projection weight:
    //   shape (hidden_dim,  conv_dim + value_dim + 2*v_heads_tp)
    //   = [in_proj_qkv | in_proj_z | in_proj_b | in_proj_a]  (column-wise)
    const int out_all = in_proj_qkv.output_dim  //
                        + in_proj_z.output_dim  //
                        + in_proj_b.output_dim  //
                        + in_proj_a.output_dim;

    in_proj_all.emplace(in_proj_qkv.input_dim,
                        out_all,
                        in_proj_qkv.data_type,
                        /*bias=*/false,
                        in_proj_qkv.weight_type,
                        in_proj_qkv.group_size);

    concat_weights_4(
        in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, in_proj_all.weight, stream);

    // Prepare (convert/repack) the fused weight for GEMM
    in_proj_all.prepare();

    // Release the now-redundant individual weight tensors to free HBM
    in_proj_qkv = {};
    in_proj_z   = {};
    in_proj_b   = {};
    in_proj_a   = {};

    // Transpose conv1d from checkpoint layout [conv_dim, d_conv] to kernel layout [d_conv, conv_dim]
    {
        const int rows = conv1d.shape(0);  // conv_dim
        const int cols = conv1d.shape(1);  // d_conv

        Tensor conv1d_t{{cols, rows}, conv1d.dtype(), kDEVICE};
        invokeTransposeAxis01((uint16_t*)conv1d_t.raw_data(), (uint16_t*)conv1d.raw_data(), rows, cols, 1, stream);
        sync_check_cuda_error();
        conv1d = std::move(conv1d_t);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/GatedDeltaNetWeight.h
================================================
#pragma once

#include "src/turbomind/core/core.h"
#include "src/turbomind/core/module.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"

namespace turbomind {

struct GatedDeltaNetWeight: public core::Module {

    GatedDeltaNetWeight() = default;

    GatedDeltaNetWeight(int      hidden_dim,
                        int      num_k_heads,
                        int      num_v_heads,
                        int      key_head_dim,
                        int      value_head_dim,
                        int      d_conv,
                        bool     bias,
                        int      tp_size,
                        int      tp_rank,
                        DataType data_type,
                        DataType weight_type,
                        int      group_size);

    void prepare();

    // Individual projections – populated at load time from the checkpoint.
    // After prepare() completes they are released (null-ed) to free HBM.
    LlamaDenseWeight in_proj_qkv;  // hidden -> key_dim*2 + value_dim
    LlamaDenseWeight in_proj_z;    // hidden -> value_dim (output gate)
    LlamaDenseWeight in_proj_b;    // hidden -> num_v_heads (beta, per-head scalar)
    LlamaDenseWeight in_proj_a;    // hidden -> num_v_heads (alpha/dt, per-head scalar)

    // Fused projection: hidden -> (conv_dim + value_dim + 2*v_heads_tp).
    // Built from the four above in prepare(); used for all inference GEMMs.
    // Reduces p.input HBM reads from 4× to 1× per forward pass.
    LlamaDenseWeight in_proj_all;

    LlamaDenseWeight out_proj;  // value_dim -> hidden

    // Non-dense parameters
    Tensor conv1d;   // depthwise conv weights: (d_conv, conv_dim)
    Tensor A_log;    // log-space decay: (num_v_heads,)
    Tensor dt_bias;  // dt bias: (num_v_heads,)
    Tensor norm;     // RMSNormGated weight: (value_head_dim,)

    int tp_rank_;
    int tp_size_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc

#include 

#include 
#include 

#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

static bool is_fuse_silu_act()
{
    static const bool value = [] {
        const auto str = std::getenv("TM_FUSE_SILU_ACT");
        if (str) {
            try {
                auto v = std::stoi(str) != 0;
                TM_LOG_INFO("TM_FUSE_SILU_ACT=%d", (int)v);
                return v;
            }
            catch (...) {
            }
        }
        // TM_LOG_INFO("TM_FUSE_SILU_ACT=1");
        return true;
    }();
    return value;
}

LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(
    DataType data_type, int layer_id, const ModelParam& model, const EngineParam& engine, const MoeParam& moe_param):
    head_num_(model.head_num),
    kv_head_num_(model.kv_head_num),
    size_per_head_(model.head_dim),
    hidden_units_(model.hidden_units),
    inter_size_(model.inter_size.at(layer_id)),
    data_type_{data_type},
    weight_type_(model.weight_type),
    expert_weight_type_(model.expert_weight_type),
    attn_bias_(model.attn_bias),
    attn_tp_size_(engine.attn_tp_size),
    attn_tp_rank_(engine.attn_tp_rank),
    mlp_tp_size_(engine.mlp_tp_size),
    mlp_tp_rank_(engine.mlp_tp_rank)
{
    bool is_linear_attention = false;
    if (layer_id < (int)model.layer_types.size() && model.layer_types[layer_id] == 1) {
        is_linear_attention = true;
    }

    if (is_linear_attention) {
        linear_attn_weights.reset(
            new GatedDeltaNetWeight{hidden_units_,
                                    model.linear_num_key_heads,
                                    model.linear_num_value_heads,
                                    model.linear_key_head_dim,
                                    model.linear_value_head_dim,
                                    model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4,
                                    attn_bias_,
                                    attn_tp_size_,
                                    attn_tp_rank_,
                                    data_type_,
                                    weight_type_,
                                    model.group_size});
        register_module("linear_attn", *linear_attn_weights);
    }
    else {
        // Attention uses weight_type (fp16 in mixed quant scenarios)
        self_attn_weights.reset(new LlamaAttentionWeight{hidden_units_,
                                                         size_per_head_,
                                                         head_num_,
                                                         kv_head_num_,
                                                         model.mla,
                                                         attn_bias_,
                                                         model.qk_norm,
                                                         attn_tp_size_,
                                                         attn_tp_rank_,
                                                         data_type_,
                                                         weight_type_,
                                                         model.group_size,
                                                         model.window_size.empty() ? 0 : model.window_size.at(layer_id),
                                                         model.attn_sink,
                                                         model.attn_output_gate});
        register_module("attention", *self_attn_weights);
    }

    // FFN uses ffn_weight_type, except for layers fully excluded from
    // quantization (e.g. 'model.layers.0.' in modules_to_not_convert)
    // where all weights—including FFN—are in data_type (fp16).
    if (inter_size_) {
        const DataType ffn_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : model.ffn_weight_type;
        const bool     is_cublas_gemm = byte_size(ffn_wtype, 8) == 16;
        ffn_weights.reset(new LlamaFfnWeight{
            hidden_units_,
            inter_size_,
            model.mlp_bias,
            mlp_tp_size_,
            mlp_tp_rank_,
            data_type_,
            ffn_wtype,
            model.group_size,
            model.act_type,
            is_fuse_silu_act() && !is_cublas_gemm,
        });
        register_module("feed_forward", *ffn_weights);
    }

    // MoE routed experts use expert_weight_type (int4 for AWQ, e2m1 for mxfp4)
    // unless the layer is in unquantized_expert_layers (e.g. layer 0 excluded
    // from quantization via modules_to_not_convert).
    if (layer_id < moe_param.expert_num.size() && moe_param.expert_num[layer_id]) {
        const DataType moe_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : expert_weight_type_;
        moe_weights.reset(new MoeFfnWeight{layer_id,
                                           moe_param,
                                           hidden_units_,
                                           model.mlp_bias,
                                           data_type_,
                                           moe_wtype,
                                           model.group_size,
                                           mlp_tp_size_,
                                           mlp_tp_rank_,
                                           model.act_type,
                                           is_fuse_silu_act()});
        register_module("moe_ffn", *moe_weights);
    }

    self_attn_norm = Tensor{{hidden_units_}, data_type_, kDEVICE};
    ffn_norm       = Tensor{{hidden_units_}, data_type_, kDEVICE};
    register_parameter("attention_norm.weight", self_attn_norm);
    register_parameter("ffn_norm.weight", ffn_norm);
}

LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() = default;

void LlamaDecoderLayerWeight::prepare(const cudaDeviceProp& prop, cudaStream_t st)
{
    if (self_attn_weights) {
        self_attn_weights->prepare();
    }

    if (linear_attn_weights) {
        linear_attn_weights->prepare();
    }

    if (ffn_weights) {
        ffn_weights->prepare(false);
    }

    if (moe_weights) {
        moe_weights->prepare();
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaDecoderLayerWeight.h
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h

#pragma once

#include "src/turbomind/core/core.h"

#include "src/turbomind/models/llama/GatedDeltaNetWeight.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

struct LlamaDecoderLayerWeight: core::Module {
public:
    LlamaDecoderLayerWeight() = delete;

    LlamaDecoderLayerWeight(DataType           data_type,
                            int                layer_id,
                            const ModelParam&  model,
                            const EngineParam& engine,
                            const MoeParam&    moe_param);

    ~LlamaDecoderLayerWeight();
    LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight&) = delete;
    LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight&) = delete;

    void prepare(const cudaDeviceProp& prop, cudaStream_t st);

    Tensor self_attn_norm;
    Tensor ffn_norm;

    std::unique_ptr self_attn_weights;
    std::unique_ptr  linear_attn_weights;

    std::unique_ptr ffn_weights;
    std::unique_ptr   moe_weights;

private:
    int head_num_;
    int kv_head_num_;
    int size_per_head_;
    int hidden_units_;
    int inter_size_;

    DataType data_type_;
    DataType weight_type_;
    DataType expert_weight_type_;

    int  bit_size_;
    bool attn_bias_;
    int  attn_tp_size_;
    int  attn_tp_rank_;
    int  mlp_tp_size_;
    int  mlp_tp_rank_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaDenseWeight.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/models/llama/LlamaDenseWeight.h"

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/kernels/gemm/cast.h"
#include "src/turbomind/kernels/gemm/convert.h"
#include "src/turbomind/kernels/gemm/gemm.h"
#include "src/turbomind/kernels/gemm/types.h"
#include "src/turbomind/kernels/gemm/utils.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

void LlamaDenseWeight::emplace(
    int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size)
{
    this->data_type   = data_type;
    this->input_type  = data_type;
    this->weight_type = weight_type;
    this->input_dim   = input_dim;
    this->output_dim  = output_dim;
    this->group_size  = group_size;

    const bool is_qweight = weight_type == kUint4 || weight_type == kUint8;

    weight = Tensor({input_dim, output_dim}, weight_type, kDEVICE);
    register_parameter(is_qweight ? "qweight" : "weight", weight);

    if (bias) {
        this->bias = Tensor{{output_dim}, data_type, kDEVICE};
        register_parameter("bias", this->bias);
    }

    if (weight_type == kFloat8_e4m3) {
        TM_CHECK_EQ(group_size, 128);
        scales       = Tensor{{cdiv(input_dim, group_size), cdiv(output_dim, group_size)}, kFloat, kDEVICE};
        weight_quant = QuantDesc{gemm::QuantType::kB, group_size};
        if (getSMVersion() == 90) {
            input_type  = kFloat8_e4m3;
            input_quant = QuantDesc{gemm::QuantType::kK, group_size};
        }
        register_parameter("scales", scales);
    }
    else if (weight_type == kFloat4_e2m1) {
        scales       = Tensor{{cdiv(input_dim, group_size), output_dim}, kUint8, kDEVICE};
        input_type   = data_type;
        weight_quant = QuantDesc{gemm::QuantType::kK, group_size};
        register_parameter("scales", scales);
    }
    else if (is_qweight) {
        TM_CHECK(input_dim % group_size == 0) << input_dim << " " << group_size;
        scales       = Tensor{{input_dim / group_size, output_dim}, data_type, kDEVICE};
        zeros        = Tensor{{input_dim / group_size, output_dim}, data_type, kDEVICE};
        weight_quant = QuantDesc{gemm::QuantType::kK, group_size};
        register_parameter("scales", scales);
        register_parameter("zeros", zeros);
    }

    k_desc = {};
    q_desc = {};

    // default case: floating point, N-major
    k_desc.type  = weight.dtype();
    k_desc.order = gemm::kRowMajor;
    k_desc.rows  = input_dim;
    k_desc.cols  = output_dim;
    k_desc.ld    = output_dim;
}

void LlamaDenseWeight::preprocess()
{
    if (!weight) {
        return;
    }
    if (weight_quant.type == gemm::QuantType::kB && input_quant.type == gemm::QuantType::kNone) {
        // Convert blockwise scales to groupwise scales
        weight_quant.type = gemm::QuantType::kK;
        scales            = BlockscaleToGroupscale(scales, data_type, weight_quant.group_size);
    }
}

static void Convert(LlamaDenseWeight& dense, bool is_grouped, cudaStream_t st)
{
    using namespace gemm;

    auto [conv_w, conv_s] =
        GetConverters(dense.data_type, dense.weight_type, dense.input_type, is_grouped, getSMVersion());

    if (conv_w) {
        const auto order_w = conv_w->order;
        const bool is_A    = get_operand_tag(conv_w->pack) == OPERAND_A;
        const bool is_B    = !is_A;

        const int bits = byte_size(dense.weight_type, 8);

        Tensor_ tmp{{dense.input_dim, dense.output_dim}, kDEVICE};

        if (bits == 4) {  // u4 -> u16
            extend_to_u16(tmp.data(), (const uint4_t*)dense.weight.raw_data(), tmp.size(), st);
            sync_check_cuda_error();
        }
        else if (bits == 8) {  // u8 -> u16
            extend_to_u16(tmp.data(), (const uint8_t*)dense.weight.raw_data(), tmp.size(), st);
            sync_check_cuda_error();
        }
        else if (bits == 16) {
            check_cuda_error(
                cudaMemcpyAsync(tmp.raw_data(), dense.weight.raw_data(), tmp.byte_size(), cudaMemcpyDefault, st));
        }

        if (order_w == kRowMajor) {  // (k,m) -> (m,k)
            Tensor_ trans{{dense.output_dim, dense.input_dim}, kDEVICE};
            invokeTransposeAxis01(trans.data(), tmp.data(), dense.input_dim, dense.output_dim, 1, st);
            tmp = trans;
        }

        MatrixLayout w_desc{
            dense.data_type,
            order_w,
            (int)dense.output_dim,  // M
            (int)dense.input_dim,   // K
            order_w == kRowMajor ? (int)dense.input_dim : (int)dense.output_dim,
        };

        if (is_B) {
            std::swap(w_desc.rows, w_desc.cols);
            w_desc.order = ~w_desc.order;
        }

        MatrixLayout k_desc = w_desc;
        k_desc.type         = dense.weight_type;
        // Converter does not recognize e2m1 / e4m3
        if (bits == 4) {
            k_desc.type = data_type_v;
        }
        else if (bits == 8) {
            k_desc.type = data_type_v;
        }
        k_desc.pack = conv_w->pack;

        check_cuda_error(cudaMemsetAsync(dense.weight.raw_data(), 0, dense.weight.byte_size(), st));

        TM_CHECK(conv_w->Convert(tmp.data(), w_desc, dense.weight.raw_data(), k_desc, st) == 0);

        sync_check_cuda_error();

        k_desc.type = dense.weight_type;
        if (is_A) {
            k_desc = transpose(k_desc);
        }
        dense.k_desc = k_desc;
    }

    if (conv_s) {
        const auto order_s = conv_s->order;
        const auto pack_s  = conv_s->pack;
        const bool is_A    = get_operand_tag(conv_s->pack) == OPERAND_U;
        const bool is_B    = !is_A;

        Tensor   tmp_q;
        DataType scale_type;

        if (dense.zeros) {  // AWQ/GPTQ fuse scales and zeros
            tmp_q = {{dense.scales.size(), 2}, kHalf, kDEVICE};
            fuse_scales_and_zeros(
                tmp_q.data(), dense.scales.data(), dense.zeros.data(), dense.scales.size(), st);
            scale_type   = kUint32;  // half2
            dense.zeros  = {};
            dense.scales = empty_like(tmp_q);
        }
        else if (dense.weight_type == kFloat8_e4m3) {  // e4m3
            tmp_q = empty_like(dense.scales);
            Copy(dense.scales, tmp_q);
            scale_type = kUint16;  // bf16
        }
        else {  // mxfp4
            tmp_q = empty_like(dense.scales);
            Copy(dense.scales, tmp_q);
            scale_type = kUint8;  // ue8m0
        }

        if (dense.data_type == kHalf && dense.weight_type == kFloat4_e2m1) {  // mxfp4
            AdjustUe8m0ScaleForHalf(tmp_q.data(), tmp_q.size(), st);
            sync_check_cuda_error();
        }

        MatrixLayout s_desc{
            scale_type,
            order_s,
            (int)dense.output_dim,                    // M
            (int)dense.input_dim / dense.group_size,  // K
            (int)dense.output_dim,                    // always MN-major
        };

        if (is_B) {
            std::swap(s_desc.rows, s_desc.cols);
            s_desc.order = ~s_desc.order;
        }

        MatrixLayout q_desc = s_desc;
        q_desc.pack         = pack_s;

        TM_CHECK(conv_s->Convert(tmp_q.raw_data(), s_desc, dense.scales.raw_data(), q_desc, st) == 0);
        sync_check_cuda_error();

        // weight is placed at B in `Linear`
        if (is_A) {
            q_desc = transpose(q_desc);
        }
        dense.q_desc = q_desc;
    }
}

static void ConvertBlockscaleFP8Native(LlamaDenseWeight& dense, cudaStream_t stream)
{
    using namespace gemm;

    TM_CHECK_GE(getSMVersion(), 90);
    TM_CHECK_EQ(dense.data_type, data_type_v);

    auto process = [&](Tensor& x, MatrixLayout& d, auto dtype) {
        using T = decltype(dtype);
        Tensor trans{{x.shape(1), x.shape(0)}, x.dtype(), kDEVICE};
        invokeTransposeAxis01((T*)trans.raw_data(), (T*)x.raw_data(), x.shape(0), x.shape(1), 1, stream);
        x = std::move(trans);
        d = MatrixLayout{x.dtype(),  //
                         kColMajor,
                         (int)x.shape(1),
                         (int)x.shape(0),
                         (int)x.stride(0)};
    };

    TM_CHECK_EQ(dense.weight.dtype(), kFloat8_e4m3);
    process(dense.weight, dense.k_desc, uint8_t{});

    TM_CHECK_EQ(dense.scales.dtype(), kFloat);
    process(dense.scales, dense.q_desc, float{});
}

void LlamaDenseWeight::prepare(bool fused_moe)
{
    if (!weight) {
        return;
    }

    auto stream = core::Context::stream().handle();

    if (weight_type == kFloat8_e4m3 && input_type == kFloat8_e4m3) {
        ConvertBlockscaleFP8Native(*this, stream);
    }
    else {
        Convert(*this, fused_moe, stream);
    }
}

LlamaAttentionWeight::LlamaAttentionWeight(int      hidden_dim,
                                           int      head_dim,
                                           int      head_num,
                                           int      kv_head_num,
                                           MLAParam mla,
                                           bool     bias,
                                           bool     qk_norm,
                                           int      tp_size,
                                           int      tp_rank,
                                           DataType data_type,
                                           DataType weight_type,
                                           int      group_size,
                                           int      window_size,
                                           bool     sink,
                                           bool     attn_output_gate)
{
    this->window_size = window_size;

    // attn_output_gate doubles Q dimension (extra gate projection fused into Q)
    const int q_factor = attn_output_gate ? 2 : 1;

    if (mla.kv_lora_rank == 0) {
        qkv.emplace(hidden_dim,
                    (head_num * q_factor + 2 * kv_head_num) * head_dim / tp_size,
                    data_type,
                    bias,
                    weight_type,
                    group_size);
        register_module("w_qkv", qkv, tp_rank);
        if (qk_norm) {
            q_a_layernorm  = Tensor{{head_dim}, data_type, kDEVICE};
            kv_a_layernorm = Tensor{{head_dim}, data_type, kDEVICE};
            register_parameter("q_norm", q_a_layernorm);
            register_parameter("k_norm", kv_a_layernorm);
        }
    }
    else {
        const int qk_nope_dim = head_dim - mla.qk_rope_dim;
        if (mla.q_lora_rank) {
            q_a_proj.emplace(hidden_dim, mla.q_lora_rank, data_type, false, weight_type, group_size);
            q_b_proj.emplace(mla.q_lora_rank, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);
            q_a_layernorm = Tensor{{q_b_proj.input_dim}, data_type, kDEVICE};
            register_module("q_a_proj", q_a_proj);
            register_module("q_b_proj", q_b_proj, tp_rank);
            register_parameter("q_a_layernorm", q_a_layernorm);
        }
        else {
            q_proj.emplace(hidden_dim, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);
            register_module("q_proj", q_proj, tp_rank);
        }
        kv_a_proj.emplace(hidden_dim, mla.kv_lora_rank + mla.qk_rope_dim, data_type, false, weight_type, group_size);
        // kv_b_proj.emplace(mla.kv_lora_rank,
        //                   head_num * (qk_nope_dim + mla.v_head_dim) / tp_size,
        //                   data_type,
        //                   false,
        //                   weight_type,
        //                   group_size);

        kv_a_layernorm = Tensor{{mla.kv_lora_rank}, data_type, kDEVICE};
        register_module("kv_a_proj", kv_a_proj);
        // register_module("kv_b_proj", kv_b_proj, tp_rank);
        register_parameter("kv_a_layernorm", kv_a_layernorm);
    }
    output.emplace((head_num * head_dim) / tp_size, hidden_dim, data_type, bias, weight_type, group_size);
    register_module("wo", output, tp_rank);

    if (sink) {
        sinks = Tensor{{head_num / tp_size}, data_type, kDEVICE};
        register_parameter(std::to_string(tp_rank) + ".sinks", sinks);
    }
}

void LlamaAttentionWeight::prepare()
{
    std::vector weights{
        &qkv, &output, &q_a_proj, &q_a_proj, &q_b_proj, &kv_a_proj  // &kv_b_proj,
    };
    for (auto& w : weights) {
        w->preprocess();
        w->prepare();
    }
}

LlamaFfnWeight::LlamaFfnWeight(int            hidden_dim,
                               int            inter_size,
                               bool           bias,
                               int            tp_size,
                               int            tp_rank,
                               DataType       data_type,
                               DataType       weight_type,
                               int            group_size,
                               ActivationType act_type,
                               bool           fuse_silu_act)
{
    TM_CHECK(inter_size % tp_size == 0) << inter_size << " " << tp_size;

    inter_size /= tp_size;

    this->inter_size    = inter_size;
    this->tp_rank       = tp_rank;
    this->act_type      = act_type;
    this->is_fused_silu = fuse_silu_act && this->act_type == ActivationType::kSilu;

    gating.emplace(hidden_dim, inter_size, data_type, bias, weight_type, group_size);

    intermediate.emplace(hidden_dim, inter_size, data_type, bias, weight_type, group_size);

    output.emplace(inter_size, hidden_dim, data_type, bias, weight_type, group_size);

    if (gating.input_type == kFloat8_e4m3) {  // SM90 FP8*FP8 GEMM, can't fuse
        this->is_fused_silu = false;
    }

    register_module("w1", gating, tp_rank);
    register_module("w3", intermediate, tp_rank);
    register_module("w2", output, tp_rank);
}

static void Interleave(const Tensor& a, const Tensor& b, Tensor& c, cudaStream_t st)
{
    TM_CHECK(a.layout() == b.layout());
    int M, K;
    if (a.ndim() == 2) {
        std::tie(K, M) = a.shapes(0, 1);
    }
    else {
        M = a.shape(0);
        K = 1;
    }
    auto a_ = a.raw_data();
    auto b_ = b.raw_data();
    auto c_ = c.raw_data();

    const int bits = byte_size(a.dtype(), 8);
    if (bits == 4) {
        Buffer_ ta{a.size(), kDEVICE};
        Buffer_ tb{b.size(), kDEVICE};
        Buffer_ tc{c.size(), kDEVICE};
        extend_to_u8(ta.data(), (uint4_t*)a_, a.size(), st);
        extend_to_u8(tb.data(), (uint4_t*)b_, b.size(), st);
        interleave_output_dims(tc.data(), ta.data(), tb.data(), M, K, st);
        compact_to_u4((uint4_t*)c_, tc.data(), c.size(), st);
    }
    else if (bits == 8) {
        interleave_output_dims((uint8_t*)c_, (uint8_t*)a_, (uint8_t*)b_, M, K, st);
    }
    else if (bits == 16) {
        interleave_output_dims((uint16_t*)c_, (uint16_t*)a_, (uint16_t*)b_, M, K, st);
    }
    else if (bits == 32) {
        interleave_output_dims((uint32_t*)c_, (uint32_t*)a_, (uint32_t*)b_, M, K, st);
    }
    else {
        TM_CHECK(0);
    }
}

void interleave(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, DataType data_type, cudaStream_t st)
{
    TM_CHECK_EQ(c.input_dim, a.input_dim);
    TM_CHECK_EQ(c.input_dim, b.input_dim);
    TM_CHECK_EQ(c.output_dim, a.output_dim * 2);
    TM_CHECK_EQ(c.output_dim, b.output_dim * 2);
    TM_CHECK_EQ(c.group_size, a.group_size);
    TM_CHECK_EQ(c.group_size, b.group_size);

    Interleave(a.weight, b.weight, c.weight, st);
    sync_check_cuda_error();

    if (a.scales) {
        Interleave(a.scales, b.scales, c.scales, st);
        sync_check_cuda_error();
    }
    if (a.zeros) {
        Interleave(a.zeros, b.zeros, c.zeros, st);
        sync_check_cuda_error();
    }
    if (a.bias) {
        Interleave(a.bias, b.bias, c.bias, st);
        sync_check_cuda_error();
    }
}

static void Chunk(const Tensor& a, const Tensor& b, Tensor& c, cudaStream_t st)
{
    TM_CHECK(a.layout() == b.layout());
    int M, K, spitch, dpitch;
    if (a.ndim() == 2) {
        std::tie(K, M) = a.shapes(0, 1);
        spitch         = byte_size(a.dtype(), a.stride(0));
        dpitch         = byte_size(c.dtype(), c.stride(0));
    }
    else {
        M      = a.shape(0);
        K      = 1;
        spitch = byte_size(a.dtype(), M);
        dpitch = byte_size(c.dtype(), c.shape(0));
    }
    int height = K;
    int width  = byte_size(a.dtype(), M);
    check_cuda_error(cudaMemcpy2DAsync((char*)c.raw_data(),  //
                                       dpitch,
                                       (const char*)a.raw_data(),
                                       spitch,
                                       width,
                                       height,
                                       cudaMemcpyDefault,
                                       st));
    check_cuda_error(cudaMemcpy2DAsync((char*)c.raw_data() + width,  //
                                       dpitch,
                                       (const char*)b.raw_data(),
                                       spitch,
                                       width,
                                       height,
                                       cudaMemcpyDefault,
                                       st));
}

void chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, DataType data_type, cudaStream_t st)
{
    TM_CHECK_EQ(c.input_dim, a.input_dim);
    TM_CHECK_EQ(c.input_dim, b.input_dim);
    TM_CHECK_EQ(c.output_dim, a.output_dim * 2);
    TM_CHECK_EQ(c.output_dim, b.output_dim * 2);
    TM_CHECK_EQ(c.group_size, a.group_size);
    TM_CHECK_EQ(c.group_size, b.group_size);

    Chunk(a.weight, b.weight, c.weight, st);
    sync_check_cuda_error();

    if (a.scales) {
        Chunk(a.scales, b.scales, c.scales, st);
        sync_check_cuda_error();
    }
    if (a.zeros) {
        Chunk(a.zeros, b.zeros, c.zeros, st);
        sync_check_cuda_error();
    }
    if (a.bias) {
        Chunk(a.bias, b.bias, c.bias, st);
        sync_check_cuda_error();
    }
}

void LlamaFfnWeight::prepare(bool fused_moe)
{
    const auto data_type = gating.data_type;

    auto stream = core::Context().stream().handle();

    gating.preprocess();
    intermediate.preprocess();

    if (fuse_up_and_gate) {
        auto& gate_and_up = fused_gating_intermediate;

        gate_and_up.emplace(gating.input_dim,  //
                            gating.output_dim * 2,
                            gating.data_type,
                            (bool)gating.bias,
                            gating.weight_type,
                            gating.group_size);
        gate_and_up.preprocess();
        register_module("w1w3", gate_and_up, this->tp_rank);

        if (is_fused_silu) {
            interleave(gate_and_up, gating, intermediate, data_type, stream);
            gate_and_up.epilogue = gemm::Epilogue::kGatedSilu;
        }
        else {
            chunk(gate_and_up, gating, intermediate, data_type, stream);
        }

        fused_gating_intermediate.prepare(fused_moe);

        gating       = {};
        intermediate = {};
    }
    else {
        gating.prepare(fused_moe);
        intermediate.prepare(fused_moe);
    }

    output.preprocess();
    output.prepare(fused_moe);
}

MoeFfnWeight::MoeFfnWeight(int             layer_id,
                           const MoeParam& param,
                           int             hidden_dim,
                           bool            mlp_bias,
                           DataType        data_type,
                           DataType        weight_type,
                           int             group_size,
                           int             tp_size,
                           int             tp_rank,
                           ActivationType  act_type,
                           bool            fuse_silu_act)
{
    if ((int)param.expert_num.size() <= layer_id) {
        return;
    }

    const int expert_num = param.expert_num[layer_id];

    if (expert_num == 0) {
        return;
    }

    gate.emplace(hidden_dim, expert_num, data_type, param.router_bias, data_type, 1);
    register_module("gate", gate);

    if (param.topk_method == "noaux_tc") {
        score_correction_bias = Tensor{{expert_num}, kFloat, kDEVICE};
        register_parameter("gate.score_correction_bias", score_correction_bias);
    }

    method = param.method;

    const bool is_cublas_gemm = method == MoeParam::kNaive && byte_size(weight_type, 8) == 16;
    if (is_cublas_gemm || mlp_bias) {
        fuse_silu_act = false;
    }

    experts.reserve(expert_num);
    for (int i = 0; i < expert_num; ++i) {
        experts.emplace_back(new LlamaFfnWeight{hidden_dim,
                                                param.inter_size,
                                                mlp_bias,
                                                tp_size,
                                                tp_rank,
                                                data_type,
                                                weight_type,
                                                group_size,
                                                act_type,
                                                fuse_silu_act});
        register_module("experts", *experts.back(), i);
    }

    if (param.shared_gate) {
        shared_gate.emplace(hidden_dim, 1, data_type, false, data_type, 1);
        register_module("shared_gate", shared_gate);
    }
}

void MoeFfnWeight::prepare()
{
    const auto fused_moe = method == MoeParam::kFused;

    gate.prepare();
    shared_gate.prepare();

    for (auto& e : experts) {
        e->prepare(fused_moe);
    }

    const int n = experts.size();
    LinkExperts([&](int i) { return &experts[i]->fused_gating_intermediate; }, n, block.fused_gating_intermediate);
    LinkExperts([&](int i) { return &experts[i]->output; }, n, block.output);

    auto& e = *experts.at(0);
    // Copy MLP properties
    block.inter_size    = e.inter_size;
    block.is_fused_silu = e.is_fused_silu;
    block.act_type      = e.act_type;
}

void LinkExperts(std::function experts, int n, LlamaDenseWeight& d)
{
    const auto& e = *experts(0);

    d.input_dim    = e.input_dim;
    d.output_dim   = e.output_dim;
    d.group_size   = e.group_size;
    d.data_type    = e.data_type;
    d.input_type   = e.input_type;
    d.weight_type  = e.weight_type;
    d.input_quant  = e.input_quant;
    d.weight_quant = e.weight_quant;
    d.k_desc       = e.k_desc;
    d.q_desc       = e.q_desc;
    d.epilogue     = e.epilogue;

    d.k_desc.num = d.q_desc.num = n;

    if (e.bias) {
        d.bias = Tensor{{n, e.output_dim}, e.bias.dtype(), kDEVICE};
    }

    std::vector> weights;
    std::vector> scales;

    for (int i = 0; i < n; ++i) {
        auto& e = *experts(i);
        weights.emplace_back(e.weight.raw_data(), e.k_desc.ld);
        if (e.scales) {
            scales.emplace_back(e.scales.raw_data(), e.q_desc.ld);
        }
        if (e.bias) {
            Copy(e.bias, d.bias.slice(i, 1).squeeze(0));
        }
    }

    auto stream = core::Context::stream().handle();

    if (d.weight_type == kFloat8_e4m3 && d.input_type == kFloat8_e4m3) {
        auto make_blocked_ptr = [&](const auto& ptrs) {
            return std::shared_ptr{gemm::MakeBlockedPtrs(ptrs, stream), [](auto p) { cudaFree(p); }};
        };
        d.weight = Tensor{make_blocked_ptr(weights), {n}, e.weight.dtype(), kDEVICE};
        d.scales = Tensor{make_blocked_ptr(scales), {n}, e.scales.dtype(), kDEVICE};
        // This is needed to be recognized as blocked striding mode
        d.k_desc.offsets = d.q_desc.offsets = (int*)1;
    }
    else {
        auto make_strided_ptr = [&](const auto& ptrs) {
            return std::shared_ptr{gemm::MakeStridedPtrs(ptrs, stream), [](auto p) { cudaFree(p); }};
        };
        d.weight = Tensor{make_strided_ptr(weights), {n}, d.weight_type, kDEVICE};
        if (e.scales) {
            d.scales = Tensor{make_strided_ptr(scales), {n}, e.scales.dtype(), kDEVICE};
        }
        // pre-sm90 grouped GEMM need `ld == 0 to resolve strided_ptr
        d.k_desc.ld = d.q_desc.ld = 0;
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaDenseWeight.h
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/DenseWeight.h

#pragma once

#include "src/turbomind/core/core.h"
#include "src/turbomind/core/module.h"

#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/kernels/gemm/types.h"

#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

using gemm::QuantDesc;
using gemm::MatrixLayout;
using gemm::Epilogue;

struct LlamaDenseWeight: public core::Module {

    LlamaDenseWeight():
        data_type{}, weight_type{}, input_type{}, weight_quant{}, input_quant{}, epilogue{}, k_desc{}, q_desc{}
    {
    }

    void emplace(int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size);

    void preprocess();

    void prepare(bool fused_moe = 0);

    LlamaDenseWeight& operator=(std::nullptr_t)
    {
        this->~LlamaDenseWeight();
        new (this) LlamaDenseWeight{};
        return *this;
    }

    operator bool() const noexcept
    {
        return static_cast(weight);
    }

    int input_dim  = 0;
    int output_dim = 0;
    int group_size = 1;

    Tensor weight;
    Tensor bias;

    Tensor scales;
    Tensor zeros;

    DataType data_type;

    DataType weight_type;
    DataType input_type;

    QuantDesc weight_quant;
    QuantDesc input_quant;

    Epilogue epilogue;

    MatrixLayout k_desc;
    MatrixLayout q_desc;
};

struct LlamaAttentionWeight: public core::Module {

    LlamaAttentionWeight() = default;

    LlamaAttentionWeight(int      hidden_dim,
                         int      head_dim,
                         int      head_num,
                         int      kv_head_num,
                         MLAParam mla,
                         bool     bias,
                         bool     qk_norm,
                         int      tp_size,
                         int      tp_rank,
                         DataType data_type,
                         DataType weight_type,
                         int      group_size,
                         int      window_size,
                         bool     sink,
                         bool     attn_output_gate = false);

    void prepare();

    LlamaDenseWeight qkv;
    LlamaDenseWeight output;

    Tensor sinks;

    LlamaDenseWeight q_proj;
    LlamaDenseWeight q_a_proj;
    LlamaDenseWeight q_b_proj;
    LlamaDenseWeight kv_a_proj;
    // LlamaDenseWeight kv_b_proj;

    Tensor q_a_layernorm;
    Tensor kv_a_layernorm;

    int window_size{};
};

struct LlamaFfnWeight: core::Module {

    LlamaFfnWeight() = default;

    LlamaFfnWeight(int            hidden_dim,
                   int            inter_size,
                   bool           bias,
                   int            tp_size,
                   int            tp_rank,
                   DataType       data_type,
                   DataType       weight_type,
                   int            group_size,
                   ActivationType act_type,
                   bool           fuse_silu_act);

    static constexpr bool fuse_up_and_gate = true;

    void prepare(bool fused_moe);

    LlamaDenseWeight gating;
    LlamaDenseWeight intermediate;
    LlamaDenseWeight output;
    LlamaDenseWeight fused_gating_intermediate;

    ActivationType act_type;

    int  inter_size{};
    bool is_fused_silu{};

    int tp_rank{};
};

struct MoeFfnWeight: core::Module {

    MoeFfnWeight() = default;

    MoeFfnWeight(int             layer_id,
                 const MoeParam& param,
                 int             hidden_dim,
                 bool            mlp_bias,
                 DataType        data_type,
                 DataType        weight_type,
                 int             group_size,
                 int             tp_size,
                 int             tp_rank,
                 ActivationType  act_type,
                 bool            fuse_silu_act);

    void prepare();

    LlamaDenseWeight gate;
    LlamaDenseWeight shared_gate;

    /// Per-expert score correction bias for noaux_tc routing (optional; used when topk_method == "noaux_tc")
    Tensor score_correction_bias;

    std::vector> experts;

    // reference into `experts`
    LlamaFfnWeight block;

    MoeParam::Method method{};
};

void LinkExperts(std::function experts, int n, LlamaDenseWeight& d);

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaFfnLayer.cc
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.h

#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/anomaly_handler.h"

namespace turbomind {

void LlamaFfnLayer::forward(ForwardParam param)
{
    NvtxScope scope("ffn");

    const auto& mlp = *param.weights;

    const int token_num  = param.input.shape(0);
    const int inter_size = mlp.inter_size;
    const int layer_id   = param.layer_id;

    const auto stream = core::Context::stream().handle();

    Tensor gating;
    Tensor inter;

    if (mlp.fused_gating_intermediate.weight) {
        auto mix = linear_.Forward(param.input, mlp.fused_gating_intermediate);
        sync_check_cuda_error();

        gating = mix.slice({0, 0}, {(int)token_num, inter_size});
        if (!mlp.is_fused_silu) {
            inter = mix.slice({0, inter_size}, {(ssize_t)token_num, inter_size});
        }
    }
    else {
        gating = linear_.Forward(param.input, mlp.gating);
        sync_check_cuda_error();
        TM_DEBUG_TENSOR(gating, Concat("w1", layer_id), 3);

        inter = linear_.Forward(param.input, mlp.intermediate);
        sync_check_cuda_error();
        TM_DEBUG_TENSOR(inter, Concat("w3", layer_id), 3);
    }

    if (!mlp.is_fused_silu) {
        // gate' = silu(gate) * up
        Activation(gating, inter, mlp.act_type, stream);
        sync_check_cuda_error();
        TM_DEBUG_TENSOR(gating, Concat("act", layer_id), 3);
    }

    {  // w2(x)
        NvtxScope scope("w2");
        linear_.Forward(gating, mlp.output, param.output);
        sync_check_cuda_error();
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaFfnLayer.h
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.cc

#pragma once

#include "src/turbomind/core/core.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class LlamaFfnLayer {
public:
    LlamaFfnLayer(const ModelParam& model, const Context& ctx): hidden_units_(model.hidden_units), linear_(*ctx.linear)
    {
    }

    struct ForwardParam {
        Tensor                input;
        Tensor                output;
        const LlamaFfnWeight* weights;
        int                   layer_id;
    };

    void forward(ForwardParam param);

private:
    const size_t hidden_units_;
    LlamaLinear& linear_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaLinear.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/core.h"
#include "src/turbomind/core/cuda_data_type.h"
#include "src/turbomind/core/data_type.h"

#include "src/turbomind/kernels/gemm/gemm.h"
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/kernels/gemm/types.h"

#include "src/turbomind/kernels/quantization.h"

#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"

#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

using namespace gemm;

struct LlamaLinear::Impl {

    explicit Impl()
    {
        workspace_ = {};

        workspace_.barriers_size   = gemm::Gemm::kBarriersSize;
        workspace_.partials_size   = gemm::Gemm::kPartialsSize;
        workspace_.tensormaps_size = 8192 * 128;  // maximum 4096 tensor maps

        auto st = core::Context::stream().handle();

        check_cuda_error(cudaMallocAsync(&workspace_.barriers, workspace_.barriers_size, st));
        check_cuda_error(cudaMallocAsync(&workspace_.partials, workspace_.partials_size, st));
        check_cuda_error(cudaMallocAsync(&workspace_.tensormaps, workspace_.partials_size, st));
        check_cuda_error(cudaMemsetAsync(workspace_.barriers, 0, workspace_.barriers_size, st));
        check_cuda_error(cudaMallocAsync(&workspace_.flags, sizeof(int), st));

        core::Context::stream().Sync();
    }

    ~Impl()
    {
        auto st = core::Context::stream().handle();

        cudaFreeAsync(workspace_.barriers, st);
        cudaFreeAsync(workspace_.partials, st);
        cudaFreeAsync(workspace_.tensormaps, st);
        cudaFreeAsync(workspace_.flags, st);
        workspace_ = {};
    }

    std::tuple GetOperandB(const LlamaDenseWeight& dense)
    {
        const Tensor& B      = dense.weight;
        const Tensor& V      = dense.scales;
        MatrixLayout  desc_B = dense.k_desc;
        MatrixLayout  desc_V = dense.q_desc;
        return {B, desc_B, V, desc_V};
    }

    std::tuple
    GetOperandA(const LlamaDenseWeight& dense, const Tensor& input, Buffer_ indices, const Buffer_& offsets)
    {
        auto st = core::Context::stream().handle();

        Tensor A;
        Tensor U;

        const int m = indices ? indices.size() : input.shape(0);

        // Currently, FP8 only; INT8 may be added later
        if (input.dtype() != dense.input_type) {
            QuantizeSymm(A, U, input, st);
            sync_check_cuda_error();
        }
        else {
            A = input;
        }

        if (indices && A.dtype() == kFloat8_e4m3) {
            const auto [bsz, k] = A.shapes(0, 1);
            const int e         = indices.size() / bsz;
            Tensor    A_e       = {{m, k}, A.dtype(), kDEVICE};
            invokeMoeDispatch(A_e, A, indices.data(), e, st);
            sync_check_cuda_error();
            Tensor U_e;
            invokeMoeDispatchScales(U_e, U, indices.data(), e, st);
            sync_check_cuda_error();
            A       = A_e;
            U       = U_e;
            indices = {};  // indices already applied
        }

        MatrixLayout desc_A{A.dtype(), gemm::Order::kRowMajor, m, (int)A.shape(1), (int)A.stride(0)};
        MatrixLayout desc_U{};
        if (U) {
            desc_U = {U.dtype(), kColMajor, (int)U.shape(1), (int)U.shape(0), (int)U.stride(0)};
        }
        if (offsets) {
            desc_A.num = desc_U.num = dense.k_desc.num;
            desc_A.offsets = desc_U.offsets = const_cast(offsets.data());
        }
        if (indices) {
            desc_A.idxs = desc_U.idxs = const_cast(indices.data());
        }

        return {A, desc_A, U, desc_U};
    }

    void Forward(Tensor&                 output,
                 const Tensor&           input,  //
                 const LlamaDenseWeight& dense,
                 const Buffer_&     indices,
                 const Buffer_&     offsets)
    {
        using namespace gemm;

        Operation op{};
        op.dispatch  = dispatch_policy_;
        op.epilogue  = dense.epilogue;
        op.quant_a   = dense.input_quant;
        op.quant_b   = dense.weight_quant;
        op.batch_dim = 0;

        auto&& [A, desc_A, U, desc_U] = GetOperandA(dense, input, indices, offsets);
        auto&& [B, desc_B, V, desc_V] = GetOperandB(dense);

        Tensor& D = output;
        if (!D) {
            int dim = dense.epilogue == Epilogue::kGatedSilu ? dense.output_dim / 2 : dense.output_dim;
            D       = Tensor{{desc_A.rows, dim}, dense.data_type, kDEVICE};
        }

        // std::cout << "D: " << D << " " << desc_B.num << "\n";

        MatrixLayout desc_D{
            output.dtype(),
            kRowMajor,
            (int)output.shape(0),
            dense.output_dim,
            (int)output.stride(0),
        };

        if (offsets) {
            desc_D.num     = desc_B.num;
            desc_D.offsets = const_cast(offsets.data());
        }

        auto ec = gemm_.Run(op,
                            1.f,
                            A.raw_data(),
                            desc_A,
                            U.data_or((void*)nullptr),
                            desc_U,
                            B.raw_data(),
                            desc_B,
                            V.data_or((void*)nullptr),
                            desc_V,
                            0.f,
                            D.raw_data(),
                            desc_D,
                            D.raw_data(),
                            desc_D,
                            workspace_,
                            core::Context::stream().handle());

        if (ec) {
            TM_LOG_ERROR("%s: %d", __PRETTY_FUNCTION__, ec);
        }
    }

    gemm::Gemm           gemm_;
    gemm::DispatchPolicy dispatch_policy_{gemm::DispatchPolicy::kDefault};

    gemm::Workspace workspace_;
};

LlamaLinear::LlamaLinear(): impl_{std::make_shared()} {}

Tensor LlamaLinear::Forward(const Tensor&           input,  //
                            const LlamaDenseWeight& weight,
                            std::optional   output)
{
    return Forward(input, weight, {}, {}, output);
}

Tensor LlamaLinear::Forward(const Tensor&           input,  //
                            const LlamaDenseWeight& weight,
                            const Buffer_&     indices,
                            const Buffer_&     offsets,
                            std::optional   output)
{
    Tensor in = input.view({-1, input.shape(-1)});
    Tensor out;

    if (output) {
        out = output->view({-1, output->shape(-1)});
    }

    impl_->Forward(out, in, weight, indices, offsets);

    return out;
}

void LlamaLinear::set_measure(bool measure)
{
    impl_->dispatch_policy_ = measure ? gemm::DispatchPolicy::kMeasure : gemm::DispatchPolicy::kReuse;
}

int LlamaLinear::Export(std::ostream& os)
{
    if (os) {
        return impl_->gemm_.Export(os);
    }
    return 0;
}

int LlamaLinear::Import(std::istream& is)
{
    auto n_records = 0;
    if (is) {
        n_records = impl_->gemm_.Import(is);
    }
    if (n_records) {
        impl_->dispatch_policy_ = gemm::DispatchPolicy::kReuse;
    };
    return n_records;
}

std::vector LlamaLinear::GetTuningSeq() const
{
    return impl_->gemm_.GetTuningSeq();
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaLinear.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"

namespace turbomind {

class LlamaLinear {
public:
    explicit LlamaLinear();

    Tensor Forward(const Tensor&           input,  //
                   const LlamaDenseWeight& weight,
                   std::optional   output = {});

    Tensor Forward(const Tensor&           input,
                   const LlamaDenseWeight& weight,
                   const Buffer_&     indices,
                   const Buffer_&     offsets,
                   std::optional   output = {});

    void set_measure(bool measure);

    [[maybe_unused]] int Export(std::ostream& os);

    [[maybe_unused]] int Import(std::istream& is);

    std::vector GetTuningSeq() const;

private:
    struct Impl;
    std::shared_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaWeight.cc
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc

#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

LlamaWeight::LlamaWeight(DataType           data_type,
                         const ModelParam&  model,
                         const EngineParam& engine_param,
                         const MoeParam&    moe_param):
    model_param_{model},
    engine_param_{engine_param},
    moe_param_{moe_param},
    hidden_units_(model.hidden_units),
    inter_size_(model.inter_size),
    vocab_size_(model.vocab_size),
    vocab_size_padded_(model.vocab_size),
    embedding_size_(model.embedding_size),
    num_layer_(model.layer_num),
    data_type_{data_type},
    weight_type_{model.weight_type},
    tp_size_(engine_param.attn_tp_size * engine_param.attn_cp_size),
    tp_rank_(engine_param.attn_tp_rank * engine_param.attn_cp_size + engine_param.attn_cp_rank)
{
    if (vocab_size_padded_ % tp_size_ != 0) {
        vocab_size_padded_ = (vocab_size_ + tp_size_ - 1) / tp_size_ * tp_size_;
        TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_);
    }
    if (embedding_size_ % tp_size_ != 0) {
        embedding_size_ = (embedding_size_ + tp_size_ - 1) / tp_size_ * tp_size_;
        TM_LOG_WARNING("pad embed size from %d to %d", embedding_size_, embedding_size_);
    }
    FT_CHECK(hidden_units_ % tp_size_ == 0);
    TM_CHECK_EQ(vocab_size_padded_ % tp_size_, 0);
    TM_CHECK_EQ(hidden_units_ % tp_size_, 0);

    stream_ = core::Stream::create();
    alloca_ = core::Allocator{stream_, false};

    initialize();
}

LlamaWeight::~LlamaWeight()
{
    release();
}

bool LlamaWeight::is_initialized() const
{
    return initialized_;
}

void LlamaWeight::initialize()
{
    core::ContextGuard guard = context();

    pre_decoder_embedding.emplace(embedding_size_, hidden_units_ / tp_size_, data_type_, false, data_type_, 1);
    post_decoder_embedding.emplace(hidden_units_, vocab_size_padded_ / tp_size_, data_type_, false, data_type_, 1);
    register_module("tok_embeddings", pre_decoder_embedding, tp_rank_);
    register_module("output", post_decoder_embedding, tp_rank_);

    /// Lower VRAM pressure on consumer grade GPUs
    /// TODO: Support token embeds on pinned host memory
    pre_decoder_embedding.weight  = empty_like(pre_decoder_embedding.weight, kCPU);
    post_decoder_embedding.weight = empty_like(post_decoder_embedding.weight, kCPU);

    decoder_layer_weights.reserve(num_layer_);
    for (int i = 0; i < num_layer_; ++i) {
        decoder_layer_weights.emplace_back(
            new LlamaDecoderLayerWeight(data_type_, i, model_param_, engine_param_, moe_param_));
        register_module("layers", *decoder_layer_weights.back(), i);
    }

    output_norm_weight = Tensor{{hidden_units_}, data_type_, kDEVICE};
    register_parameter("norm.weight", output_norm_weight);
    initialized_ = true;
}

void LlamaWeight::release()
{
    core::ContextGuard guard = context();

    pre_decoder_embedding  = {};
    post_decoder_embedding = {};
    output_norm_weight     = {};

    for (auto& p : decoder_layer_weights) {
        delete p;
    }

    decoder_layer_weights.clear();
    pinned_weights_.clear();

    // Wait for deallocations
    core::Context::stream().Sync();

    // release memory back to os
    core::Context::device_alloc()->trim(0);
    initialized_ = false;
}

void LlamaWeight::to_device(const core::Device& device)
{
    TM_CHECK(device.type == kCPU || device.type == kDEVICE);
    core::ContextGuard guard{stream_, alloca_, Allocator{kCPUpinned}};

    auto tensor_ptr_map = get_parameters();
    for (auto& [name, tensor_ptr] : tensor_ptr_map) {
        if (device.type == kCPU) {
            if (pinned_weights_.find(name) == pinned_weights_.end()) {
                pinned_weights_[name] = empty_like(*tensor_ptr, kCPUpinned);
                Copy(*tensor_ptr, pinned_weights_[name]);
            }
            *tensor_ptr = {};
        }
        else {
            TM_CHECK(pinned_weights_.find(name) != pinned_weights_.end());
            *tensor_ptr = empty_like(pinned_weights_[name], kDEVICE);
            Copy(pinned_weights_[name], *tensor_ptr);
        }
    }
    core::Context::stream().Sync();
    if (device.type == kCPU) {
        core::Context::device_alloc()->trim(0);
    }
}

core::ContextGuard LlamaWeight::context() const
{
    return core::ContextGuard{stream_, alloca_};
}

void LlamaWeight::prepare(const cudaDeviceProp& prop)
{
    core::ContextGuard guard = context();

    // Wait for the weights to be filled externally
    check_cuda_error(cudaDeviceSynchronize());

    auto stream = core::Context::stream().handle();

    for (auto& layer : decoder_layer_weights) {
        layer->prepare(prop, stream);
    }

    auto to_device = [](Tensor& x) {
        auto tmp = std::exchange(x, empty_like(x, kDEVICE));
        Copy(tmp, x);
        return tmp;
    };

    // Keep the host tensor until stream synchronization
    auto tmp_token_embeds = to_device(pre_decoder_embedding.weight);
    auto tmp_lm_head      = to_device(post_decoder_embedding.weight);

    post_decoder_embedding.prepare();

    // Block until processing is done
    check_cuda_error(cudaStreamSynchronize(stream));
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/LlamaWeight.h
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h

#pragma once

#include 

#include "src/turbomind/core/context.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

struct LlamaWeight: core::Module {
    LlamaWeight() = default;

    LlamaWeight(DataType           data_type,
                const ModelParam&  model_param,
                const EngineParam& engine_param,
                const MoeParam&    moe_param);

    ~LlamaWeight();

    LlamaWeight(const LlamaWeight&) = delete;
    LlamaWeight& operator=(const LlamaWeight&) = delete;

    void prepare(const cudaDeviceProp& prop);

    bool is_initialized() const;

    void initialize();

    void release();

    void to_device(const core::Device& device);

    core::ContextGuard context() const;

    std::vector decoder_layer_weights;

    LlamaDenseWeight pre_decoder_embedding;
    LlamaDenseWeight post_decoder_embedding;

    Tensor output_norm_weight;

private:
    const ModelParam  model_param_;
    const EngineParam engine_param_;
    const MoeParam    moe_param_;

    int hidden_units_;
    int vocab_size_;
    int vocab_size_padded_;
    int embedding_size_;
    int num_layer_;

    DataType data_type_;
    DataType weight_type_;

    std::unordered_map pinned_weights_;

    int tp_size_;  // this will follow attn tp param
    int tp_rank_;

    std::vector inter_size_;

    core::Stream    stream_;
    core::Allocator alloca_;
    bool            initialized_{false};
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/SequenceManager.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

#include "src/turbomind/kernels/attention/block.h"
#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/utils/logger.h"

// #include "dbg.h"

namespace turbomind {

template
std::string vector2string(const std::vector& data)
{
    if (data.empty()) {
        return "nil";
    }
    std::stringstream ss;

    auto it = data.begin();
    ss << *it;

    for (++it; it != data.end(); ++it) {
        ss << ", " << *it;
    }
    return ss.str();
}

SequenceManager::SequenceManager(const ModelParam& model_param,
                                 DataType          runtime_dtype,
                                 int               cache_block_seq_len,
                                 int               attn_tp_size,
                                 int               max_batch_size,
                                 double            block_count,
                                 int               chunk_size,
                                 bool              enable_prefix_caching,
                                 int               rank,
                                 int               attn_cp_size,
                                 core::Allocator   allocator,
                                 GetFreeMemSize    get_free_size):
    block_seq_len_(cache_block_seq_len), rank_(rank), attn_cp_size_(attn_cp_size)
{
    TM_CHECK_GT(attn_tp_size, 0);
    TM_CHECK_GT(cache_block_seq_len, 0);

    int cache_layer_num   = model_param.layer_num;
    int num_linear_layers = 0;
    for (const auto& type : model_param.layer_types) {
        if (type == 1) {
            --cache_layer_num;
            ++num_linear_layers;
        }
    }

    const size_t free_before = (block_count < 1. && num_linear_layers > 0) ? get_free_size() : 0;

    if (num_linear_layers > 0) {

        const int key_head_dim =
            model_param.linear_key_head_dim > 0 ? model_param.linear_key_head_dim : model_param.head_dim;
        const int value_head_dim =
            model_param.linear_value_head_dim > 0 ? model_param.linear_value_head_dim : model_param.head_dim;
        const int d_conv      = model_param.linear_conv_kernel_dim > 0 ? model_param.linear_conv_kernel_dim : 4;
        const int num_k_heads = model_param.linear_num_key_heads / attn_tp_size;
        const int num_v_heads = model_param.linear_num_value_heads / attn_tp_size;
        const int key_dim     = num_k_heads * key_head_dim;
        const int value_dim   = num_v_heads * value_head_dim;
        const int conv_dim    = key_dim * 2 + value_dim;

        TM_CHECK_GT(max_batch_size, 0);
        pooled_conv_states_ = {{max_batch_size, num_linear_layers, d_conv, conv_dim}, model_param.data_type, kDEVICE};
        pooled_recurrent_states_ = {{max_batch_size, num_linear_layers, num_v_heads, key_head_dim, value_head_dim},
                                    model_param.linear_state_dtype,
                                    kDEVICE};

        free_linear_state_slots_.reserve(max_batch_size);
        for (int slot = max_batch_size - 1; slot >= 0; --slot) {
            free_linear_state_slots_.push_back(slot);
        }
        TM_LOG_INFO("[SeqMgr] linear-state slot pool initialized: %d slots", max_batch_size);
        const auto   conv_one      = pooled_conv_states_.slice(0, 1).squeeze(0);
        const auto   recurrent_one = pooled_recurrent_states_.slice(0, 1).squeeze(0);
        const double mb            = 1.0 / (1024.0 * 1024.0);
        TM_LOG_INFO("[SeqMgr] linear-state per slot: conv %.2f MB + recurrent %.2f MB = %.2f MB",
                    conv_one.byte_size() * mb,
                    recurrent_one.byte_size() * mb,
                    (conv_one.byte_size() + recurrent_one.byte_size()) * mb);
        TM_LOG_INFO("[SeqMgr] linear-state combined total: %.2f MB",
                    (pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size()) * mb);
    }

    const int  dbits        = byte_size(runtime_dtype, 8);
    const auto quant_policy = model_param.quant_policy;
    const int  elem_bits    = quant_policy ? quant_policy : dbits;

    BlockConfig block_config{
        (int)model_param.head_dim,
        (int)model_param.kv_head_num / attn_tp_size,
        cache_block_seq_len,
        elem_bits == dbits ? 0 : dbits,
        elem_bits,
        model_param.head_dim == 576,  // share kv
    };

    block::Layout layout{block_config};
    // dump(layout);

    size_t block_size = layout.block_size(cache_layer_num);

    if (num_linear_layers > 0 && block_count < 1.) {
        const size_t linear_bytes = pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size();
        const size_t target_bytes = static_cast(free_before * block_count);
        TM_LOG_INFO("[SeqMgr] Adjusting block_count: free_before %.2f MB, linear %.2f MB, target %.2f MB",
                    free_before / (1024. * 1024.),
                    linear_bytes / (1024. * 1024.),
                    target_bytes / (1024. * 1024.));
        if (target_bytes <= linear_bytes) {
            TM_LOG_ERROR("[SeqMgr] Linear-state memory (%.2f MB) >= cache budget (%.2f MB). ",
                         linear_bytes / (1024. * 1024.),
                         target_bytes / (1024. * 1024.));
            TM_CHECK(0)
                << "Please decrease max_batch_size to reduce total linear state size or increase cache_max_entry_count.";
        }
        const size_t cache_bytes = target_bytes - linear_bytes;
        block_count              = static_cast(cache_bytes) / static_cast(block_size);
        TM_LOG_INFO("[SeqMgr] Adjusted block_count to %.0f", block_count);
    }

    block_manager_ = std::make_shared(block_size, block_count, chunk_size, allocator, get_free_size);

    if (enable_prefix_caching) {
        block_trie_ = std::make_shared(block_config.block_len_, block_manager_);
    }
    TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled");
}

const Sequence* SequenceManager::Create(uint64_t id)
{
    Sequence sequence{id};
    auto     it = sequences_.find(id);
    if (it != sequences_.end()) {
        if (rank_ == 0) {
            TM_LOG_WARNING("[SeqMgr][Create] Removing conflicting ID %llu", id);
        }
        Erase(it);
    }
    it = sequences_.emplace_hint(it, id, std::move(sequence));
    if (rank_ == 0) {
        TM_LOG_INFO("[SeqMgr][Create] ID %llu", id);
    }
    return &it->second;
}

const Sequence* SequenceManager::Get(uint64_t id)
{
    if (auto it = sequences_.find(id); it != sequences_.end()) {
        return &it->second;
    }
    return nullptr;
}

bool SequenceManager::Contains(uint64_t id)
{
    return sequences_.find(id) != sequences_.end();
}

void SequenceManager::Erase(std::map::iterator& it)
{
    auto& seq = it->second;
    if (seq.status == Sequence::kCached) {
        const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
        seq.blocks.resize(count);
    }
    else {
        UpdateAndSetUnlock(seq);
    }
    // if prefix cache enabled, blocks will be shared by sequences, cannot be freed immediately
    if (!block_trie_) {
        freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
    }
    ReleaseLinearStateSlot(seq);
    it = sequences_.erase(it);
}

bool SequenceManager::Erase(uint64_t id)
{
    if (auto it = sequences_.find(id); it != sequences_.end()) {
        Erase(it);
        return true;
    }
    return false;
}

void SequenceManager::AcquireLinearStateSlot(const Sequence& sequence)
{
    if (!pooled_recurrent_states_) {
        return;
    }

    auto& seq = const_cast(sequence);

    auto slot_it = seq_to_linear_state_slot_.find(seq.id);
    if (slot_it != seq_to_linear_state_slot_.end()) {
        const int slot       = slot_it->second;
        seq.conv_states      = pooled_conv_states_.slice(slot).squeeze(0);
        seq.recurrent_states = pooled_recurrent_states_.slice(slot).squeeze(0);
        return;
    }

    TM_CHECK(!free_linear_state_slots_.empty()) << "No free linear-state slot for sequence " << seq.id
                                                << ", max_batch_size=" << pooled_recurrent_states_.shape(0);

    const int slot = free_linear_state_slots_.back();
    free_linear_state_slots_.pop_back();
    seq_to_linear_state_slot_.emplace(seq.id, slot);

    seq.conv_states              = pooled_conv_states_.slice(slot).squeeze(0);
    seq.recurrent_states         = pooled_recurrent_states_.slice(slot).squeeze(0);
    seq.linear_states_need_reset = true;
}

void SequenceManager::ReleaseLinearStateSlot(const Sequence& sequence)
{
    if (!pooled_recurrent_states_) {
        return;
    }

    auto& seq = const_cast(sequence);

    if (auto slot_it = seq_to_linear_state_slot_.find(seq.id); slot_it != seq_to_linear_state_slot_.end()) {
        free_linear_state_slots_.push_back(slot_it->second);
        seq_to_linear_state_slot_.erase(slot_it);
    }
    seq.conv_states              = {};
    seq.recurrent_states         = {};
    seq.linear_states_need_reset = false;
}

void SequenceManager::InvalidateStatesAndCache(const Sequence& sequence)
{
    InvalidateStatesAndCache(sequence, freed_);
}

void SequenceManager::InvalidateStatesAndCache(const Sequence& sequence, BlockIds& freed_blocks)
{
    auto& seq = const_cast(sequence);
    if (seq.status != Sequence::kCached) {
        UpdateAndSetUnlock(seq);
    }
    freed_blocks.insert(freed_blocks.end(), seq.blocks.begin(), seq.blocks.end());

    seq.blocks.clear();
    seq.block_unique_ids.clear();
    seq.input_length = 0;
    seq.cache_len    = 0;
    ReleaseLinearStateSlot(seq);
}

void SequenceManager::CachePrompt(const Sequences& sequences, int active_size)
{
    if (!block_trie_) {
        return;
    }

    for (int i = 0; i < active_size; ++i) {
        if (auto& seq = *sequences[i]; !seq.prompt.empty()) {
            const auto& [block_ids, unique_ids] = block_trie_->Cache(seq, seq.prompt);
            if (rank_ == 0) {
                // clang-format off
                TM_LOG_INFO("[SeqMgr][CachePrompt] ID %llu, cached blocks %d, tokens %d", seq.id,
                            (int)block_ids.size(), (int)seq.prompt.size());
                TM_LOG_DEBUG("[SeqMgr][CachePrompt] ID %llu, cached block_ids %s, unique_ids %s", seq.id,
                             vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());
                // clang-format on
            }
            if (seq.cache_len >= seq.prompt.size()) {
                seq.prompt.clear();
            }
        }
    }
}

void SequenceManager::CacheGeneration(const Sequence& seq)
{
    if (!block_trie_) {
        return;
    }

    const auto& [block_ids, unique_ids] = block_trie_->Cache(seq, seq.tokens);

    if (rank_ == 0) {
        // clang-format off
        TM_LOG_INFO("[SeqMgr][CacheGeneration] ID %llu, cached blocks %d, tokens %d",
                    seq.id, (int)block_ids.size(), (int)seq.tokens.size());
        TM_LOG_DEBUG("[SeqMgr][CacheGeneration] ID %llu, cached block_ids %s, unique_ids %s", seq.id,
                     vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());
        // clang-format on
    }
}

void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
{
    BlockIds valid_blocks;
    BlockIds freed_blocks;
    for (const auto& p : sequences) {
        auto& seq = const_cast(*p);
        if (seq.status != Sequence::kCached) {
            continue;
        }
        TM_CHECK_EQ(seq.blocks.size(), seq.block_unique_ids.size());
        // Verify cache blocks that may be invalidated
        const int original_count = seq.blocks.size();
        const int count          = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
        seq.blocks.resize(count);
        seq.block_unique_ids.resize(count);

        const bool has_linear_states = static_cast(seq.recurrent_states);
        if (has_linear_states && count < original_count) {
            InvalidateStatesAndCache(seq, freed_blocks);
            // This request can still continue in the current scheduling round.
            // Rebind a slot immediately so GatedDeltaNetLayer::Setup always sees
            // valid linear-state views.
            AcquireLinearStateSlot(seq);
            continue;
        }

        valid_blocks.insert(valid_blocks.end(), seq.blocks.begin(), seq.blocks.end());
        seq.cache_len = std::min(seq.cache_len, seq.blocks.size() * block_seq_len_);
        seq.status    = Sequence::kLocked;
    }
    if (!freed_blocks.empty()) {
        block_manager_->Free(freed_blocks);
    }
    block_manager_->Lock(valid_blocks);
}

void SequenceManager::CommitUnlockAndFree()
{
    if (!unlocked_.empty()) {
        block_manager_->Unlock(unlocked_);
        unlocked_.clear();
    }

    if (!freed_.empty()) {
        block_manager_->Free(freed_);
        freed_.clear();
    }
}

void SequenceManager::UpdateAndSetUnlock(const Sequence& sequence)
{
    TM_CHECK_NE(sequence.status, Sequence::kCached);
    auto& seq = const_cast(sequence);
    block_manager_->Touch(seq.blocks);
    unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
    seq.status = Sequence::kCached;
}

namespace {

struct Schedule {
    int free;
    int cached;

    int allocate{};
    int evict{};
    int preempt{};

    int last;

    int max_fwd_tokens;
    int max_tmp_tokens;

    Sequences        active;
    std::vector block_counts;
    Sequences        inactive;
    Sequences        victims;

    Schedule(Snapshot snapshot, int size, int max_fwd_tokens, int max_tmp_tokens):
        free{snapshot.free},
        cached{snapshot.cached},
        last{size},
        max_fwd_tokens{max_fwd_tokens},
        max_tmp_tokens{max_tmp_tokens},
        use_count_{std::move(snapshot.use_count)},
        unlocked_(size),  // ! This is a vector, DO NOT brace initialize it
        it_{size}
    {
    }

    int Unlock(const Sequences& seqs, int vidx)
    {
        while (vidx < it_) {
            const auto& blocks = seqs[--it_]->blocks;
            int         count  = 0;
            for (const auto& bid : blocks) {
                count += static_cast(--use_count_[bid] == 0);
            }
            unlocked_[it_] = count;
        }
        return unlocked_[vidx];
    }

private:
    std::vector use_count_;
    std::vector unlocked_;
    int              it_;
};

template
std::ostream& operator<<(std::ostream& os, const std::vector& v)
{
    os << "[";
    for (int i = 0; i < v.size(); ++i) {
        os << (i ? "," : "") << v[i];
    }
    os << "]";
    return os;
}

std::ostream& operator<<(std::ostream& os, const Schedule& s)
{
    os << "free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate << ", evict=" << s.evict
       << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive;
    return os;
}

struct Transaction {
    int index_;
    int block_count_;
    int input_len_;
    int temp_len_;

    int allocate_{};
    int evict_{};
    int preempt_{};

    Sequences victims_;

    const Sequences& sequences_;
    Schedule&        schedule_;

    explicit Transaction(
        const Sequences& sequences, int index, int block_count, int input_len, int temp_len, Schedule& sched):
        index_{index},
        block_count_{block_count},
        input_len_{input_len},
        temp_len_{temp_len},
        sequences_{sequences},
        schedule_{sched}
    {
    }

    void Process()
    {
        if (schedule_.max_fwd_tokens > 0 && schedule_.max_tmp_tokens >= temp_len_) {
            int count = block_count_;

            int tmp = std::min(schedule_.free, count);
            count -= tmp;
            allocate_ += tmp;

            tmp = std::min(schedule_.cached, count);
            count -= tmp;
            evict_ += tmp;

            for (int vidx = schedule_.last - 1; count && vidx > index_; --vidx) {
                if (sequences_[vidx]->status == Sequence::kCached) {
                    continue;
                }
                victims_.push_back(sequences_[vidx]);
                preempt_ += schedule_.Unlock(sequences_, vidx);

                if (count <= preempt_) {
                    evict_ += count;
                    count -= count;
                    schedule_.last = vidx;  // ! modifiying `sched_.last` is part of commit
                    break;
                }
            }
            if (count == 0) {
                return Commit();
            }
        }

        const_cast(sequences_[index_])->input_length = 0;
        schedule_.inactive.push_back(sequences_[index_]);
    }

    void Commit()
    {
        // update available resources
        schedule_.free -= allocate_;
        TM_CHECK_GE(schedule_.free, 0);
        schedule_.cached += preempt_;
        schedule_.cached -= evict_;
        TM_CHECK_GE(schedule_.cached, 0);

        // update scheduled operations
        schedule_.allocate += allocate_;
        schedule_.evict += evict_;
        schedule_.preempt += preempt_;
        schedule_.victims.insert(schedule_.victims.end(), victims_.begin(), victims_.end());

        // update active sequences
        schedule_.active.push_back(sequences_[index_]);
        schedule_.block_counts.push_back(block_count_);

        input_len_ = std::min(input_len_, schedule_.max_fwd_tokens);
        schedule_.max_fwd_tokens -= input_len_;
        const_cast(sequences_[index_])->input_length = input_len_;

        schedule_.max_tmp_tokens -= temp_len_;
    }
};

std::ostream& operator<<(std::ostream& os, const Transaction& trans)
{
    os << "index=" << trans.index_ << ", block_count=" << trans.block_count_ << ", allocate=" << trans.allocate_
       << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_ << ", victims=" << trans.victims_;
    return os;
}

}  // namespace

template
static void SortByKey(const std::vector& keys, std::vector&... vals)
{
    std::vector idxs(keys.size());
    std::iota(idxs.begin(), idxs.end(), 0);
    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return keys[i] < keys[j]; });
    auto reorder = [&](auto& xs) {
        std::remove_reference_t ys(xs.size());
        for (size_t i = 0; i < xs.size(); ++i) {
            ys[i] = xs[idxs[i]];
        }
        xs.swap(ys);
    };
    (reorder(vals), ...);
}

std::vector SequenceManager::CountRequiredBlocks(const Sequences&        sequences,
                                                      const std::vector& context_length)
{
    std::vector required(sequences.size());
    for (int i = 0; i < sequences.size(); ++i) {
        int length  = (context_length[i] + attn_cp_size_ - 1) / attn_cp_size_;
        int count   = (length + block_seq_len_ - 1) / block_seq_len_ - static_cast(sequences[i]->blocks.size());
        required[i] = std::max(0, count);
    }
    return required;
}

void SequenceManager::AssignAndActivate(const Sequences&        sequences,  //
                                        const std::vector& counts,
                                        const BlockIds&         blocks,
                                        const UniqueIds&        unique_ids)
{
    TM_CHECK_EQ(sequences.size(), counts.size());
    int first = 0;
    for (int i = 0; i < sequences.size(); ++i) {
        auto& s     = const_cast(*sequences[i]);
        auto  count = counts[i];
        int   last  = first + count;
        TM_CHECK_LE(last, blocks.size());
        s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last);
        s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last);
        s.status = Sequence::kActive;
        first    = last;
    }
}

void SequenceManager::PrefixMatch(Sequences& sequences, const std::vector& alpha)
{
    if (!block_trie_) {
        return;
    }

    for (int i = 0; i < sequences.size(); i++) {

        auto& seq = const_cast(*sequences[i]);

        /// TODO: Is there a way to exploit the alpha[i] != 0 case?
        if (alpha[i] != 0 || seq.cache_len >= seq.prompt.size()) {
            continue;
        }

        const auto& [block_ids, unique_ids] = block_trie_->Match(seq);

        if (rank_ == 0) {
            // clang-format off
            TM_LOG_INFO("[SeqMgr][match] ID %llu, hit blocks %d, cache_len %d", seq.id, (int)block_ids.size(), seq.cache_len);
            TM_LOG_DEBUG("[SeqMgr][match] ID %llu, hit block_ids %s, unique_ids %s", seq.id,
                         vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());
            // clang-format on
        }

        /// TODO: `Unlock` and `Lock` can't be batched because there may be repeated blocks between sequences
        if (const int offset = seq.cache_len / block_seq_len_; offset < block_ids.size()) {
            if (BlockIds tail{seq.blocks.begin() + offset, seq.blocks.end()}; !tail.empty()) {
                block_manager_->Unlock(tail);
                seq.blocks.resize(offset);
                seq.block_unique_ids.resize(offset);
            }
            seq.blocks.insert(seq.blocks.end(), block_ids.begin() + offset, block_ids.end());
            seq.block_unique_ids.insert(seq.block_unique_ids.end(), unique_ids.begin() + offset, unique_ids.end());
            seq.cache_len = seq.blocks.size() * block_seq_len_;
            block_manager_->Lock({block_ids.begin() + offset, block_ids.end()});
        }

        if (rank_ == 0) {
            // clang-format off
            TM_LOG_INFO("[SeqMgr][match] ID %llu, after matching, blocks %d, cache_len %d",
                        seq.id, seq.blocks.size(), seq.cache_len);
            TM_LOG_DEBUG("[SeqMgr][match] ID %llu, after matching, block_ids %s, unique_ids %s", seq.id,
                         vector2string(seq.blocks).c_str(), vector2string(seq.block_unique_ids).c_str());
            // clang-format on
        }
    }
}

auto SequenceManager::Materialize(Sequences             sequences,
                                  std::vector      context_length,
                                  std::vector      alpha,
                                  std::vector priorities,
                                  int                   max_fwd_tokens,
                                  int                   max_tmp_tokens) -> Outcome
{
    ////////////////////////////////////////////////////////////////////////////////
    /// Schedule the assignment of blocks to sequences

    // process deferred unlock and free operations
    CommitUnlockAndFree();

    SortByKey(priorities, sequences, context_length, alpha);

    // Verify and lock cache sequences to avoid their blocks being evicted unnoticed
    // the blocks can still be preempted later
    VerifyAndLockCached(sequences);

    PrefixMatch(sequences, alpha);

    std::vector required = CountRequiredBlocks(sequences, context_length);

    Schedule schedule(block_manager_->TakeSnapshot(), sequences.size(), max_fwd_tokens, max_tmp_tokens);

    // `schedule.last` is decreasing in the loop
    for (int i = 0; i < schedule.last; ++i) {
        auto&     s         = *sequences[i];
        const int input_len = context_length[i] - alpha[i] - s.cache_len;
        // sanity check
        TM_CHECK_GT(input_len, 0) << "Logical error: " << context_length[i] << " " << alpha[i] << " " << s.cache_len
                                  << " " << s.status;
        // temp buffer for flatten KV cache
        const int temp_len = (input_len > 1 || s.status != Sequence::kActive) ? context_length[i] : 0;
        Transaction{sequences, i, required[i], input_len, temp_len, schedule}.Process();
    }

    // mark remaining sequences invalid
    for (int i = schedule.last; i < sequences.size(); ++i) {
        schedule.inactive.push_back(sequences[i]);
    }

    ////////////////////////////////////////////////////////////////////////////////
    /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)

    // combine allocate and evict since evicted blocks are reused by allocation
    schedule.allocate += schedule.evict;

    // if (schedule.allocate) {
    //     dbg(*block_manager_);
    // }

    Outcome outcome{};
    outcome.allocation = schedule.allocate;
    outcome.swap_in    = std::count_if(schedule.active.begin(), schedule.active.end(), [](auto p) {
        // if (p->status != Sequence::kActive) {
        //     dbg(*p);
        // }
        return p->status != Sequence::kActive;
    });
    outcome.swap_out = std::count_if(schedule.inactive.begin(), schedule.inactive.end(), [](auto p) {
        // if (p->status == Sequence::kActive) {
        //     dbg(*p);
        // }
        return p->status == Sequence::kActive;
    });

    // release preempted blocks -> cached
    if (!schedule.victims.empty()) {
        TM_LOG_INFO("[SeqMgr] #victim: %d", (int)schedule.victims.size());
        for (const auto& p : schedule.victims) {
            UpdateAndSetUnlock(*p);
        }
        CommitUnlockAndFree();
    }

    // evict cached blocks -> free
    if (schedule.evict) {
        block_manager_->Evict(schedule.evict);
    }

    // allocate & assign blocks
    {
        BlockIds  block_ids;
        UniqueIds unique_ids;
        if (schedule.allocate) {
            std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate);
        }
        AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids);
    }

    // active -> locked
    for (const auto& p : schedule.inactive) {
        if (p->status == Sequence::kActive) {
            const_cast(p)->status = Sequence::kLocked;
        }
    }

    // TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d",
    //              block_manager_->active_count(),
    //              block_manager_->cached_count(),
    //              block_manager_->free_count());
    if (block_trie_) {
        block_trie_->Verify();
    }

    return outcome;
}

std::tuple SequenceManager::seq_stats() const noexcept
{
    int total  = static_cast(sequences_.size());
    int active = 0;
    int cached = 0;
    for (const auto& p : sequences_) {
        if (p.second.status == Sequence::kActive) {
            ++active;
        }
        else if (p.second.status == Sequence::kCached) {
            ++cached;
        }
    }
    return std::make_tuple(total, active, cached);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/SequenceManager.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/core.h"

#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/models/llama/BlockTrie.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

struct Sequence {

    enum Status
    {
        kCached = 0,
        kLocked,
        kActive
    };

    uint64_t id;
    Status   status = kCached;

    BlockIds  blocks;
    UniqueIds block_unique_ids;

    int input_length = 0;  // the number of tokens to be processed in each forward iter

    mutable std::vector prompt;

    mutable std::vector tokens;  // update by user or when the sequence is finished

    mutable int cache_len = 0;

    // additional data kept round-to-round
    mutable std::vector random_state;  // update by user

    mutable float rope_theta = 0.f;

    // embedding data
    mutable std::vector input_embeds;
    mutable std::vector    input_embeds_offsets;

    // Gated DeltaNet linear attention persistent states (e.g. Qwen3.5-MoE).
    // Allocated on first request, preserved across requests for the same session,
    // and freed automatically when the sequence is erased from the SequenceManager.
    //   conv_states:      (num_linear_layers, conv_dim, d_conv) — per-channel rolling conv history
    //   recurrent_states: (num_linear_layers, num_v_heads, key_head_dim, value_head_dim) — SSM state
    mutable Tensor conv_states;
    mutable Tensor recurrent_states;
    mutable bool   linear_states_need_reset = false;

    explicit Sequence(uint64_t _id): id(_id) {}

    friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
};

using Sequences = std::vector;

inline std::ostream& operator<<(std::ostream& os, const Sequence& seq)
{
    os << "id=" << seq.id << ", status=" << seq.status << ", token_count=" << seq.tokens.size()
       << ", block_count=" << seq.blocks.size() << ", cache_len=" << seq.cache_len
       << ", random_state_size=" << seq.random_state.size() << ", input_length=" << seq.input_length;
    return os;
}

class SequenceManager {
public:
    // clang-format off
    struct BlockConfig {
        int head_dim_;
        int head_num_;
        int block_len_;
        int t_bits_;
        int q_bits_;
        bool share_kv_;
        int t_bits() const { return t_bits_; }
        int q_bits() const { return q_bits_; }
        int head_dim() const { return head_dim_; }
        int head_num() const { return head_num_; }
        int block_len() const { return block_len_; }
        bool is_share_kv() const { return share_kv_; }
    };
    // clang-format on

    explicit SequenceManager(const ModelParam& model_param,
                             DataType          runtime_dtype,
                             int               cache_block_seq_len,
                             int               attn_tp_size,
                             int               max_batch_size,
                             double            block_count,
                             int               chunk_size,
                             bool              enable_prefix_caching,
                             int               rank,
                             int               attn_cp_size,
                             core::Allocator   allocator,
                             GetFreeMemSize    get_free_size);

    SequenceManager(const SequenceManager&)     = delete;
    SequenceManager(SequenceManager&&) noexcept = default;

    [[nodiscard]] const Sequence* Create(uint64_t id);

    [[nodiscard]] const Sequence* Get(uint64_t id);

    [[nodiscard]] bool Contains(uint64_t id);

    [[nodiscard]] bool Erase(uint64_t id);

    void AcquireLinearStateSlot(const Sequence& seq);

    void ReleaseLinearStateSlot(const Sequence& seq);

    void InvalidateStatesAndCache(const Sequence& seq);

    void UpdateAndSetUnlock(const Sequence& seq);

    struct Outcome {
        int allocation;
        int swap_in;
        int swap_out;
    };

    using AdjustInputCount = std::function&)>;

    //                50       1       0       50
    //    context = seq_len + beta = cache + alpha + input
    //     alpha' = input
    //      beta' = int(is_gen)
    //  -----------------------------------
    //   seq_len += output
    //     cache += input + output - 1  or  cache = seq_len - 1

    [[maybe_unused]] Outcome Materialize(Sequences             sequences,
                                         std::vector      context_length,
                                         std::vector      alpha,
                                         std::vector priorities,
                                         int                   max_fwd_tokens,
                                         int                   max_tmp_tokens);

    /** @brief cache the input prompt tokens of each seq in sequences[0:active_size-1]
     *
     * @param sequences The sequence list
     * @param active_size the number of active sequences in the list
     */
    void CachePrompt(const Sequences& sequences, int active_size);

    /** @brief cache the generated tokens of a given sequence
     *
     * @param sequence the given sequence
     *
     * @note This function can only be called after the sequence finish generation
     * and all tokens including the prompt tokens and generated tokens have been put to
     * `seq.tokens`
     */
    void CacheGeneration(const Sequence& sequence);

    [[nodiscard]] void* GetBlockPtr(int block_id)
    {
        return block_manager_->block(block_id).data;
    }

    int max_block_count() const noexcept
    {
        return block_manager_->max_block_count();
    }

    int total_count() const noexcept
    {
        return block_manager_->total_count();
    }

    int active_count() const noexcept
    {
        return block_manager_->active_count();
    }

    int free_count() const noexcept
    {
        return block_manager_->free_count();
    }

    int cached_count() const noexcept
    {
        return block_manager_->cached_count();
    }

    // return #total_seq, #active_seq, #cached_seq
    std::tuple seq_stats() const noexcept;

private:
    void Erase(std::map::iterator& it);

    void CommitUnlockAndFree();

    void InvalidateStatesAndCache(const Sequence& seq, BlockIds& freed_blocks);

    void VerifyAndLockCached(const Sequences& sequences);

    std::vector CountRequiredBlocks(const Sequences&        sequences,  //
                                         const std::vector& context_length);

    static void AssignAndActivate(const Sequences&        sequences,  //
                                  const std::vector& counts,
                                  const BlockIds&         blocks,
                                  const UniqueIds&        unique_ids);

    void PrefixMatch(Sequences& sequences, const std::vector& alpha);

private:
    int block_seq_len_;
    int rank_;
    int attn_cp_size_;

    // Use `std::map` to avoid reference invalidation
    std::map sequences_;

    std::shared_ptr block_manager_;
    std::shared_ptr    block_trie_;

    Tensor                            pooled_conv_states_;
    Tensor                            pooled_recurrent_states_;
    std::vector                  free_linear_state_slots_;
    std::unordered_map seq_to_linear_state_slot_;

    BlockIds unlocked_;
    BlockIds freed_;
};

inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)
{
    os << "allocation: " << oc.allocation << ", swap-in: " << oc.swap_in << ", swap-out: " << oc.swap_out;
    return os;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/bench_conv1d_silu.cc
================================================

#include 
#include 
#include 
#include 
#include 
#include 

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/models/llama/gated_delta_net_kernels.h"

using namespace turbomind;
using namespace turbomind::core;

struct Args {
    int      batch_size  = 32;
    int      seq_len     = 1;
    int      num_v_heads = 64;
    int      num_k_heads = 16;
    int      d_conv      = 4;
    int      warmup      = 10;
    int      iters       = 100;
    DataType dtype       = kFloat16;

    static DataType ParseDtype(const char* s)
    {
        if (strcmp(s, "half") == 0 || strcmp(s, "fp16") == 0)
            return kFloat16;
        if (strcmp(s, "bf16") == 0)
            return kBfloat16;
        fprintf(stderr, "Unknown dtype: %s (expected half/fp16/bf16)\n", s);
        exit(1);
    }

    static Args Parse(int argc, char** argv)
    {
        Args a;
        for (int i = 1; i < argc; i += 2) {
            if (i + 1 >= argc) {
                fprintf(stderr, "Missing value for %s\n", argv[i]);
                exit(1);
            }
            if (strcmp(argv[i], "--batch_size") == 0)
                a.batch_size = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--seq_len") == 0)
                a.seq_len = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--num_v_heads") == 0)
                a.num_v_heads = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--num_k_heads") == 0)
                a.num_k_heads = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--d_conv") == 0)
                a.d_conv = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--warmup") == 0)
                a.warmup = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--iters") == 0)
                a.iters = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--dtype") == 0)
                a.dtype = ParseDtype(argv[i + 1]);
            else {
                fprintf(stderr, "Unknown arg: %s\n", argv[i]);
                exit(1);
            }
        }
        return a;
    }

    void Print() const
    {
        printf("batch_size=%d  seq_len=%d  num_v_heads=%d  num_k_heads=%d  d_conv=%d  "
               "warmup=%d  iters=%d  dtype=%s\n",
               batch_size,
               seq_len,
               num_v_heads,
               num_k_heads,
               d_conv,
               warmup,
               iters,
               to_string(dtype));
    }
};

static float
benchmark_kernel(const char* name, std::function launch, cudaStream_t stream, int warmup, int iters)
{
    for (int i = 0; i < warmup; ++i)
        launch();
    cudaStreamSynchronize(stream);

    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    cudaEventRecord(start, stream);
    for (int i = 0; i < iters; ++i)
        launch();
    cudaEventRecord(stop, stream);
    cudaEventSynchronize(stop);

    float ms = 0;
    cudaEventElapsedTime(&ms, start, stop);
    float avg_ms = ms / iters;

    printf("  %-45s  %8.3f ms (avg over %d iters)\n", name, avg_ms, iters);

    cudaEventDestroy(start);
    cudaEventDestroy(stop);
    return avg_ms;
}

// CPU reference for depthwise causal conv1d + SiLU.
//
//   y(t, c) = SiLU( sum_{d=0}^{D-1} w(d, c) * x(t - D + 1 + d, c) )
//
// where x(i, c) falls back to the conv state for i < 0 (history from the
// previous inference step).  After the sequence, the state is updated to the
// last D-1 inputs.
//
// State is a ring buffer: slot j holds the input written at absolute time t
// where t % d_conv == j.  history_len = k_offsets[b+1] - k_offsets[b] - seq_len.
//
// Weight layout: [d_conv, conv_dim].  State layout: [d_conv, conv_dim] per batch.
template
static void cpu_conv1d_silu(T*         h_out,
                            const T*   h_in,
                            const T*   h_weight,
                            T*         h_state,
                            const int* h_q_offsets,
                            const int* h_k_offsets,
                            int        batch_size,
                            int        conv_dim,
                            int        d_conv,
                            int        in_stride)
{
    for (int b = 0; b < batch_size; ++b) {
        const int seq_off     = h_q_offsets[b];
        const int seq_len     = h_q_offsets[b + 1] - seq_off;
        const int history_len = (h_k_offsets[b + 1] - h_k_offsets[b]) - seq_len;
        T*        state       = h_state + b * d_conv * conv_dim;

        auto x = [&](int i, int c) -> float {
            if (i >= 0)
                return static_cast(h_in[(seq_off + i) * in_stride + c]);
            int ring_idx = ((history_len + i) % d_conv + d_conv) % d_conv;
            return static_cast(state[ring_idx * conv_dim + c]);
        };

        for (int t = 0; t < seq_len; ++t) {
            for (int c = 0; c < conv_dim; ++c) {
                float acc = 0.f;
                for (int d = 0; d < d_conv; ++d)
                    acc += static_cast(h_weight[d * conv_dim + c]) * x(t - d_conv + 1 + d, c);
                h_out[(seq_off + t) * conv_dim + c] = static_cast(acc / (1.f + std::exp(-acc)));
            }
        }

        for (int d = 0; d < d_conv; ++d) {
            int src = seq_len - d_conv + d;
            if (src >= 0) {
                int ring_d = (history_len + src) % d_conv;
                for (int c = 0; c < conv_dim; ++c)
                    state[ring_d * conv_dim + c] = h_in[(seq_off + src) * in_stride + c];
            }
        }
    }
}

int main(int argc, char** argv)
{
    auto args = Args::Parse(argc, argv);
    args.Print();

    constexpr int kHeadDim = 128;

    const int num_v_heads = args.num_v_heads;
    const int num_k_heads = args.num_k_heads;
    const int batch_size  = args.batch_size;
    const int seq_len     = args.seq_len;
    const int d_conv      = args.d_conv;

    const int k_dim     = num_k_heads * kHeadDim;
    const int v_dim     = num_v_heads * kHeadDim;
    const int conv_dim  = 2 * k_dim + v_dim;
    const int in_stride = conv_dim + v_dim + 2 * num_v_heads;
    const int total_tok = batch_size * seq_len;

    const int      conv_state_size = conv_dim * d_conv;
    const DataType dtype           = args.dtype;
    const auto     elem_bytes      = byte_size(dtype);

    auto         stream = Stream::create();
    ContextGuard ctx{stream, Allocator{kCPU}, Allocator{kCPUpinned}, Allocator{stream, false}};
    cudaStream_t cu_stream = stream.handle();

    int sm_count = 1;
    {
        int device = 0;
        cudaGetDevice(&device);
        cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
    }
    Buffer_ work_counter{1, kDEVICE};

    printf("\nconv_dim=%d  d_conv=%d  in_stride=%d  total_tokens=%d\n", conv_dim, d_conv, in_stride, total_tok);

    Tensor all_proj{Layout{{total_tok, in_stride}}, dtype, kDEVICE};
    Tensor weight{Layout{{d_conv, conv_dim}}, dtype, kDEVICE};

    Tensor out_ref{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};
    Tensor out_v2{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};

    Tensor state_ref{Layout{{batch_size, conv_state_size}}, dtype, kDEVICE};
    Tensor state_v2{Layout{{batch_size, conv_state_size}}, dtype, kDEVICE};

    Buffer_ state_ptrs_v2_host{batch_size, kCPUpinned};
    Buffer_ state_ptrs_v2_dev{batch_size, kDEVICE};

    Buffer_ q_offsets_host{batch_size + 1, kCPUpinned};
    Buffer_ q_offsets_dev{batch_size + 1, kDEVICE};
    Buffer_ k_offsets_dev{batch_size + 1, kDEVICE};

    RNG rng;
    rng.UniformFloat(all_proj, 0.1f);
    rng.UniformFloat(weight, 0.1f);

    for (int i = 0; i <= batch_size; ++i)
        q_offsets_host.data()[i] = i * seq_len;
    Copy(q_offsets_host, batch_size + 1, q_offsets_dev);
    Copy(q_offsets_host, batch_size + 1, k_offsets_dev);  // no history in bench

    for (int i = 0; i < batch_size; ++i) {
        state_ptrs_v2_host.data()[i] = (char*)state_v2.raw_data() + i * conv_state_size * elem_bytes;
    }
    Copy(state_ptrs_v2_host, batch_size, state_ptrs_v2_dev);
    stream.Sync();

    auto launch_v2 = [&] {
        invokeFusedConv1dSiLU(out_v2,
                              all_proj,
                              weight,
                              Tensor{},
                              state_ptrs_v2_dev,
                              q_offsets_dev,
                              k_offsets_dev,
                              batch_size,
                              0,
                              sm_count,
                              work_counter.data(),
                              cu_stream);
    };

    // === Benchmark ===
    printf("\n=== Benchmark ===\n");
    float v2_ms = benchmark_kernel("v2   (templated + vectorized)", launch_v2, cu_stream, args.warmup, args.iters);

    // === Bandwidth ===
    {
        double in_bytes         = (double)total_tok * conv_dim * elem_bytes;
        double out_bytes        = (double)total_tok * conv_dim * elem_bytes;
        double wt_bytes         = (double)conv_dim * d_conv * elem_bytes;
        int    state_write_rows = std::min(seq_len, d_conv);
        double state_rd_bytes   = (double)batch_size * d_conv * conv_dim * elem_bytes;
        double state_wr_bytes   = (double)batch_size * state_write_rows * conv_dim * elem_bytes;
        double state_bytes      = state_rd_bytes + state_wr_bytes;
        double total_bytes      = in_bytes + out_bytes + wt_bytes + state_bytes;

        printf("\n=== Bandwidth ===\n");
        printf("  in:     %.1f MB\n", in_bytes / 1e6);
        printf("  out:    %.1f MB\n", out_bytes / 1e6);
        printf("  weight: %.3f MB\n", wt_bytes / 1e6);
        printf("  state:  %.1f MB  (R %.1f + W %.1f)\n", state_bytes / 1e6, state_rd_bytes / 1e6, state_wr_bytes / 1e6);
        printf("  total:  %.1f MB\n", total_bytes / 1e6);
        printf("  v2  BW: %.1f GB/s\n", total_bytes / (v2_ms * 1e6));
    }

    // === Cross-comparison (correctness): CPU ref vs GPU v2 ===
    printf("\n=== Cross-comparison (CPU ref vs GPU v2) ===\n");

    Clear(state_ref);
    Clear(state_v2);
    Clear(out_ref);
    Clear(out_v2);
    stream.Sync();

    // Run GPU kernel
    launch_v2();
    stream.Sync();

    // Run CPU reference
    {
        const size_t in_bytes    = (size_t)total_tok * in_stride * elem_bytes;
        const size_t wt_bytes    = (size_t)d_conv * conv_dim * elem_bytes;
        const size_t state_bytes = (size_t)batch_size * conv_state_size * elem_bytes;
        const size_t out_bytes   = (size_t)total_tok * conv_dim * elem_bytes;

        std::vector h_in(in_bytes), h_wt(wt_bytes), h_state(state_bytes), h_out(out_bytes);

        cudaMemcpy(h_in.data(), all_proj.raw_data(), in_bytes, cudaMemcpyDeviceToHost);
        cudaMemcpy(h_wt.data(), weight.raw_data(), wt_bytes, cudaMemcpyDeviceToHost);
        std::memset(h_state.data(), 0, state_bytes);
        std::memset(h_out.data(), 0, out_bytes);

        auto run_cpu = [&](auto t) {
            using T = decltype(t);
            cpu_conv1d_silu((T*)h_out.data(),
                            (const T*)h_in.data(),
                            (const T*)h_wt.data(),
                            (T*)h_state.data(),
                            q_offsets_host.data(),
                            q_offsets_host.data(),  // k_offsets == q_offsets (no history in bench)
                            batch_size,
                            conv_dim,
                            d_conv,
                            in_stride);
        };

        if (dtype == kFloat16)
            run_cpu(half{});
        else
            run_cpu(nv_bfloat16{});

        cudaMemcpy(out_ref.raw_data(), h_out.data(), out_bytes, cudaMemcpyHostToDevice);
        cudaMemcpy(state_ref.raw_data(), h_state.data(), state_bytes, cudaMemcpyHostToDevice);
    }

    printf("  output comparison:\n");
    FC_Header();
    FC_Print(FastCompare(out_ref, out_v2, cu_stream));

    printf("  state comparison:\n");
    FC_Header();
    FC_Print(FastCompare(state_ref, state_v2, cu_stream));

    printf("\nDone.\n");
    return 0;
}


================================================
FILE: src/turbomind/models/llama/bench_gated_delta_net.cc
================================================

#include 
#include 
#include 

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/models/llama/gated_delta_net_kernels.h"

using namespace turbomind;
using namespace turbomind::core;

struct Args {
    int      batch_size  = 32;
    int      seq_len     = 64;
    int      num_v_heads = 16;
    int      num_k_heads = 4;
    int      warmup      = 10;
    int      iters       = 100;
    DataType dtype       = kFloat16;
    DataType state_dtype = kFloat32;

    static DataType ParseDtype(const char* s)
    {
        if (strcmp(s, "half") == 0 || strcmp(s, "fp16") == 0)
            return kFloat16;
        if (strcmp(s, "bf16") == 0)
            return kBfloat16;
        if (strcmp(s, "fp32") == 0 || strcmp(s, "float") == 0)
            return kFloat32;
        fprintf(stderr, "Unknown dtype: %s (expected half/fp16/bf16/fp32/float)\n", s);
        exit(1);
    }

    static Args Parse(int argc, char** argv)
    {
        Args a;
        for (int i = 1; i < argc; i += 2) {
            if (i + 1 >= argc) {
                fprintf(stderr, "Missing value for %s\n", argv[i]);
                exit(1);
            }
            if (strcmp(argv[i], "--batch_size") == 0)
                a.batch_size = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--seq_len") == 0)
                a.seq_len = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--num_v_heads") == 0)
                a.num_v_heads = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--num_k_heads") == 0)
                a.num_k_heads = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--warmup") == 0)
                a.warmup = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--iters") == 0)
                a.iters = atoi(argv[i + 1]);
            else if (strcmp(argv[i], "--dtype") == 0)
                a.dtype = ParseDtype(argv[i + 1]);
            else if (strcmp(argv[i], "--state_dtype") == 0)
                a.state_dtype = ParseDtype(argv[i + 1]);
            else {
                fprintf(stderr, "Unknown arg: %s\n", argv[i]);
                exit(1);
            }
        }
        return a;
    }

    void Print() const
    {
        printf(
            "batch_size=%d  seq_len=%d  num_v_heads=%d  num_k_heads=%d  warmup=%d  iters=%d  dtype=%s  state_dtype=%s\n",
            batch_size,
            seq_len,
            num_v_heads,
            num_k_heads,
            warmup,
            iters,
            to_string(dtype),
            to_string(state_dtype));
    }
};

static float
benchmark_kernel(const char* name, std::function launch, cudaStream_t stream, int warmup, int iters)
{
    for (int i = 0; i < warmup; ++i)
        launch();
    cudaStreamSynchronize(stream);

    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    cudaEventRecord(start, stream);
    for (int i = 0; i < iters; ++i)
        launch();
    cudaEventRecord(stop, stream);
    cudaEventSynchronize(stop);

    float ms = 0;
    cudaEventElapsedTime(&ms, start, stop);
    float avg_ms = ms / iters;

    printf("  %-45s  %8.3f ms (avg over %d iters)\n", name, avg_ms, iters);

    cudaEventDestroy(start);
    cudaEventDestroy(stop);
    return avg_ms;
}

int main(int argc, char** argv)
{
    auto args = Args::Parse(argc, argv);
    args.Print();

    constexpr int kHeadDim = 128;

    const int num_v_heads = args.num_v_heads;
    const int num_k_heads = args.num_k_heads;
    const int batch_size  = args.batch_size;
    const int seq_len     = args.seq_len;

    const int k_dim     = num_k_heads * kHeadDim;
    const int v_dim     = num_v_heads * kHeadDim;
    const int conv_dim  = 2 * k_dim + v_dim;
    const int total_tok = batch_size * seq_len;

    const int state_size = num_v_heads * kHeadDim * kHeadDim;  // per request

    const DataType dtype       = args.dtype;
    const DataType state_dtype = args.state_dtype;

    // --- Context setup ---
    auto         stream = Stream::create();
    ContextGuard ctx{stream, Allocator{kCPU}, Allocator{kCPUpinned}, Allocator{stream, false}};
    cudaStream_t cu_stream = stream.handle();

    const bool is_decode = (seq_len == 1);

    // --- Allocate tensors ---
    Tensor qkv_in{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};
    Tensor v_out_v2{Layout{{total_tok, v_dim}}, dtype, kDEVICE};
    Tensor v_out_chunked{Layout{{total_tok, v_dim}}, dtype, kDEVICE};
    Tensor v_out_v3{Layout{{total_tok, v_dim}}, dtype, kDEVICE};
    Tensor beta{Layout{{total_tok, num_v_heads}}, dtype, kDEVICE};
    Tensor g{Layout{{total_tok, num_v_heads}}, dtype, kDEVICE};

    // State buffers — all three kernels use state_dtype
    Tensor state_v2{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};
    Tensor state_chunked{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};
    Tensor state_v3{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};

    // State pointer arrays: host pinned + device
    Buffer_ state_ptrs_v2_host{batch_size, kCPUpinned};
    Buffer_ state_ptrs_v2_dev{batch_size, kDEVICE};
    Buffer_ state_ptrs_chunked_host{batch_size, kCPUpinned};
    Buffer_ state_ptrs_chunked_dev{batch_size, kDEVICE};
    Buffer_ state_ptrs_v3_host{batch_size, kCPUpinned};
    Buffer_ state_ptrs_v3_dev{batch_size, kDEVICE};

    // q_offsets: host + device
    Buffer_ q_offsets_host{batch_size + 1, kCPUpinned};
    Buffer_ q_offsets_dev{batch_size + 1, kDEVICE};

    // --- Fill random data ---
    RNG rng;
    rng.UniformFloat(qkv_in, 0.1f);
    rng.UniformFloat(beta, 1.0f);        // will be passed through sigmoid inside kernel
    rng.UniformFloat(g, 0.02f, -0.01f);  // small values around 0
    Clear(state_v2);
    Clear(state_chunked);
    Clear(state_v3);

    // --- Build q_offsets ---
    for (int i = 0; i <= batch_size; ++i)
        q_offsets_host.data()[i] = i * seq_len;
    Copy(q_offsets_host, batch_size + 1, q_offsets_dev);

    // --- Build state_ptrs ---
    const auto state_elem_bytes    = byte_size(state_dtype);
    const auto state_elem_bytes_v3 = byte_size(state_dtype);
    for (int i = 0; i < batch_size; ++i) {
        state_ptrs_v2_host.data()[i]      = (char*)state_v2.raw_data() + i * state_size * state_elem_bytes;
        state_ptrs_chunked_host.data()[i] = (char*)state_chunked.raw_data() + i * state_size * state_elem_bytes;
        state_ptrs_v3_host.data()[i]      = (char*)state_v3.raw_data() + i * state_size * state_elem_bytes_v3;
    }
    Copy(state_ptrs_v2_host, batch_size, state_ptrs_v2_dev);
    Copy(state_ptrs_chunked_host, batch_size, state_ptrs_chunked_dev);
    Copy(state_ptrs_v3_host, batch_size, state_ptrs_v3_dev);
    stream.Sync();

    // Shared resources for all three kernel launchers
    int sm_count = 1;
    {
        int device = 0;
        cudaGetDevice(&device);
        cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
    }
    Buffer_ work_counter_buf{1, kDEVICE};
    int*         work_counter = work_counter_buf.data();

    // --- Benchmark recurrent (v2) kernel ---
    printf("\n=== Benchmarks ===\n");
    auto launch_v2 = [&] {
        invokeGatedDeltaRuleBatched_v2(v_out_v2,
                                       qkv_in,
                                       beta,
                                       g,
                                       state_ptrs_v2_dev,
                                       q_offsets_dev,
                                       batch_size,
                                       num_k_heads,
                                       0,
                                       state_dtype,
                                       sm_count,
                                       work_counter,
                                       cu_stream);
    };
    float v2_ms = benchmark_kernel("invokeGatedDeltaRuleBatched_v2", launch_v2, cu_stream, args.warmup, args.iters);

    // --- Benchmark chunked kernel ---
    auto launch_chunked = [&] {
        invokeChunkedGatedDeltaRuleBatched(v_out_chunked,
                                           qkv_in,
                                           beta,
                                           g,
                                           state_ptrs_chunked_dev,
                                           q_offsets_dev,
                                           batch_size,
                                           num_k_heads,
                                           0,
                                           state_dtype,
                                           sm_count,
                                           work_counter,
                                           cu_stream);
    };
    float chunked_ms =
        benchmark_kernel("invokeChunkedGatedDeltaRuleBatched", launch_chunked, cu_stream, args.warmup, args.iters);

    // --- Benchmark v3 persistent decode kernel (seq_len == 1 only) ---
    float v3_ms     = -1.f;
    auto  launch_v3 = [&] {
        invokeGatedDeltaRuleBatched_v3(v_out_v3,
                                       qkv_in,
                                       beta,
                                       g,
                                       state_ptrs_v3_dev,
                                       q_offsets_dev,
                                       batch_size,
                                       num_k_heads,
                                       0,
                                       state_dtype,
                                       sm_count,
                                       work_counter,
                                       cu_stream);
    };
    if (is_decode) {
        v3_ms = benchmark_kernel(
            "invokeGatedDeltaRuleBatched_v3 (persistent)", launch_v3, cu_stream, args.warmup, args.iters);
    }
    else {
        printf("  %-45s  (skipped — seq_len > 1)\n", "invokeGatedDeltaRuleBatched_v3 (persistent)");
    }

    printf("\n  Speedup v2 / chunked:  %.2fx\n", v2_ms / chunked_ms);
    if (is_decode)
        printf("  Speedup v2 / v3:       %.2fx\n", v2_ms / v3_ms);

    // --- Bandwidth stats ---
    {
        double state_bytes    = (double)batch_size * state_size * state_elem_bytes * 2.0;
        double state_bytes_v3 = (double)batch_size * state_size * state_elem_bytes_v3 * 2.0;
        printf("\n=== Bandwidth ===\n");
        printf("  v2:      state BW = %.1f GB/s\n", state_bytes / (v2_ms * 1e6));
        printf("  chunked: state BW = %.1f GB/s\n", state_bytes / (chunked_ms * 1e6));
        if (is_decode)
            printf("  v3:      state BW = %.1f GB/s\n", state_bytes_v3 / (v3_ms * 1e6));
        printf("  total_tokens = %d\n", total_tok);
    }

    // === Cross-comparison: run both kernels on identical input, compare outputs ===
    printf("\n=== Cross-comparison (v2 vs chunked) ===\n");

    // Reset states to identical initial values (zero)
    Clear(state_v2);
    Clear(state_chunked);
    Clear(v_out_v2);
    Clear(v_out_chunked);
    stream.Sync();

    // Single invocation of each kernel
    launch_v2();
    launch_chunked();
    stream.Sync();

    // Compare v_out
    printf("  v_out comparison:\n");
    FC_Header();
    auto v_out_stats = FastCompare(v_out_v2, v_out_chunked, cu_stream);
    FC_Print(v_out_stats);

    // Compare final states
    printf("  state comparison:\n");
    FC_Header();
    auto state_stats = FastCompare(state_v2, state_chunked, cu_stream);
    FC_Print(state_stats);

    // === Cross-comparison: v2 vs v3 (decode only) ===
    if (is_decode) {
        printf("\n=== Cross-comparison (v2 vs v3, state_dtype=%s) ===\n", to_string(state_dtype));

        Clear(state_v2);
        Clear(state_v3);
        Clear(v_out_v2);
        Clear(v_out_v3);
        stream.Sync();

        launch_v2();
        launch_v3();
        stream.Sync();

        printf("  v_out comparison:\n");
        FC_Header();
        FC_Print(FastCompare(v_out_v2, v_out_v3, cu_stream));

        printf("  state comparison:\n");
        FC_Header();
        FC_Print(FastCompare(state_v2, state_v3, cu_stream));
    }

    printf("\nDone.\n");
    return 0;
}


================================================
FILE: src/turbomind/models/llama/context.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

#include 

#include 
#include 

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/core.h"
#include "src/turbomind/models/llama/LlamaLinear.h"

namespace turbomind {

struct Communicators {
    comm::HostComm h_global;
    comm::HostComm h_comm;
    comm::HostComm h_tp_group;
    comm::HostComm h_dp_group;

    comm::DeviceComm d_comm;
    int              d_tp_group;
    int              d_cp_group;
};

// Execution context for the model
struct Context {
    core::Stream                 core_stream;
    core::Allocator              allocator;
    cudaStream_t                 stream;
    std::unique_ptr linear;
    cudaDeviceProp               device_prop;
    Communicators                comm;  // initialize later
    std::unique_ptr         is_warm_up;

    Context(int device_id):
        core_stream{core::Stream::create()},
        allocator{core::Allocator(core_stream, false)},
        stream{core_stream.handle()},
        comm{},  // value initialize
        is_warm_up{std::make_unique()}
    {
        core::ContextGuard guard{core_stream};
        linear = std::make_unique();
        check_cuda_error(cudaGetDeviceProperties(&device_prop, device_id));
    }
};

inline Allocator GetSymmAllocator(const comm::DeviceComm& comm)
{
    TM_CHECK(comm);
    return core::SimpleAllocator::Create(
        [&comm](auto size) {
            auto p = comm->Allocate(size);
            comm->Register(p, size);
            return p;
        },
        [&comm](void* p, auto size) {
            comm->Deregister(p);
            comm->Free(p);
        },
        kDEVICE);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/gated_delta_net_kernels.cu
================================================

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/models/llama/gated_delta_net_kernels.h"

#include 
#include 
#include 
#include 

#include "src/turbomind/utils/cuda_utils.h"

#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/layout.h"
#include "src/turbomind/kernels/gemm/thread_map.h"

namespace turbomind {

using namespace gemm;

template
__global__ void recurrent_gated_delta_rule_kernel_v2(T*         v_out,
                                                     const T*   qkv_in,
                                                     const T*   beta_in,
                                                     const T*   g_in,
                                                     S* const*  state_ptrs,
                                                     const int* q_offsets,
                                                     int        num_v_heads,
                                                     int        num_k_heads,
                                                     int        k_dim_total,
                                                     int        state_layer_offset)
{
    const int bh    = blockIdx.x;
    const int b     = bh / num_v_heads;
    const int h     = bh % num_v_heads;
    const int ratio = num_v_heads / num_k_heads;
    const int kh    = h / ratio;

    const int tok_off    = q_offsets[b];
    const int seq_len    = q_offsets[b + 1] - tok_off;
    const int state_size = k_head_dim * v_head_dim;
    const int conv_dim   = 2 * k_dim_total + num_v_heads * v_head_dim;
    const int v_dim      = num_v_heads * v_head_dim;

    S* s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;

    const float scale = rsqrtf((float)k_head_dim);

    // DimC = v_head_dim (memory-contiguous), DimS = k_head_dim (strided)
    using Map_S = ThreadMap_V2;

    extern __shared__ __align__(16) char smem_buf[];

    // XOR swizzle: bits [10,13] (offset_k) XOR into column access-group index
    constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;  // log2(kAccessC)
    constexpr int kShift = 10 - kBase;
    using Layout         = SmemLayoutV2>;
    SmemAccessor smem_S{(S*)smem_buf};

    const int warp_id = threadIdx.x / WARP_SIZE;
    const int lane_id = threadIdx.x % WARP_SIZE;

    constexpr int tile_k = 16;
    constexpr int tile_v = 4;

    constexpr int k_tiles = k_head_dim / tile_k;  // 8
    constexpr int v_tiles = v_head_dim / tile_v;  // 32

    constexpr int k_threads = k_tiles;
    constexpr int v_threads = block_dim / k_threads;

    constexpr int v_iters = cdiv(v_tiles, v_threads);

    Array vec_S[v_iters][tile_k];

    const int offset_k = threadIdx.x % k_tiles;
    const int offset_v = threadIdx.x / k_tiles;

    constexpr int kAccessC = Map_S::kAccessC;

    PRAGMA_UNROLL
    for (int s = 0; s < Map_S::kIterS; ++s) {
        Array vec;
        PRAGMA_UNROLL
        for (int c = 0; c < Map_S::kIterC; ++c) {
            const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);
            const int final_vd  = vd + c * Map_S::kDeltaC;
            const int final_kd  = kd + s * Map_S::kDeltaS;
            Load(vec, s_ptr + final_kd * v_head_dim + final_vd);
            Store(&smem_S(final_kd, final_vd), vec);
        }
    }

    __syncthreads();

    PRAGMA_UNROLL
    for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
        PRAGMA_UNROLL
        for (int k = 0; k < tile_k; ++k) {
            constexpr int kTileAccessC = (tile_v >= kAccessC) ? kAccessC : tile_v;
            static_assert(tile_v % kTileAccessC == 0);
            PRAGMA_UNROLL
            for (int c = 0; c < tile_v / kTileAccessC; ++c) {
                Array tmp;
                Load(tmp, &smem_S(offset_k * tile_k + k, (offset_v + v_iter * v_threads) * tile_v + c * kTileAccessC));
                (Array&)vec_S[v_iter][k][c * kTileAccessC] = cast(tmp);
            }
        }
    }

    for (int t = 0; t < seq_len; ++t) {
        const int global_t = tok_off + t;

        const T* q_ptr = qkv_in + global_t * conv_dim + kh * k_head_dim;
        const T* k_ptr = qkv_in + global_t * conv_dim + k_dim_total + kh * k_head_dim;
        const T* v_ptr = qkv_in + global_t * conv_dim + 2 * k_dim_total + h * v_head_dim;
        T*       o_ptr = v_out + global_t * v_dim + h * v_head_dim;

        const float beta_val = (float)beta_in[global_t * num_v_heads + h];
        const float decay    = expf((float)g_in[global_t * num_v_heads + h]);

        Array vec_K;
        Array vec_Q;

        // --- In-kernel L2-normalize K/Q (Vectorized) ---
        {
            {
                Array tmp_K;
                Array tmp_Q;
                Load(tmp_K, &k_ptr[offset_k * tile_k]);
                Load(tmp_Q, &q_ptr[offset_k * tile_k]);
                vec_K = cast(tmp_K);
                vec_Q = cast(tmp_Q);
            }

            float k_sum = 0.f;
            float q_sum = 0.f;

            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                k_sum += vec_K[k] * vec_K[k];
                q_sum += vec_Q[k] * vec_Q[k];
            }

            PRAGMA_UNROLL
            for (int mask = k_threads / 2; mask > 0; mask /= 2) {
                k_sum += __shfl_xor_sync(0xffffffff, k_sum, mask);
                q_sum += __shfl_xor_sync(0xffffffff, q_sum, mask);
            }

            const float k_inv_norm = rsqrtf(k_sum + 1e-6f);
            const float q_inv_norm = rsqrtf(q_sum + 1e-6f);

            PRAGMA_UNROLL
            for (int i = 0; i < tile_k; ++i) {
                vec_K[i] = vec_K[i] * k_inv_norm;
                vec_Q[i] = vec_Q[i] * q_inv_norm;
            }
        }

        // Precompute KQ = dot(K, Q) — invariant across all v elements
        float KQ = 0.f;
        PRAGMA_UNROLL
        for (int k = 0; k < tile_k; ++k)
            KQ += vec_K[k] * vec_Q[k];
        PRAGMA_UNROLL
        for (int mask = k_threads / 2; mask > 0; mask /= 2)
            KQ += __shfl_xor_sync(0xffffffff, KQ, mask);

        Array vec_V[v_iters];

        PRAGMA_UNROLL
        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
            Load(vec_V[v_iter], &v_ptr[(offset_v + v_iter * v_threads) * tile_v]);
        }

        PRAGMA_UNROLL
        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
            Array vec_O;
            PRAGMA_UNROLL
            for (int v = 0; v < tile_v; ++v) {
                // Fused: decay + dual dot product (kv_mem and SQ simultaneously)
                float kv_mem = 0.f, SQ = 0.f;
                PRAGMA_UNROLL
                for (int k = 0; k < tile_k; ++k) {
                    float s_decayed     = vec_S[v_iter][k][v] * decay;
                    vec_S[v_iter][k][v] = s_decayed;
                    kv_mem += s_decayed * vec_K[k];
                    SQ += s_decayed * vec_Q[k];
                }

                // Single interleaved reduction (2 independent values -> good ILP)
                PRAGMA_UNROLL
                for (int mask = k_threads / 2; mask > 0; mask /= 2) {
                    kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);
                    SQ += __shfl_xor_sync(0xffffffff, SQ, mask);
                }

                const float delta = ((float)vec_V[v_iter][v] - kv_mem) * beta_val;

                // State update
                PRAGMA_UNROLL
                for (int k = 0; k < tile_k; ++k) {
                    vec_S[v_iter][k][v] += vec_K[k] * delta;
                }

                // Output: algebraic computation, NO reduction needed
                vec_O[v] = static_cast((SQ + delta * KQ) * scale);
            }
            if (offset_k == 0)
                Store(&o_ptr[(offset_v + v_iter * v_threads) * tile_v], vec_O);
        }
    }

    __syncthreads();

    PRAGMA_UNROLL
    for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
        PRAGMA_UNROLL
        for (int k = 0; k < tile_k; ++k) {
            constexpr int kTileAccessC = (tile_v >= kAccessC) ? kAccessC : tile_v;
            PRAGMA_UNROLL
            for (int c = 0; c < tile_v / kTileAccessC; ++c) {
                auto tmp = cast((Array&)vec_S[v_iter][k][c * kTileAccessC]);
                Store(&smem_S(offset_k * tile_k + k, (offset_v + v_iter * v_threads) * tile_v + c * kTileAccessC), tmp);
            }
        }
    }

    __syncthreads();

    PRAGMA_UNROLL
    for (int s = 0; s < Map_S::kIterS; ++s) {
        Array vec;
        PRAGMA_UNROLL
        for (int c = 0; c < Map_S::kIterC; ++c) {
            const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);
            const int final_vd  = vd + c * Map_S::kDeltaC;
            const int final_kd  = kd + s * Map_S::kDeltaS;
            Load(vec, &smem_S(final_kd, final_vd));
            Store(s_ptr + final_kd * v_head_dim + final_vd, vec);
        }
    }
}

void invokeGatedDeltaRuleBatched_v2(Ref           v_out_,
                                    const Tensor&         qkv_in,
                                    const Tensor&         beta,
                                    const Tensor&         g,
                                    const Buffer_& state_ptrs,
                                    const Buffer_&   q_offsets,
                                    int                   batch_size,
                                    int                   num_k_heads,
                                    int                   state_layer_offset,
                                    DataType              state_dtype,
                                    int /*sm_count*/,
                                    int* /*work_counter*/,
                                    cudaStream_t stream)
{
    auto& v_out = v_out_.get();

    const int num_v_heads    = beta.shape(1);
    const int v_dim          = v_out.shape(1);
    const int value_head_dim = v_dim / num_v_heads;
    const int k_dim_total    = (qkv_in.shape(1) - v_dim) / 2;

    if (batch_size == 0 || num_v_heads == 0)
        return;

    constexpr int kHeadDim  = 128;
    constexpr int kBlockDim = 256;

    TM_CHECK_EQ(value_head_dim, kHeadDim);
    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);

    const int num_blocks = batch_size * num_v_heads;

    auto invoke = [&](auto t) {
        using T     = decltype(t);
        auto launch = [&](auto s) {
            using S = decltype(s);

            auto kernel = recurrent_gated_delta_rule_kernel_v2;

            const size_t smem_sz = kHeadDim * kHeadDim * sizeof(S);
            if (smem_sz > 48 << 10) {
                cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz);
            }

            kernel<<>>(v_out.data(),
                                                               qkv_in.data(),
                                                               beta.data(),
                                                               g.data(),
                                                               (S* const*)state_ptrs.data(),
                                                               q_offsets.data(),
                                                               num_v_heads,
                                                               num_k_heads,
                                                               k_dim_total,
                                                               state_layer_offset);
        };
        if (state_dtype == kFloat32) {
            launch(float{});
        }
        else {
            launch(T{});
        }
    };
    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);
}

// =============================================================================
// Recurrent Gated Delta Rule — Persistent decode kernel (seq_len == 1 only).
//
// Designed for large-batch decode (e.g., bs=1024, 64 heads = 65536 work-items).
// Instead of launching one block per (b, h) pair, we launch only as many blocks
// as can be simultaneously resident (determined via the CUDA occupancy API), and
// each block iterates over multiple (b, h) work-items in a persistent loop.
//
// State is loaded/stored directly between global memory and registers (no smem
// staging), eliminating all __syncthreads() from the loop body. Each thread
// owns a [tile_k, tile_v] register tile and issues strided 8-byte tile loads
// directly from global memory. smem_sz = 0 in the host launcher.
// =============================================================================
template
__global__ __launch_bounds__(block_dim, 2) void recurrent_gated_delta_rule_kernel_v3(T*         v_out,
                                                                                     const T*   qkv_in,
                                                                                     const T*   beta_in,
                                                                                     const T*   g_in,
                                                                                     S* const*  state_ptrs,
                                                                                     const int* q_offsets,
                                                                                     int*       work_counter,
                                                                                     int        total_work,
                                                                                     int        num_v_heads,
                                                                                     int        num_k_heads,
                                                                                     int        k_dim_total,
                                                                                     int        state_layer_offset)
{
    constexpr int state_size = k_head_dim * v_head_dim;
    const int     conv_dim   = 2 * k_dim_total + num_v_heads * v_head_dim;
    const int     v_dim      = num_v_heads * v_head_dim;
    const float   scale      = rsqrtf((float)k_head_dim);

    // Compile-time thread partition (identical to v2)
    constexpr int tile_k    = 16;
    constexpr int tile_v    = 4;
    constexpr int k_tiles   = k_head_dim / tile_k;
    constexpr int v_tiles   = v_head_dim / tile_v;
    constexpr int k_threads = k_tiles;
    constexpr int v_threads = block_dim / k_threads;
    constexpr int v_iters   = cdiv(v_tiles, v_threads);

    const int offset_k = threadIdx.x % k_tiles;
    const int offset_v = threadIdx.x / k_tiles;

    // Persistent loop: each block atomically claims the next (b, h) work-item.
    // Thread 0 issues the atomic; result is broadcast to all threads via smem.
    __shared__ int s_work_idx;
    while (true) {
        if (threadIdx.x == 0)
            s_work_idx = atomicAdd(work_counter, 1);
        __syncthreads();
        const int work_idx = s_work_idx;
        if (work_idx >= total_work)
            break;
        const int b     = work_idx / num_v_heads;
        const int h     = work_idx % num_v_heads;
        const int ratio = num_v_heads / num_k_heads;
        const int kh    = h / ratio;

        const int global_t = q_offsets[b];  // seq_len == 1 guaranteed

        S* s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;

        // --- Load state: global → registers (direct strided tile loads, tile_v contiguous) ---
        Array vec_S[v_iters][tile_k];

        PRAGMA_UNROLL
        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                Array tmp;
                Load(tmp, &s_ptr[(offset_k * tile_k + k) * v_head_dim + (offset_v + v_iter * v_threads) * tile_v]);
                vec_S[v_iter][k] = cast(tmp);
            }
        }

        // --- Process single token (seq_len == 1) ---
        {
            const T* q_ptr = qkv_in + global_t * conv_dim + kh * k_head_dim;
            const T* k_ptr = qkv_in + global_t * conv_dim + k_dim_total + kh * k_head_dim;
            const T* v_ptr = qkv_in + global_t * conv_dim + 2 * k_dim_total + h * v_head_dim;
            T*       o_ptr = v_out + global_t * v_dim + h * v_head_dim;

            const float beta_val = (float)beta_in[global_t * num_v_heads + h];
            const float decay    = expf((float)g_in[global_t * num_v_heads + h]);

            Array vec_K;
            Array vec_Q;

            // L2-normalize K and Q in registers
            {
                Array tmp_K;
                Array tmp_Q;
                Load(tmp_K, &k_ptr[offset_k * tile_k]);
                Load(tmp_Q, &q_ptr[offset_k * tile_k]);
                vec_K = cast(tmp_K);
                vec_Q = cast(tmp_Q);
            }

            float k_sum = 0.f, q_sum = 0.f;
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                k_sum += vec_K[k] * vec_K[k];
                q_sum += vec_Q[k] * vec_Q[k];
            }
            PRAGMA_UNROLL
            for (int mask = k_threads / 2; mask > 0; mask /= 2) {
                k_sum += __shfl_xor_sync(0xffffffff, k_sum, mask);
                q_sum += __shfl_xor_sync(0xffffffff, q_sum, mask);
            }
            const float k_inv_norm = rsqrtf(k_sum + 1e-6f);
            const float q_inv_norm = rsqrtf(q_sum + 1e-6f);
            PRAGMA_UNROLL
            for (int i = 0; i < tile_k; ++i) {
                vec_K[i] *= k_inv_norm;
                vec_Q[i] *= q_inv_norm;
            }

            // KQ dot product (invariant across v elements)
            float KQ = 0.f;
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k)
                KQ += vec_K[k] * vec_Q[k];
            PRAGMA_UNROLL
            for (int mask = k_threads / 2; mask > 0; mask /= 2)
                KQ += __shfl_xor_sync(0xffffffff, KQ, mask);

            Array vec_V[v_iters];
            PRAGMA_UNROLL
            for (int v_iter = 0; v_iter < v_iters; ++v_iter)
                Load(vec_V[v_iter], &v_ptr[(offset_v + v_iter * v_threads) * tile_v]);

            PRAGMA_UNROLL
            for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
                Array vec_O;
                PRAGMA_UNROLL
                for (int v = 0; v < tile_v; ++v) {
                    // Fused: decay + dual dot product (kv_mem and SQ simultaneously)
                    float kv_mem = 0.f, SQ = 0.f;
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k) {
                        float s_decayed     = vec_S[v_iter][k][v] * decay;
                        vec_S[v_iter][k][v] = s_decayed;
                        kv_mem += s_decayed * vec_K[k];
                        SQ += s_decayed * vec_Q[k];
                    }
                    PRAGMA_UNROLL
                    for (int mask = k_threads / 2; mask > 0; mask /= 2) {
                        kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);
                        SQ += __shfl_xor_sync(0xffffffff, SQ, mask);
                    }
                    const float delta = ((float)vec_V[v_iter][v] - kv_mem) * beta_val;
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k)
                        vec_S[v_iter][k][v] += vec_K[k] * delta;
                    vec_O[v] = static_cast((SQ + delta * KQ) * scale);
                }
                if (offset_k == 0)
                    Store(&o_ptr[(offset_v + v_iter * v_threads) * tile_v], vec_O);
            }
        }

        // --- Store state: registers → global (direct strided tile stores, tile_v contiguous) ---
        PRAGMA_UNROLL
        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                auto tmp = cast(vec_S[v_iter][k]);
                Store(&s_ptr[(offset_k * tile_k + k) * v_head_dim + (offset_v + v_iter * v_threads) * tile_v], tmp);
            }
        }
    }
}

void invokeGatedDeltaRuleBatched_v3(Ref           v_out_,
                                    const Tensor&         qkv_in,
                                    const Tensor&         beta,
                                    const Tensor&         g,
                                    const Buffer_& state_ptrs,
                                    const Buffer_&   q_offsets,
                                    int                   batch_size,
                                    int                   num_k_heads,
                                    int                   state_layer_offset,
                                    DataType              state_dtype,
                                    int                   sm_count,
                                    int*                  work_counter,
                                    cudaStream_t          stream)
{
    auto& v_out = v_out_.get();

    const int num_v_heads = beta.shape(1);
    const int v_dim       = v_out.shape(1);
    const int k_dim_total = (qkv_in.shape(1) - v_dim) / 2;

    if (batch_size == 0 || num_v_heads == 0)
        return;

    constexpr int kHeadDim  = 128;
    constexpr int kBlockDim = 256;

    TM_CHECK_EQ(v_dim / num_v_heads, kHeadDim);
    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);

    const int total_work = batch_size * num_v_heads;

    auto invoke = [&](auto t) {
        using T     = decltype(t);
        auto launch = [&](auto s) {
            using S = decltype(s);

            auto         kernel        = recurrent_gated_delta_rule_kernel_v3;
            const size_t smem_sz       = sizeof(int);  // s_work_idx
            int          blocks_per_sm = 1;
            cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, kBlockDim, smem_sz);
            const int grid_blocks = min(total_work, blocks_per_sm * sm_count);

            cudaMemsetAsync(work_counter, 0, sizeof(int), stream);
            kernel<<>>(v_out.data(),
                                                                qkv_in.data(),
                                                                beta.data(),
                                                                g.data(),
                                                                (S* const*)state_ptrs.data(),
                                                                q_offsets.data(),
                                                                work_counter,
                                                                total_work,
                                                                num_v_heads,
                                                                num_k_heads,
                                                                k_dim_total,
                                                                state_layer_offset);
        };
        if (state_dtype == kFloat32)
            launch(float{});
        else
            launch(T{});
    };
    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);
}

// =============================================================================
// Chunked Gated Delta Rule kernel — register-centric, small chunk size.
//
// Grid = batch_size * num_v_heads blocks, one block per (b, h) pair.
// Cooperative QKV load to smem per chunk, then sequential per-token
// processing (same recurrence as v2) reading from smem.
// State load/store uses the full swizzled smem buffer (same as v2).
// =============================================================================
template
__global__ void chunked_gated_delta_rule_kernel(T*         v_out,
                                                const T*   qkv_in,
                                                const T*   beta_in,
                                                const T*   g_in,
                                                S* const*  state_ptrs,
                                                const int* q_offsets,
                                                int        num_v_heads,
                                                int        num_k_heads,
                                                int        k_dim_total,
                                                int        state_layer_offset)
{
    constexpr int C = kChunkSize;
    constexpr int D = kHeadDim;

    const int bh    = blockIdx.x;
    const int b     = bh / num_v_heads;
    const int h     = bh % num_v_heads;
    const int ratio = num_v_heads / num_k_heads;
    const int kh    = h / ratio;

    const int tok_off    = q_offsets[b];
    const int seq_len    = q_offsets[b + 1] - tok_off;
    const int state_size = D * D;
    const int conv_dim   = 2 * k_dim_total + num_v_heads * D;
    const int v_dim      = num_v_heads * D;

    if (seq_len == 0)
        return;

    S*          s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;
    const float scale = rsqrtf((float)D);

    // ── State tiling (same as v2) ──
    constexpr int tile_k    = 8;
    constexpr int tile_v    = 8;
    constexpr int k_tiles   = D / tile_k;                // 16
    constexpr int k_threads = k_tiles;                   // 16
    constexpr int v_threads = kBlockDim / k_threads;     // 16
    constexpr int v_tiles   = D / tile_v;                // 16
    constexpr int v_iters   = cdiv(v_tiles, v_threads);  // 1

    const int offset_k = threadIdx.x % k_threads;
    const int offset_v = threadIdx.x / k_threads;

    Array vec_S[v_iters][tile_k];

    extern __shared__ __align__(16) char smem_buf[];

    // ================================================================
    //  LOAD STATE  global → smem (swizzled) → registers   (same as v2)
    // ================================================================
    {
        using Map_S          = ThreadMap_V2;
        constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;
        constexpr int kShift = 10 - kBase;
        using Layout         = SmemLayoutV2>;
        SmemAccessor smem_S{(S*)smem_buf};

        const int     warp_id  = threadIdx.x / WARP_SIZE;
        const int     lane_id  = threadIdx.x % WARP_SIZE;
        constexpr int kAccessC = Map_S::kAccessC;

        PRAGMA_UNROLL
        for (int s = 0; s < Map_S::kIterS; ++s) {
            Array vec;
            PRAGMA_UNROLL
            for (int c = 0; c < Map_S::kIterC; ++c) {
                const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);
                const int fvd       = vd + c * Map_S::kDeltaC;
                const int fkd       = kd + s * Map_S::kDeltaS;
                Load(vec, s_ptr + fkd * D + fvd);
                Store(&smem_S(fkd, fvd), vec);
            }
        }
        __syncthreads();

        PRAGMA_UNROLL
        for (int vi = 0; vi < v_iters; ++vi) {
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                static_assert(tile_v % Map_S::kAccessC == 0);
                PRAGMA_UNROLL
                for (int c = 0; c < tile_v / Map_S::kAccessC; ++c) {
                    Array tmp;
                    Load(tmp,
                         &smem_S(offset_k * tile_k + k, (offset_v + vi * v_threads) * tile_v + c * Map_S::kAccessC));
                    (Array&)vec_S[vi][k][c * Map_S::kAccessC] = cast(tmp);
                }
            }
        }
    }
    __syncthreads();

    // ================================================================
    //  CHUNK PROCESSING  — sequential per-token (same as v2) with
    //  smem-cached QKV.  Eliminates resolvent/intra-attention overhead.
    // ================================================================
    // Shared memory layout for chunk processing (overlaps state staging buffer):
    //   k_norm_smem[C][kSmemStride]  — pre-normalized K
    //   q_norm_smem[C][kSmemStride]  — pre-normalized Q
    //   v_smem[C][kSmemStride]       — raw V (as float)
    //   scalars[3*C]                 — beta[C], g[C], scratch[C]
    constexpr int kSmemStride = D + 4;  // pad rows by 4 to avoid 4-way bank conflicts

    float* k_norm_smem = (float*)smem_buf;
    float* q_norm_smem = k_norm_smem + C * kSmemStride;
    float* v_smem      = q_norm_smem + C * kSmemStride;
    float* beta_vals   = v_smem + C * kSmemStride;
    float* g_vals      = beta_vals + C;

    // Thread-to-token mapping for cooperative loads: 1 warp per token
    constexpr int kThreadsPerTok = kBlockDim / C;                 // 256/8 = 32
    constexpr int kElemsPerThr   = D / kThreadsPerTok;            // 128/32 = 4
    const int     load_tok       = threadIdx.x / kThreadsPerTok;  // which token (0..C-1)
    const int     load_lane      = threadIdx.x % kThreadsPerTok;  // lane within token's warp

    const int num_chunks = (seq_len + C - 1) / C;

    for (int ci = 0; ci < num_chunks; ++ci) {
        const int chunk_start = tok_off + ci * C;
        const int valid_len   = min(C, seq_len - ci * C);

        // ────────────────────────────────────────────────────
        //  Phase 0: Cooperative load K, Q, V → smem (pre-normalized)
        //  32 threads (1 warp) per token, 4 elements per thread.
        //  Norms computed via warp shuffle, K/Q normalized in registers
        //  before writing to smem → eliminates one __syncthreads.
        // ────────────────────────────────────────────────────
        {
            float K_reg[kElemsPerThr], Q_reg[kElemsPerThr];
            float k_sq = 0.f, q_sq = 0.f;
            if (load_tok < valid_len) {
                const int gt    = chunk_start + load_tok;
                const T*  k_ptr = qkv_in + gt * conv_dim + k_dim_total + kh * D;
                const T*  q_ptr = qkv_in + gt * conv_dim + kh * D;
                const T*  v_ptr = qkv_in + gt * conv_dim + 2 * k_dim_total + h * D;
                PRAGMA_UNROLL
                for (int e = 0; e < kElemsPerThr; ++e) {
                    const int d = load_lane * kElemsPerThr + e;
                    K_reg[e]    = (float)k_ptr[d];
                    Q_reg[e]    = (float)q_ptr[d];
                    k_sq += K_reg[e] * K_reg[e];
                    q_sq += Q_reg[e] * Q_reg[e];
                    v_smem[load_tok * kSmemStride + d] = (float)v_ptr[d];
                }
                if (load_lane == 0) {
                    beta_vals[load_tok] = (float)beta_in[gt * num_v_heads + h];
                    g_vals[load_tok]    = (float)g_in[gt * num_v_heads + h];
                }
            }
            else {
                PRAGMA_UNROLL
                for (int e = 0; e < kElemsPerThr; ++e) {
                    K_reg[e]                                                      = 0.f;
                    Q_reg[e]                                                      = 0.f;
                    v_smem[load_tok * kSmemStride + load_lane * kElemsPerThr + e] = 0.f;
                }
            }
            // Warp-reduce norms (32-thread warp per token)
            PRAGMA_UNROLL
            for (int mask = kThreadsPerTok / 2; mask > 0; mask >>= 1) {
                k_sq += __shfl_xor_sync(0xffffffff, k_sq, mask);
                q_sq += __shfl_xor_sync(0xffffffff, q_sq, mask);
            }
            const float k_inv = (load_tok < valid_len) ? rsqrtf(k_sq + 1e-6f) : 0.f;
            const float q_inv = (load_tok < valid_len) ? rsqrtf(q_sq + 1e-6f) : 0.f;
            // Write normalized K, Q to smem
            PRAGMA_UNROLL
            for (int e = 0; e < kElemsPerThr; ++e) {
                const int d                             = load_lane * kElemsPerThr + e;
                k_norm_smem[load_tok * kSmemStride + d] = K_reg[e] * k_inv;
                q_norm_smem[load_tok * kSmemStride + d] = Q_reg[e] * q_inv;
            }
        }
        __syncthreads();  // [sync 1] all smem data ready

        // ────────────────────────────────────────────────────
        //  Sequential per-token loop (same computation as v2)
        //  Reads K, Q, V from smem instead of global memory.
        // ────────────────────────────────────────────────────
        PRAGMA_UNROLL
        for (int t = 0; t < C; ++t) {
            if (t >= valid_len)
                break;

            const int   gt       = chunk_start + t;
            const float beta_val = beta_vals[t];
            const float decay    = expf(g_vals[t]);

            float vec_K[tile_k];
            float vec_Q[tile_k];
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                vec_K[k] = k_norm_smem[t * kSmemStride + offset_k * tile_k + k];
                vec_Q[k] = q_norm_smem[t * kSmemStride + offset_k * tile_k + k];
            }

            PRAGMA_UNROLL
            for (int vi = 0; vi < v_iters; ++vi) {
                const int v_base = (offset_v + vi * v_threads) * tile_v;

                float vec_V[tile_v];
                PRAGMA_UNROLL
                for (int v = 0; v < tile_v; ++v)
                    vec_V[v] = v_smem[t * kSmemStride + v_base + v];

                Array vec_O;
                PRAGMA_UNROLL
                for (int v = 0; v < tile_v; ++v) {
                    // Step 1: state *= decay
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k)
                        vec_S[vi][k][v] *= decay;

                    // Step 2: delta rule update
                    float kv_mem = 0.f;
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k)
                        kv_mem += vec_S[vi][k][v] * vec_K[k];
                    PRAGMA_UNROLL
                    for (int mask = k_threads / 2; mask > 0; mask /= 2)
                        kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);
                    const float delta = (vec_V[v] - kv_mem) * beta_val;
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k)
                        vec_S[vi][k][v] += vec_K[k] * delta;

                    // Step 3: output = (S^T @ q) * scale
                    float O = 0.f;
                    PRAGMA_UNROLL
                    for (int k = 0; k < tile_k; ++k)
                        O += vec_S[vi][k][v] * vec_Q[k];
                    PRAGMA_UNROLL
                    for (int mask = k_threads / 2; mask > 0; mask /= 2)
                        O += __shfl_xor_sync(0xffffffff, O, mask);
                    vec_O[v] = static_cast(O * scale);
                }
                if (offset_k == 0)
                    Store(&v_out[gt * v_dim + h * D + v_base], vec_O);
            }
        }
        __syncthreads();  // [sync 2] ensure all reads done before next chunk overwrites smem
    }                     // chunk loop

    // ================================================================
    //  STORE STATE  registers → smem (swizzled) → global   (same as v2)
    // ================================================================
    {
        using Map_S          = ThreadMap_V2;
        constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;
        constexpr int kShift = 10 - kBase;
        using Layout         = SmemLayoutV2>;
        SmemAccessor smem_S{(S*)smem_buf};
        constexpr int           kAccessC = Map_S::kAccessC;

        PRAGMA_UNROLL
        for (int vi = 0; vi < v_iters; ++vi) {
            PRAGMA_UNROLL
            for (int k = 0; k < tile_k; ++k) {
                PRAGMA_UNROLL
                for (int c = 0; c < tile_v / kAccessC; ++c) {
                    auto tmp = cast((Array&)vec_S[vi][k][c * kAccessC]);
                    Store(&smem_S(offset_k * tile_k + k, (offset_v + vi * v_threads) * tile_v + c * kAccessC), tmp);
                }
            }
        }
        __syncthreads();

        const int warp_id = threadIdx.x / WARP_SIZE;
        const int lane_id = threadIdx.x % WARP_SIZE;

        PRAGMA_UNROLL
        for (int s = 0; s < Map_S::kIterS; ++s) {
            Array vec;
            PRAGMA_UNROLL
            for (int c = 0; c < Map_S::kIterC; ++c) {
                const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);
                const int fvd       = vd + c * Map_S::kDeltaC;
                const int fkd       = kd + s * Map_S::kDeltaS;
                Load(vec, &smem_S(fkd, fvd));
                Store(s_ptr + fkd * D + fvd, vec);
            }
        }
    }
}

// Host-side launcher
void invokeChunkedGatedDeltaRuleBatched(Ref           v_out_,
                                        const Tensor&         qkv_in,
                                        const Tensor&         beta,
                                        const Tensor&         g,
                                        const Buffer_& state_ptrs,
                                        const Buffer_&   q_offsets,
                                        int                   batch_size,
                                        int                   num_k_heads,
                                        int                   state_layer_offset,
                                        DataType              state_dtype,
                                        int /*sm_count*/,
                                        int* /*work_counter*/,
                                        cudaStream_t stream)
{
    auto& v_out = v_out_.get();

    const int num_v_heads    = beta.shape(1);
    const int v_dim          = v_out.shape(1);
    const int value_head_dim = v_dim / num_v_heads;
    const int k_dim_total    = (qkv_in.shape(1) - v_dim) / 2;

    if (batch_size == 0 || num_v_heads == 0)
        return;

    constexpr int kHeadDim   = 128;
    constexpr int kChunkSize = 16;
    constexpr int kBlockDim  = 256;

    TM_CHECK_EQ(value_head_dim, kHeadDim);
    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);

    const int num_blocks = batch_size * num_v_heads;

    auto invoke = [&](auto t) {
        using T     = decltype(t);
        auto launch = [&](auto s) {
            using S = decltype(s);

            auto kernel = chunked_gated_delta_rule_kernel;

            // smem = max(state staging, chunk working buffers)
            // State staging: D*D*sizeof(S) (64KB for fp32)
            // Chunk buffers: QKV cache [3*C*(D+4)] + scalars[2*C]
            const size_t state_smem  = kHeadDim * kHeadDim * sizeof(S);
            const int    kSmemStride = kHeadDim + 4;
            const size_t chunk_smem  = 3 * kChunkSize * kSmemStride * sizeof(float)  // k_norm, q_norm, v
                                      + 2 * kChunkSize * sizeof(float);              // beta, g
            const size_t smem_sz = state_smem > chunk_smem ? state_smem : chunk_smem;

            if (smem_sz > 48 << 10) {
                cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz);
            }

            kernel<<>>(v_out.data(),
                                                               qkv_in.data(),
                                                               beta.data(),
                                                               g.data(),
                                                               (S* const*)state_ptrs.data(),
                                                               q_offsets.data(),
                                                               num_v_heads,
                                                               num_k_heads,
                                                               k_dim_total,
                                                               state_layer_offset);
        };
        if (state_dtype == kFloat32) {
            launch(float{});
        }
        else {
            launch(T{});
        }
    };
    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);
}

template
__global__ void compute_beta_g_kernel_v2(T*       beta_out,
                                         T*       g_out,
                                         const T* b_in,
                                         int      b_stride,
                                         const T* a_in,
                                         int      a_stride,
                                         const T* A_log,
                                         const T* dt_bias,
                                         int      total,
                                         int      num_v_heads)
{
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx >= total)
        return;

    const int hi = idx % num_v_heads;
    const int ti = idx / num_v_heads;

    float b_val       = static_cast(b_in[ti * b_stride + hi]);
    float a_val       = static_cast(a_in[ti * a_stride + hi]);
    float A_log_val   = static_cast(A_log[hi]);
    float dt_bias_val = static_cast(dt_bias[hi]);

    float beta  = 1.0f / (1.0f + expf(-b_val));
    float sum   = a_val + dt_bias_val;
    float sp    = sum > 20.0f ? sum : logf(1.0f + expf(sum));
    float g_val = -expf(A_log_val) * sp;

    beta_out[idx] = static_cast(beta);
    g_out[idx]    = static_cast(g_val);
}

void ComputeBetaG_v2(Ref   beta_out_,
                     Ref   g_out_,
                     const Tensor& b_in,
                     const Tensor& a_in,
                     const Tensor& A_log,
                     const Tensor& dt_bias,
                     cudaStream_t  stream)
{

    auto& beta_out = beta_out_.get();
    auto& g_out    = g_out_.get();

    const int threads = 256;
    const int blocks  = cdiv(beta_out.size(), threads);

    auto invoke = [&](auto t) {
        using T = decltype(t);
        compute_beta_g_kernel_v2<<>>(beta_out.data(),
                                                                 g_out.data(),
                                                                 b_in.data(),
                                                                 b_in.stride(0),
                                                                 a_in.data(),
                                                                 a_in.stride(0),
                                                                 A_log.data(),
                                                                 dt_bias.data(),
                                                                 beta_out.size(),
                                                                 A_log.size());
    };

    TM_DISPATCH_PRIMARY_DTYPES(beta_out.dtype(), invoke);
}

// =============================================================================
// RMSNorm * SiLU-Gate (fused output normalization)
// =============================================================================
template
__global__ void rms_norm_gated_kernel(
    T* hidden, const T* gate, const T* weight, float eps, int N, int head_dim, int gate_stride, int num_heads)
{
    const int row = blockIdx.x;
    if (row >= N)
        return;

    T*        h         = hidden + row * head_dim;
    const int token_idx = row / num_heads;
    const int head_idx  = row % num_heads;
    const T*  g         = gate + token_idx * gate_stride + head_idx * head_dim;

    __shared__ float smem[32];
    float            sum_sq = 0.0f;
    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
        float val = static_cast(h[d]);
        sum_sq += val * val;
    }
    for (int mask = 16; mask > 0; mask >>= 1)
        sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);
    if ((threadIdx.x & 31) == 0)
        smem[threadIdx.x >> 5] = sum_sq;
    __syncthreads();
    if (threadIdx.x >> 5 == 0) {
        sum_sq = (threadIdx.x < (blockDim.x + 31) / 32) ? smem[threadIdx.x] : 0.0f;
        for (int mask = 16; mask > 0; mask >>= 1)
            sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);
        if (threadIdx.x == 0)
            smem[0] = sum_sq;
    }
    __syncthreads();
    sum_sq = smem[0];

    float inv_rms = rsqrtf(sum_sq / (float)head_dim + eps);
    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
        float h_val  = static_cast(h[d]) * inv_rms * static_cast(weight[d]);
        float g_val  = static_cast(g[d]);
        float silu_g = g_val / (1.0f + expf(-g_val));
        h[d]         = static_cast(h_val * silu_g);
    }
}

void invokeRMSNormGated(Ref hidden_, const Tensor& gate, const Tensor& weight, float eps, cudaStream_t stream)
{
    auto& hidden = hidden_.get();

    const int N           = hidden.shape(0);
    const int head_dim    = hidden.shape(1);
    const int token_num   = gate.shape(0);
    const int gate_stride = gate.stride(0);
    const int num_heads   = N / token_num;

    if (N == 0)
        return;

    const int threads = std::min(256, head_dim);

    auto invoke = [&](auto t) {
        using T = decltype(t);
        rms_norm_gated_kernel<<>>(
            hidden.data(), gate.data(), weight.data(), eps, N, head_dim, gate_stride, num_heads);
    };
    TM_DISPATCH_PRIMARY_DTYPES(hidden.dtype(), invoke);
}

// =============================================================================
// Fused Conv1d + SiLU — persistent batched kernel
//
// Weight layout: [d_conv, conv_dim], State layout: [d_conv, conv_dim] per batch.
//
// Persistent 1D grid. Each block has a fixed channel tile
// (blockIdx.x % num_ch_tiles) and atomically claims single-token work items
// via a global counter. Token-major work ordering with grid size a multiple of
// num_ch_tiles guarantees monotonically increasing tokens and a fixed channel
// tile per block.
// =============================================================================
template
__global__ void __launch_bounds__(BLOCK_DIM) fused_conv1d_batched_kernel_v2(T*           out,
                                                                            const T*     in,
                                                                            const T*     weight,
                                                                            const T*     bias,
                                                                            void* const* conv_state_ptrs,
                                                                            const int*   q_offsets,
                                                                            const int*   k_offsets,
                                                                            int*         work_counter,
                                                                            int          batch_size,
                                                                            int          conv_dim,
                                                                            int          in_stride,
                                                                            int          num_token_tiles,
                                                                            int          state_layer_offset,
                                                                            int          total_work,
                                                                            int          num_ch_tiles)
{
    static_assert(BLOCK_DIM * CHANNELS_PER_THREAD > 0);

    int prev_ch_tile = -1;
    int c_base       = 0;

    Array w_tap[D_CONV];
    Array bias_vals;

    __shared__ int  s_work_id;
    __shared__ int4 s_batch_info;
    int             b_start = 0;

    while (true) {
        if (threadIdx.x == 0)
            s_work_id = atomicAdd(work_counter, 1);
        __syncthreads();

        if (s_work_id >= total_work)
            break;

        const int t_tile  = s_work_id % num_token_tiles;
        const int ch_tile = s_work_id / num_token_tiles;

        if (ch_tile != prev_ch_tile) {
            prev_ch_tile = ch_tile;
            b_start      = 0;
        }

        c_base = (ch_tile * BLOCK_DIM + threadIdx.x) * CHANNELS_PER_THREAD;

        const bool ch_active = (c_base < conv_dim);

        if (ch_active) {
            PRAGMA_UNROLL
            for (int d = 0; d < D_CONV; ++d) {
                Load(w_tap[d], weight + d * conv_dim + c_base);
            }
            if (bias)
                Load(bias_vals, bias + c_base);
        }

        if constexpr (NUM_TOKENS == 1) {
            for (int b = b_start + threadIdx.x; b < batch_size; b += BLOCK_DIM) {
                int lo = __ldg(&q_offsets[b]);
                if (lo > t_tile)
                    break;
                int hi = __ldg(&q_offsets[b + 1]);
                if (t_tile < hi) {
                    int seq      = hi - lo;
                    int hist     = (__ldg(&k_offsets[b + 1]) - __ldg(&k_offsets[b])) - seq;
                    s_batch_info = make_int4(b, lo, seq, hist);
                }
            }
        }
        else {
            for (int b = b_start + threadIdx.x; b < batch_size; b += BLOCK_DIM) {
                int tile_off = __ldg(&q_offsets[b]) / NUM_TOKENS + b;
                if (tile_off > t_tile)
                    break;
                int tile_off_next = __ldg(&q_offsets[b + 1]) / NUM_TOKENS + b + 1;
                if (t_tile < tile_off_next) {
                    int lo       = __ldg(&q_offsets[b]);
                    int seq      = __ldg(&q_offsets[b + 1]) - lo;
                    int hist     = (__ldg(&k_offsets[b + 1]) - __ldg(&k_offsets[b])) - seq;
                    s_batch_info = make_int4(b, lo, seq, hist);
                }
            }
        }
        __syncthreads();

        b_start = s_batch_info.x;

        const int4 bi          = s_batch_info;
        const int  b           = bi.x;
        const int  seq_off     = bi.y;
        const int  seq_len     = bi.z;
        const int  history_len = bi.w;

        int t_local_start;
        int n_tokens;
        if constexpr (NUM_TOKENS == 1) {
            t_local_start = t_tile - seq_off;
            n_tokens      = 1;
        }
        else {
            const int tile_off_b = seq_off / NUM_TOKENS + b;
            t_local_start        = (t_tile - tile_off_b) * NUM_TOKENS;
            if (t_local_start >= seq_len)
                continue;
            n_tokens = min(NUM_TOKENS, seq_len - t_local_start);
        }

        const int ring_start = (history_len + t_local_start + 1) % D_CONV;
        T*        state_base = (T*)conv_state_ptrs[b] + state_layer_offset;

        if (ch_active) {
            constexpr int                 VALS_SIZE = NUM_TOKENS + D_CONV - 1;
            Array vals[VALS_SIZE];
            const int                     n_vals = n_tokens + D_CONV - 1;

            PRAGMA_UNROLL
            for (int i = 0; i < VALS_SIZE; ++i) {
                if (i < n_vals) {
                    int pos = t_local_start - (D_CONV - 1) + i;
                    if (pos >= 0) {
                        Load(vals[i], in + (seq_off + pos) * in_stride + c_base);
                    }
                    else {
                        int ring_d = (ring_start + i) % D_CONV;
                        Load(vals[i], state_base + ring_d * conv_dim + c_base);
                    }
                }
            }

            PRAGMA_UNROLL
            for (int tok = 0; tok < NUM_TOKENS; ++tok) {
                if (tok < n_tokens) {
                    float acc[CHANNELS_PER_THREAD] = {};
                    PRAGMA_UNROLL
                    for (int d = 0; d < D_CONV; ++d) {
                        PRAGMA_UNROLL
                        for (int ch = 0; ch < CHANNELS_PER_THREAD; ++ch) {
                            acc[ch] += static_cast(vals[tok + d][ch]) * static_cast(w_tap[d][ch]);
                        }
                    }

                    Array out_vals;
                    PRAGMA_UNROLL
                    for (int ch = 0; ch < CHANNELS_PER_THREAD; ++ch) {
                        if (bias)
                            acc[ch] += static_cast(bias_vals[ch]);
                        out_vals[ch] = static_cast(acc[ch] / (1.0f + expf(-acc[ch])));
                    }

                    Store(out + (seq_off + t_local_start + tok) * conv_dim + c_base, out_vals);
                }
            }

            if (t_local_start + n_tokens >= seq_len) {
                PRAGMA_UNROLL
                for (int i = 0; i < VALS_SIZE; ++i) {
                    int pos = t_local_start - (D_CONV - 1) + i;
                    if (pos >= 0 && pos >= seq_len - D_CONV && pos < seq_len) {
                        int ring_d = (ring_start + i) % D_CONV;
                        Store(state_base + ring_d * conv_dim + c_base, vals[i]);
                    }
                }
            }
        }
    }
}

void invokeFusedConv1dSiLU(Ref           out_,
                           const Tensor&         in,
                           const Tensor&         weight,
                           const Tensor&         bias,
                           const Buffer_& conv_state_ptrs,
                           const Buffer_&   q_offsets,
                           const Buffer_&   k_offsets,
                           int                   batch_size,
                           int                   state_layer_offset,
                           int                   sm_count,
                           int*                  work_counter,
                           cudaStream_t          stream)
{
    auto& out = out_.get();

    const int total_tokens = in.shape(0);
    const int d_conv       = weight.shape(0);
    const int conv_dim     = weight.shape(1);
    const int in_stride    = in.stride(0);

    constexpr int threads = 128;

    auto invoke = [&](auto t) {
        using T = decltype(t);
        if (d_conv == 4) {
            constexpr int kDConv     = 4;
            constexpr int kChPerT    = 8;
            const int     ch_per_blk = threads * kChPerT;
            TM_CHECK(conv_dim % kChPerT == 0);
            const int num_ch_tiles = cdiv(conv_dim, ch_per_blk);

            auto launch = [&](auto num_tok_tag) {
                constexpr int kNumTok         = decltype(num_tok_tag)::value;
                const int     num_token_tiles = (kNumTok == 1) ? total_tokens : total_tokens / kNumTok + batch_size;
                const int     total_work      = num_token_tiles * num_ch_tiles;

                auto kernel        = fused_conv1d_batched_kernel_v2;
                int  blocks_per_sm = 1;
                cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, threads, 0);
                int grid = min(total_work, blocks_per_sm * sm_count);

                cudaMemsetAsync(work_counter, 0, sizeof(int), stream);
                kernel<<>>(out.data(),
                                                     in.data(),
                                                     weight.data(),
                                                     bias ? bias.data() : (T*)nullptr,
                                                     conv_state_ptrs.data(),
                                                     q_offsets.data(),
                                                     k_offsets.data(),
                                                     work_counter,
                                                     batch_size,
                                                     conv_dim,
                                                     in_stride,
                                                     num_token_tiles,
                                                     state_layer_offset,
                                                     total_work,
                                                     num_ch_tiles);
            };

            int avg_seq = total_tokens / batch_size;
            if (avg_seq >= 4)
                launch(std::integral_constant{});
            else
                launch(std::integral_constant{});
        }
        else {
            TM_CHECK(0) << "Only d_conv == 4 is supported by fused_conv1d_batched_kernel_v2";
        }
    };
    TM_DISPATCH_PRIMARY_DTYPES(out.dtype(), invoke);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/gated_delta_net_kernels.h
================================================
#pragma once

#include 
#include 
#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

// Fused Conv1d + SiLU — unified batched launcher (row-major layout).
//
// Processes all requests in a single kernel launch.  Decode (seq_len == 1)
// and prefill (seq_len > 1) requests may be mixed freely within the batch.
//
// out:             (total_tokens, conv_dim)       row-major output
// in:              (total_tokens, in_stride)      non-contiguous slice of all_proj
// weight:          (d_conv, conv_dim)
// bias:            (conv_dim) or empty Tensor
// conv_state_ptrs: device array[batch_size] of per-request state pointers
// q_offsets:       device int[batch_size+1] cumulative token offsets
// k_offsets:       device int[batch_size+1] cumulative key (history+input) offsets
void invokeFusedConv1dSiLU(Ref           out,
                           const Tensor&         in,
                           const Tensor&         weight,
                           const Tensor&         bias,
                           const Buffer_& conv_state_ptrs,
                           const Buffer_&   q_offsets,
                           const Buffer_&   k_offsets,
                           int                   batch_size,
                           int                   state_layer_offset,
                           int                   sm_count,
                           int*                  work_counter,
                           cudaStream_t          stream);

// All three recurrent-rule launchers share the same trailing parameters for
// interface consistency:
//   sm_count      — multiprocessor count, queried once by the caller at init
//   work_counter  — device int* (1 element), owned by caller; v3 uses it for
//                   atomic workload claiming, v2/chunked ignore it
//   stream        — CUDA stream
//
// v2: standard one-block-per-(b,h) grid launch; sm_count and work_counter ignored.
void invokeGatedDeltaRuleBatched_v2(Ref           v_out,
                                    const Tensor&         qkv_in,
                                    const Tensor&         beta,
                                    const Tensor&         g,
                                    const Buffer_& state_ptrs,
                                    const Buffer_&   q_offsets,
                                    int                   batch_size,
                                    int                   num_k_heads,
                                    int                   state_layer_offset,
                                    DataType              state_dtype,
                                    int                   sm_count,
                                    int*                  work_counter,
                                    cudaStream_t          stream);

// v3: persistent decode kernel, seq_len == 1 only.
// Launches min(total_work, blocks_per_sm * sm_count) blocks; each block claims
// work items atomically via work_counter (zeroed via cudaMemsetAsync per launch).
// state_dtype controls state precision: kFloat32 → S=float, otherwise S=T.
void invokeGatedDeltaRuleBatched_v3(Ref           v_out,
                                    const Tensor&         qkv_in,
                                    const Tensor&         beta,
                                    const Tensor&         g,
                                    const Buffer_& state_ptrs,
                                    const Buffer_&   q_offsets,
                                    int                   batch_size,
                                    int                   num_k_heads,
                                    int                   state_layer_offset,
                                    DataType              state_dtype,
                                    int                   sm_count,
                                    int*                  work_counter,
                                    cudaStream_t          stream);

// =============================================================================
// Chunked Gated Delta Rule — for accelerating prefill
//
// Processes sequences in chunks of size C (default 64), parallelizing
// intra-chunk computation while maintaining sequential inter-chunk state
// updates. Reduces sequential depth from L to L/C.
//
// Same tensor layouts as invokeGatedDeltaRuleBatched_v2.
// sm_count and work_counter accepted for interface parity; ignored internally.
void invokeChunkedGatedDeltaRuleBatched(Ref           v_out,
                                        const Tensor&         qkv_in,
                                        const Tensor&         beta,
                                        const Tensor&         g,
                                        const Buffer_& state_ptrs,
                                        const Buffer_&   q_offsets,
                                        int                   batch_size,
                                        int                   num_k_heads,
                                        int                   state_layer_offset,
                                        DataType              state_dtype,
                                        int                   sm_count,
                                        int*                  work_counter,
                                        cudaStream_t          stream);

// =============================================================================
// Helper kernels
// =============================================================================

void ComputeBetaG_v2(Ref   beta_out_,
                     Ref   g_out_,
                     const Tensor& b_in,
                     const Tensor& a_in,
                     const Tensor& A_log,
                     const Tensor& dt_bias,
                     cudaStream_t  stream);

// RMSNorm * SiLU-gate (fused output normalization)
void invokeRMSNormGated(Ref hidden, const Tensor& gate, const Tensor& weight, float eps, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_kernels.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 

#include 
#include 

#include "src/turbomind/kernels/core/array.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/dispatch.h"

namespace turbomind {

__global__ void gatherOutput(int*       output_ids,
                             const int* ids,
                             const int* context_length,
                             int        max_context_len,
                             int        max_gen_step,
                             int        max_output_len,
                             int        batch_size)
{
    const int batch_id    = blockIdx.x;
    const int context_len = context_length[batch_id];
    output_ids += batch_id * max_output_len;
    for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) {
        // skip padding for src
        if (context_len <= src_idx && src_idx < max_context_len) {
            continue;
        }
        // skip padding for dst
        const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len);
        if (dst_idx < max_output_len) {
            output_ids[dst_idx] = ids[src_idx * batch_size + batch_id];
        }
    }
}

void invokeGatherOutput(int*         output_ids,
                        const int*   ids,
                        const int*   context_length,
                        int          max_context_len,
                        int          max_gen_step,
                        int          max_output_len,
                        int          batch_size,
                        cudaStream_t stream)
{
    int block_size = 128;
    int grid_size  = batch_size;
    gatherOutput<<>>(
        output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
}

__global__ void updateOutput(int**      request_output_ids_ptrs,
                             int**      request_seqlen_ptrs,
                             const int* output_ids,
                             const int* sequence_lengths,
                             const int* request_output_ids_lens,
                             int        max_session_len,
                             bool       token_generated)
{
    const int batch_id = blockIdx.x;

    auto request_output_ids = request_output_ids_ptrs[batch_id];
    auto request_seqlen     = request_seqlen_ptrs[batch_id];

    output_ids += max_session_len * batch_id;

    const int seqlen     = sequence_lengths[batch_id] + (int)token_generated;
    const int output_len = min(seqlen, request_output_ids_lens[batch_id]);

    for (int i = threadIdx.x; i < output_len; i += blockDim.x) {
        request_output_ids[i] = output_ids[i];
    }

    *request_seqlen = seqlen;
}

void invokeUpdateOutput(int**        request_output_ids_ptrs,
                        int**        request_seqlen_ptrs,
                        const int*   output_ids,
                        const int*   sequence_lengths,
                        const int*   request_output_ids_lens,
                        int          max_session_len,
                        bool         token_generated,
                        int          batch_size,
                        cudaStream_t stream)
{
    constexpr int block_size = 128;
    const int     grid_size  = batch_size;

    updateOutput<<>>(request_output_ids_ptrs,
                                                       request_seqlen_ptrs,
                                                       output_ids,
                                                       sequence_lengths,
                                                       request_output_ids_lens,
                                                       max_session_len,
                                                       token_generated);
}

template
__global__ void compactOutputIds(
    int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated)
{
    typedef cub::BlockReduce     BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;

    const int batch_idx = blockIdx.x;

    int end   = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM;  // align to BLOCK_DIM boundary
    int count = 0;
    for (int i = threadIdx.x; i < end; i += blockDim.x) {
        int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0;
        count += BlockReduce(temp_storage).Sum(x);
        // https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html
        __syncthreads();
    }

    __shared__ int offset;

    if (threadIdx.x == 0) {
        offset = count;
    }

    __syncthreads();

    auto dst = cu_output_ids + offset;

    const int seq_len = sequence_lengths[batch_idx];

    for (int i = threadIdx.x; i < seq_len; i += blockDim.x) {
        dst[i] = output_ids[batch_idx * session_len + i];
    }
}

void invokeCompactOutputIds(int*         cu_output_ids,
                            const int*   output_ids,
                            const int*   sequence_lengths,
                            int          max_session_len,
                            bool         token_generated,
                            int          batch_size,
                            cudaStream_t stream)
{
    constexpr int BLOCK_DIM = 128;
    compactOutputIds<<>>(
        cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated);
}

template
struct IndexedCopyParam {
    Array src_ptr;
    Array dst_ptr;
    Array   stride;
    Array   src_idx;
    Array   dst_idx;
    int             max_stride;
};

template
__global__ void indexedCopy(IndexedCopyParam param)
{
    const int bi = blockIdx.x;
    const int si = param.src_idx[bi];
    const int di = param.dst_idx[bi];
    for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) {
        PRAGMA_UNROLL
        for (int k = 0; k < N; ++k) {
            if (i < param.stride[k]) {
                *((T*)param.dst_ptr[k] + param.stride[k] * di + i) =
                    *((const T*)param.src_ptr[k] + param.stride[k] * si + i);
            }
        }
    }
}

template
void invokeIndexedCopyImpl(void**       h_src_ptr,
                           void**       h_dst_ptr,
                           const int*   h_elem_sz,
                           const int*   h_src_idx,
                           const int*   h_dst_idx,
                           int          count,
                           cudaStream_t st)
{
    dispatch(  // dispatch for num of copy operations
        std::integer_sequence{},
        [&](auto C) { return count <= C; },
        [&](auto C) {
            // maximum parameter size: sm<70: 4kB, sm>=70: 32kB
            static_assert(sizeof(IndexedCopyParam) <= 4096);
            IndexedCopyParam param{};
            std::copy_n(h_src_ptr, N, param.src_ptr.data());
            std::copy_n(h_dst_ptr, N, param.dst_ptr.data());
            std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) {
                // Basic alignment check
                FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr("misalignment: %d %% %d", size, (int)sizeof(T)));
                return size / sizeof(T);
            });
            param.max_stride = *std::max_element(param.stride.begin(), param.stride.end());
            auto copy_idx    = [](const int* src, int offset, int n, auto dst) {
                return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset);
            };
            for (int c = 0; c < count; c += C) {
                int batch_size = std::min(count - c, (int)C);
                copy_idx(h_src_idx, c, batch_size, param.src_idx.data());
                copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data());
                indexedCopy<<>>(param);
            }
        });
}

void invokeIndexedCopy(void**       h_src_ptr,
                       void**       h_dst_ptr,
                       const int*   h_elem_sz,
                       const int*   h_src_idx,
                       const int*   h_dst_idx,
                       int          count,
                       int          n_copys,
                       cudaStream_t st)
{
    auto success = dispatch(std::integer_sequence{}, [&](auto N) {
        if (N == n_copys) {
            invokeIndexedCopyImpl(h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st);
            return true;
        }
        return false;
    });
    FT_CHECK(success);
}

__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size)
{
    for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) {
        token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi];
    }
}

void invokePadLastTokenIds(
    int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream)
{
    padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size);
}

template
__global__ void getFeatureOfLastToken(T* output, const T* input, const int* cu_seqlens, int dims)
{
    int bi = blockIdx.x;
    int ti = cu_seqlens[bi + 1] - 1;
    for (int i = threadIdx.x; i < dims; i += blockDim.x) {
        output[dims * bi + i] = input[dims * ti + i];
    }
}

void invokeGetFeatureOfLastToken(
    uint16_t* output, const uint16_t* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream)
{
    getFeatureOfLastToken<<>>(output, input, cu_seqlens, dims);
}

template
struct BatchedCopyParam {
    Array  src_ptr;
    Array  dst_ptr;
    Array size;
    int           count;
};

template
__global__ void batchedCopy(BatchedCopyParam param)
{
    const int ti = threadIdx.x + blockIdx.x * blockDim.x;
    const int bi = ti / kThrPerCpy;
    if (bi >= param.count) {
        return;
    }
    const T* __restrict__ src = param.src_ptr[bi];
    T* __restrict__ dst       = param.dst_ptr[bi];
    int size                  = param.size[bi];
    for (int i = ti % kThrPerCpy; i < size; i += kThrPerCpy) {
        dst[i] = src[i];
    }
}

// MSVC does not like CUDA kernel launch inside nested lambdas
template
struct BatchedCopyLauncher {
    int          max_size;
    int          count;
    const P*     params;
    cudaStream_t st;

    template
    void operator()(std::integral_constant) const
    {
        constexpr int threads         = 128;
        constexpr int items_per_block = threads / S;
        const int     blocks          = (count + items_per_block - 1) / items_per_block;
        batchedCopy<<>>(*params);
    }
};

void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st)
{
    dispatch(
        std::integer_sequence{},
        [&](auto C) { return count <= C; },
        [&](auto C) {
            using T = uint32_t;
            BatchedCopyParam params{};
            // TODO: on CUDA 12.1 and sm_70+ this can be 32K
            static_assert(sizeof(params) <= 4096);
            for (int c = 0; c < count; c += C) {
                const int bsz = std::min(count - c, C);
                params.count  = bsz;
                for (int i = 0; i < bsz; ++i) {
                    params.src_ptr[i] = (T*)src_ptr[c + i];
                    params.dst_ptr[i] = (T*)dst_ptr[c + i];
                    FT_CHECK(size[c + i] % sizeof(T) == 0);
                    params.size[i] = size[c + i] / sizeof(T);
                }
                const int max_size = *std::max_element(params.size.begin(), params.size.end());
                dispatch(
                    std::integer_sequence{},
                    [&](auto S) { return max_size <= S; },
                    BatchedCopyLauncher>{max_size, count, ¶ms, st});
            }
        });
}

template
__global__ void maskOutput(T* output, const int* mask, int dim)
{
    int batch_idx = blockIdx.x;
    output += dim * batch_idx;
    int masked = mask[batch_idx];
    for (int i = threadIdx.x; i < dim; i += blockDim.x) {
        output[i] = (masked) ? output[i] : T();
    }
}

template
void invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_t stream)
{
    maskOutput<<>>(output, mask, dim);
}

#ifdef ENABLE_FP32
template void invokeMask(float* output, const int* mask, int batch_size, int dim, cudaStream_t stream);
#endif
template void invokeMask(half* output, const int* mask, int batch_size, int dim, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeMask(__nv_bfloat16* output, const int* mask, int batch_size, int dim, cudaStream_t stream);
#endif

template
__global__ void castFloat2D(const T* input, float* output, int channels)
{
    const int vi = blockIdx.x * blockDim.x + threadIdx.x;
    const int bi = blockIdx.y;
    input += (size_t)bi * channels;
    output += (size_t)bi * channels;

    const int step = gridDim.x * blockDim.x * vec_size;

    for (int i = vi * vec_size; i < channels; i += step) {
        Array src;

        if constexpr (sizeof(src) >= sizeof(uint)) {
            Load(src, input + i);
        }
        else {
            PRAGMA_UNROLL
            for (int j = 0; j < vec_size; ++j) {
                src[j] = input[i + j];
            }
        }

        auto dst = cast(src);

        // store
        Store(output + i, dst);
    }
}

void invokeCastFloat2D(const core::Tensor& src, core::Tensor& dst, cudaStream_t stream)
{
    TM_CHECK(src.is_contiguous());
    TM_CHECK(dst.is_contiguous());
    TM_CHECK(src.shape() == dst.shape());

    auto batch_size = src.shape(0);
    auto channels   = src.shape(1);

    auto invoke = [&](auto t, auto vec_size) {
        using T                      = decltype(t);
        constexpr int threads        = 256;
        const int     blocks_per_tok = (channels + threads * vec_size - 1) / (threads * vec_size);
        const dim3    blocks(blocks_per_tok, batch_size);
        castFloat2D<<>>(  //
            src.data(),
            dst.data(),
            channels);
    };

    auto dispatch_t = [&](auto vec_size) {
        switch (src.dtype()) {
            case kFloat32:
                return invoke(float{}, vec_size);
                break;
            case kFloat16:
                return invoke(half{}, vec_size);
                break;
#ifdef ENABLE_BF16
            case kBfloat16:
                return invoke(__nv_bfloat16{}, vec_size);
                break;
#endif
            default:
                TM_UNREACHABLE;
        }
    };

    if (channels % 4 == 0) {
        return dispatch_t(std::integral_constant{});
    }
    else if (channels % 2 == 0) {
        return dispatch_t(std::integral_constant{});
    }
    else {
        return dispatch_t(std::integral_constant{});
    }
}

template
__global__ void CollectHiddenStates_Kernel(const T* src, const int* idxs, T* dst, int dim)
{
    const int bi = blockIdx.x;
    const int ti = idxs[bi];

    if (ti < 0) {
        return;
    }

    src += ti * dim;
    dst += bi * dim;

    for (int di = threadIdx.x; di < dim; di += blockDim.x) {
        dst[di] = src[di];
    }
}

void CollectHiddenStates(const Tensor& src, const Buffer_& idxs, Ref dst, cudaStream_t st)
{
    const auto stride = byte_size(src.dtype(), src.stride(0));

    auto invoke = [&](auto t) {
        using T           = decltype(t);
        const int dim     = stride / sizeof(T);
        const int threads = round_up(min(dim, 1024), WARP_SIZE);
        const int blocks  = idxs.size();
        CollectHiddenStates_Kernel<<>>(
            (const T*)src.raw_data(), idxs.data(), (T*)dst.get().raw_data(), dim);
    };

    if (stride % sizeof(uint4) == 0) {
        invoke(uint4{});
    }
    else if (stride % sizeof(uint2) == 0) {
        invoke(uint2{});
    }
    else if (stride % sizeof(uint1) == 0) {
        invoke(uint1{});
    }
    else if (stride % sizeof(ushort) == 0) {
        invoke(ushort{});
    }
    else {
        TM_CHECK(0) << "unsupported byte stride: " << stride;
    }
}

template
__global__ void
BatchPrefixSumKernel(Array srcs, Array ns, Array dsts)
{
    const int  bi  = blockIdx.x;
    const int* src = srcs[bi];
    int*       dst = dsts[bi];
    const int  n   = ns[bi];

    using BlockScan = cub::BlockScan;

    __shared__ typename BlockScan::TempStorage temp_storage;

    int prefix{};
    for (int i = threadIdx.x; i < round_up(n, BLOCK_DIM); i += BLOCK_DIM) {
        if (i >= BLOCK_DIM) {
            __syncthreads();
        }
        int data = i < n ? src[i] : 0;
        int sum{};
        BlockScan{temp_storage}.ExclusiveSum(data, data, sum);
        if (i < n) {
            dst[i] = prefix + data;
        }
        prefix += sum;
    }

    if (threadIdx.x == 0) {
        dst[n] = prefix;
    }
}

void BatchPrefixSum(const int** srcs, const int* ns, int** dsts, int count, cudaStream_t st)
{
    constexpr int max_count = 1;

    Array p_srcs{};
    Array       p_dsts{};
    Array        p_ns{};

    for (int i = 0; i < count; ++i) {
        p_srcs[i] = srcs[i];
        p_dsts[i] = dsts[i];
        p_ns[i]   = ns[i];
    }

    TM_CHECK_LE(count, max_count);

    constexpr int block = 256;
    const int     grid  = count;

    BatchPrefixSumKernel<<>>(p_srcs, p_ns, p_dsts);
}

__global__ void AppendTokenIdsKernel(int** token_ids_ptrs, const int* output_ids, const int* positions, int batch_size)
{
    int i = threadIdx.x + blockIdx.x * blockDim.x;
    if (i < batch_size) {
        int* token_ids = token_ids_ptrs[i];
        int  pos       = positions[i];
        token_ids[pos] = output_ids[i];
    }
}

void AppendTokenIds(
    int** token_ids_ptrs, const int* output_ids, const int* positions, int batch_size, cudaStream_t stream)
{
    constexpr int block = 128;
    const int     grid  = cdiv(batch_size, block);
    AppendTokenIdsKernel<<>>(token_ids_ptrs, output_ids, positions, batch_size);
}

template
__global__ void SigmoidGateMultiplyKernel(T* attn, const T* gate_base, int dim, int gate_stride, int num_tokens)
{
    const int ti = blockIdx.x;
    const int di = threadIdx.x + blockIdx.y * blockDim.x;
    if (ti >= num_tokens || di >= dim) {
        return;
    }
    float g             = (float)gate_base[ti * gate_stride + di];
    float s             = 1.0f / (1.0f + __expf(-g));
    float a             = (float)attn[ti * dim + di];
    attn[ti * dim + di] = (T)(a * s);
}

void invokeSigmoidGateMultiply(
    void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream)
{
    constexpr int block = 256;
    const dim3    grid(num_tokens, cdiv(dim, block));

    auto invoke = [&](auto t) {
        using T = decltype(t);
        SigmoidGateMultiplyKernel<<>>(
            (T*)attn, (const T*)gate_base, dim, gate_stride, num_tokens);
    };

    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_kernels.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/core/core.h"

#include 

#include 
namespace turbomind {

void invokeGatherOutput(int*         output_ids,
                        const int*   ids,
                        const int*   context_length,
                        int          max_context_len,
                        int          max_gen_step,
                        int          max_output_len,
                        int          batch_size,
                        cudaStream_t stream);

void invokeUpdateOutput(int**        request_output_ids_ptrs,
                        int**        request_seqlen_ptrs,
                        const int*   output_ids,
                        const int*   sequence_lengths,
                        const int*   request_output_ids_lens,
                        int          max_session_len,
                        bool         token_generated,
                        int          batch_size,
                        cudaStream_t stream);

// [aaa, bbbb, cc, ddd] -> [aaabbbbccddd]
void invokeCompactOutputIds(int*         cu_output_ids,
                            const int*   output_ids,
                            const int*   sequence_lengths,
                            int          max_session_len,
                            bool         token_generated,
                            int          batch_size,
                            cudaStream_t stream);

void invokeIndexedCopy(void**       h_src_ptr,
                       void**       h_dst_ptr,
                       const int*   h_elem_sz,
                       const int*   h_src_idx,
                       const int*   h_dst_idx,
                       int          count,
                       int          n_copys,
                       cudaStream_t st);

void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st);

// ABCDe            ABCDe     e
// ABCDEFGHIJk      ABCDEFGHIJk
// ABCDEFGHi    ->  ABCDEFGHi i
// ABCDEFGh         ABCDEFGh  h
// ABCd             ABCd      d
void invokePadLastTokenIds(
    int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream);

void invokeGetFeatureOfLastToken(
    uint16_t* output, const uint16_t* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream);

template
void invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_t stream);

void invokeCastFloat2D(const core::Tensor& src, core::Tensor& dst, cudaStream_t stream);

void CollectHiddenStates(const Tensor& src, const Buffer_& idxs, Ref dst, cudaStream_t st);

void BatchPrefixSum(const int** srcs, const int* ns, int** dsts, int count, cudaStream_t st);

inline void PrefixSum(const int* src, int n, int* dst, cudaStream_t st)
{
    return BatchPrefixSum(&src, &n, &dst, 1, st);
}

void AppendTokenIds(int**        token_ids_ptrs,  //
                    const int*   output_ids,
                    const int*   positions,
                    int          batch_size,
                    cudaStream_t stream);

// Apply sigmoid gating: attn[i] *= sigmoid(gate[i])
// attn:        [num_tokens, dim], contiguous
// gate_base:   pointer to first gate element in QKV buffer
// gate_stride: stride between tokens in QKV buffer (elements)
void invokeSigmoidGateMultiply(
    void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_params.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/models/llama/llama_rope.h"

namespace turbomind {

struct MLAParam {
    int q_lora_rank;
    int kv_lora_rank;
    int qk_rope_dim;
    int v_head_dim;
};

struct ModelParam {
    size_t   head_num;
    size_t   head_dim;
    size_t   kv_head_num;
    size_t   hidden_units;
    size_t   layer_num;
    size_t   vocab_size;
    size_t   embedding_size;
    float    norm_eps;
    int      quant_policy;
    bool     attn_bias;
    bool     attn_sink;
    bool     mlp_bias;
    DataType data_type;

    // Weight types for mixed quantization support.
    // Models like mixed AWQ (e.g. QuantTrio GLM-4.7-Flash) quantize FFN/expert
    // weights to int4 but keep attention weights as fp16. GptOss mxfp4 quantizes
    // only MoE experts to e2m1 while keeping attention and shared experts as fp16.
    //
    //                  weight_type   ffn_weight_type   expert_weight_type
    //  Pure fp16       float16       float16           float16
    //  Full AWQ        int4          int4              int4
    //  Mixed AWQ       float16       int4              int4
    //  GptOss mxfp4    bfloat16      bfloat16          e2m1
    DataType weight_type;         // attention weights
    DataType expert_weight_type;  // MoE routed expert weights
    DataType ffn_weight_type;     // dense FFN / shared expert weights

    int      group_size;
    MLAParam mla;
    bool     qk_norm;
    int      tune_layer_num;

    ActivationType act_type;

    std::vector window_size;
    std::vector inter_size;
    std::vector layer_types;

    // Qwen3.5 Gated DeltaNet linear attention params
    int linear_key_head_dim    = 0;
    int linear_value_head_dim  = 0;
    int linear_conv_kernel_dim = 0;
    int linear_num_key_heads   = 0;
    int linear_num_value_heads = 0;

    DataType linear_state_dtype = {};

    bool attn_output_gate = false;  // Qwen3.5: doubles Q projection in full-attention layers

    // Layer indices whose MoE experts use data_type (fp16) instead of
    // expert_weight_type (e.g. int4).  Populated from modules_to_not_convert
    // patterns like 'model.layers.0.'.
    std::set unquantized_expert_layers;
};

inline bool HasLinearAttention(const ModelParam& model_param)
{
    for (int type : model_param.layer_types) {
        if (type == 1) {
            return true;
        }
    }
    return false;
}

/// TODO: rename all `gate` in the context of MoE router to `router`
struct MoeParam {
    enum Method
    {
        kNaive,
        kFused
    } method;

    int   experts_per_token;
    int   inter_size;
    bool  norm_topk_prob;
    bool  shared_gate;
    float routed_scale;

    bool router_bias;

    int         topk_group;
    std::string topk_method;
    int         n_group;
    std::string scoring_func;
    int         router_n_groups;

    std::vector expert_num;
};

struct AttentionParam {
    float softmax_scale;
    int   cache_block_seq_len;
    // logn attention
    bool use_logn_attn;
    int  max_position_embeddings;
    // rotary embedding
    RopeParam rope;
};

struct EngineParam {
    // batch params
    int max_batch_size;
    int session_len;
    int step_length;

    // cache params
    float cache_max_block_count;
    int   cache_chunk_size;
    bool  enable_prefix_caching;
    bool  enable_metrics;

    // chunking params
    int max_forward_token_num;
    int max_context_token_num;
    int num_tokens_per_iter;
    int max_prefill_iters;

    // parallel params
    int outer_dp_size;
    int outer_dp_rank;
    int attn_dp_size;
    int attn_dp_rank;
    int attn_tp_size;
    int attn_tp_rank;
    int attn_cp_size;
    int attn_cp_rank;
    int mlp_tp_size;
    int mlp_tp_rank;

    // multi-node
    int nnodes;
    int node_rank;

    std::vector devices;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_rope.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include 

namespace turbomind {

enum class RopeType
{
    kNull,
    kDefault,
    kLinear,
    kDynamic,
    kYarn,
    kLlama3,
    kMrope,
};

inline RopeType GetRoPEType(const std::string& type)
{
    std::map lookup = {{"default", RopeType::kDefault},
                                              {"linear", RopeType::kLinear},
                                              {"dynamic", RopeType::kDynamic},
                                              {"yarn", RopeType::kYarn},
                                              {"llama3", RopeType::kLlama3},
                                              {"mrope", RopeType::kMrope}};
    return lookup.at(type);
}

struct YarnRopeParam {
    float attention_factor;
    float beta_fast;
    float beta_slow;
};

struct Llama3RopeParam {
    float low_freq_factor;
    float high_freq_factor;
    int   original_max_position_embeddings;
};

struct MropeRopeParam {
    int3 section;
};

struct RopeParam {
    RopeType type;
    // common
    float base;
    int   dim;
    float factor;
    int   max_position_embeddings;
    // unique
    union {
        YarnRopeParam   yarn;
        Llama3RopeParam llama3;
        MropeRopeParam  mrope;
    };
};

struct YarnRopeKernelParam {
    float scale_factor;
    float attention_factor;
    float ramp_inv_factor_div_2;
    float ramp_inv_factor_mul_min;
};

struct Llama3RopeKernelParam {
    float scale_factor;
    float alpha;
    float beta;
};

struct MropeRopeKernelParam {
    int3 section;

    int  stride{};
    int* position_ids{};
    int* position_delta{};
    int* length{};
};

struct RopeKernelParam {
    RopeType type;

    float* base{};  // for dynamic ntk
    int    dim;
    float  scale_factor;
    float  inv_factor;

    YarnRopeKernelParam   yarn;
    Llama3RopeKernelParam llama3;
    MropeRopeKernelParam  mrope;
};

inline void init_rope_kernel_param(const RopeParam& rope, RopeKernelParam& rope_kernel)
{
    rope_kernel.type         = rope.type;
    rope_kernel.dim          = rope.dim;
    rope_kernel.scale_factor = -std::log2(rope.base) / rope.dim;
    if (rope.type == RopeType::kDynamic) {
        rope_kernel.inv_factor = 1.f;
    }
    else {
        rope_kernel.inv_factor = (rope.factor != 0.f) ? 1.0 / rope.factor : 1.f;
    }

    if (rope.type == RopeType::kYarn) {
        auto&        src = rope.yarn;
        auto&        dst = rope_kernel.yarn;
        const double PI  = 3.14159265358979323846;

        auto find_correction_dim = [&](float num_rotations) {
            return (rope.dim * std::log(rope.max_position_embeddings / (num_rotations * 2 * PI)))
                   / (2 * std::log(rope.base));
        };

        auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
            low  = std::floor(find_correction_dim(low_rot));
            high = std::ceil(find_correction_dim(high_rot));
            low  = std::max(low, 0.f);
            high = std::min(high, rope.dim - 1.f);
        };

        float low, high;
        find_correction_range(src.beta_fast, src.beta_slow, low, high);
        // https://github.com/huggingface/transformers/blob/6c3f168b36882f0beebaa9121eafa1928ba29633/src/transformers/modeling_rope_utils.py#L216
        if (low == high) {
            high += 0.001f;
        }
        dst.ramp_inv_factor_div_2   = 1.0 / (high - low) / 2.0;
        dst.ramp_inv_factor_mul_min = 1.0 / (high - low) * low;
        dst.attention_factor        = src.attention_factor;
    }
    else if (rope.type == RopeType::kLlama3) {
        auto& src = rope.llama3;
        auto& dst = rope_kernel.llama3;

        const double PI                   = 3.14159265358979323846;
        float        inv_diff_freq_factor = 1.0 / (src.high_freq_factor - src.low_freq_factor);
        dst.alpha                         = src.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
        dst.beta                          = src.low_freq_factor * inv_diff_freq_factor;
    }

    else if (rope.type == RopeType::kMrope) {
        auto& src     = rope.mrope;
        auto& dst     = rope_kernel.mrope;
        dst.section.x = src.section.x * 2;
        dst.section.y = src.section.y * 2 + dst.section.x;
        dst.section.z = src.section.z * 2 + dst.section.y;
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind {

CmpMode compare_mode = kCmpRead;
// CmpMode compare_mode = kCmpWrite;

template
void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)
{
    std::vector h_data(size);
    cudaMemcpyAsync(h_data.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream);

    check_cuda_error(cudaStreamSynchronize(stream));

    size_t nan_cnt = 0;
    for (const auto& x : h_data) {
        nan_cnt += std::isnan(static_cast(x));
    }
    if (nan_cnt) {
        std::cerr << key << ": NaN count " << nan_cnt << "\n";
    }
}

template
void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
{
    // read a from file
    std::vector h_a(size);
    {
        const auto    filename = "tmp/" + key + ".cmp";
        std::ifstream ifs(filename, std::ios::binary);
        if (!ifs.is_open()) {
            std::cerr << key << ": failed to open " + filename << "\n";
            return;
        }
        ifs.seekg(0, ifs.end);
        const auto actual_size_in_bytes = ifs.tellg();
        ifs.seekg(0, ifs.beg);
        const auto expect_size_in_bytes = sizeof(T) * size;
        if (actual_size_in_bytes != expect_size_in_bytes) {
            std::cerr << key << ": file size in bytes mismatch, expect " << expect_size_in_bytes << ", got "
                      << actual_size_in_bytes << "\n";
            return;
        }
        ifs.read((char*)h_a.data(), sizeof(T) * h_a.size());
    }
    std::vector h_b(size);
    check_cuda_error(cudaMemcpyAsync(h_b.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));
    check_cuda_error(cudaStreamSynchronize(stream));

    using Tacc         = std::conditional_t, int64_t, float>;
    constexpr Tacc eps = std::is_integral_v ? 1 : 1e-8f;

    Tacc asum{};
    Tacc rsum{};
    Tacc amean_r{};
    Tacc amean_x{};
    for (size_t i = 0; i < size; ++i) {
        Tacc x        = (Tacc)h_b[i];
        Tacc r        = (Tacc)h_a[i];
        Tacc abs_diff = std::abs(x - r);
        Tacc rel_diff = abs_diff / std::max(std::max(std::abs(r), std::abs(x)), eps);
        asum += abs_diff;
        rsum += rel_diff;
        amean_x += std::abs(x);
        amean_r += std::abs(r);
    }

    fprintf(stderr,
            "%15s%15f%15f%15f%15f%15f\n",
            key.c_str(),
            (float)amean_x / (float)size,
            (float)amean_r / (float)size,
            (float)asum,
            (float)asum / (float)size,
            (float)rsum / (float)size);

    check_cuda_error(cudaMemcpyAsync(ptr, h_a.data(), sizeof(T) * h_a.size(), cudaMemcpyDefault, stream));
    check_cuda_error(cudaStreamSynchronize(stream));
}

template
void CmpWrite(T* ptr, size_t size, std::string key, cudaStream_t stream)
{
    std::vector a(size);
    // copy a to host
    check_cuda_error(cudaMemcpyAsync(a.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));
    check_cuda_error(cudaStreamSynchronize(stream));
    // write to file
    {
        std::ofstream ofs("tmp/" + key + ".cmp", std::ios::binary);
        ofs.write((char*)a.data(), sizeof(T) * a.size());
    }
}

template
void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream)
{
    // std::cerr << "Comparing " << key << "\n";
    if (mode == kCmpRead) {
        CmpRead(ptr, size, key, stream);
    }
    else if (mode == kCmpWrite) {
        CmpWrite(ptr, size, key, stream);
    }
    else {
        // kCmpNone
    }
}

template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);
template void Compare(__nv_bfloat16* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);

template void CheckNan(const float* ptr, size_t size, std::string key, cudaStream_t stream);
template void CheckNan(const half* ptr, size_t size, std::string key, cudaStream_t stream);

size_t curandStateGetSize()
{
    return sizeof(curandState_t);
}

bool isDebug()
{
    static const bool is_debug = [] {
        const auto level = std::getenv("TM_DEBUG_LEVEL");
        if (level && level == std::string("DEBUG")) {
            return true;
        }
        return false;
    }();
    return is_debug;
}

int64_t& gSequenceIds(int batch_idx)
{
    thread_local std::vector ids{};
    if (batch_idx >= ids.size()) {
        ids.resize(batch_idx + 1, -1);
    }
    return ids.at(batch_idx);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/llama_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once
#include "src/turbomind/utils/nvtx_utils.h"
#include 
#include 
#include 
#include 

namespace turbomind {

enum QuantPolicy
{
    kNone = 0x00,
    // reserve 0x01 and 0x02 for backward compatibility
    kReserve1 = 0x01,
    kReserve2 = 0x02,
    // quantize cache kv
    kCacheKVInt8 = 0x08,
    kCacheKVInt4 = 0x04,
};

enum CmpMode
{
    kCmpNone,
    kCmpRead,
    kCmpWrite,
};

extern CmpMode compare_mode;

template
void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);

template
void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream);

namespace detail {

template
std::string to_string(T x)
{
    return std::to_string(x);
}

inline std::string to_string(std::string x)
{
    return x;
}

}  // namespace detail

template
std::string Concat(std::string key, Args&&... args)
{
    std::vector args_str{detail::to_string((Args &&) args)...};
    for (const auto& s : args_str) {
        key.append("_");
        key.append(s);
    }
    return key;
}

size_t curandStateGetSize();

bool isDebug();

struct NvtxScope {
    explicit NvtxScope(const std::string& name)
    {
        PUSH_RANGE(name.c_str());
    }

    ~NvtxScope()
    {
        POP_RANGE;
    }
};

int64_t& gSequenceIds(int batch_idx);

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/mla_utils.cu
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/check.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/math.h"

namespace turbomind {

template
__global__ void mla_copy_qkv_kernel(T*       qkv,        // [s, head_num + 2, kv_lora_rank + rope_dim]
                                    const T* q,          // [s, head_num,     kv_lora_rank + rope_dim]
                                    const T* kv_a_k_pe,  // [s, kv_lora_rank + rope_dim]
                                    int      head_num,   // q head num
                                    int      head_dim,   // kv_lora_rank + rope_dim
                                    int      kv_lora_rank,
                                    int      rope_dim)
{
    const int type = blockIdx.y;

    const int64_t ti = blockIdx.x;
    const int     di = threadIdx.x;

    const int offset = di * vec_size < rope_dim ? kv_lora_rank : -rope_dim;

    Array data;

    if (type == 0) {  // Q
        for (int hi = threadIdx.y; hi < head_num; hi += blockDim.y) {
            if (di * vec_size < head_dim) {
                Load(data, &q[ti * head_num * head_dim + hi * head_dim + di * vec_size + offset]);
                Store(&qkv[ti * (head_num + 1) * head_dim + hi * head_dim + di * vec_size], data);
            }
        }
    }
    else if (type == 1) {  // K/V
        if (threadIdx.y == 0) {
            if (di * vec_size < head_dim) {
                Ldg(data, &kv_a_k_pe[ti * head_dim + di * vec_size + offset]);
                Store(&qkv[ti * (head_num + 1) * head_dim + (head_num + 0) * head_dim + di * vec_size], data);
            }
        }
    }
}

template
void invokeMLACopyQKV(T*           qkv,
                      const T*     q,
                      const T*     kv_a_k_pe,
                      int          token_num,
                      int          head_num,
                      int          kv_lora_rank,
                      int          rope_dim,
                      cudaStream_t stream)
{
    constexpr int vec_size = 16 / sizeof(T);

    const int head_dim = kv_lora_rank + rope_dim;  // 512 + 64 = 576

    dim3 block(round_up(head_dim / vec_size, WARP_SIZE), head_num);

    // make sure block size <= 1024
    while (block.x * block.y > 1024) {
        block.y /= 2;
    }

    const dim3 grid(token_num, 2);

    mla_copy_qkv_kernel
        <<>>(qkv, q, kv_a_k_pe, head_num, head_dim, kv_lora_rank, rope_dim);
}

void MLACopyQKV(DataType     dtype,
                void*        qkv,
                const void*  q,
                const void*  kv_a_k_pe,
                int          token_num,
                int          head_num,
                int          kv_lora_rank,
                int          rope_dim,
                cudaStream_t stream)
{
    auto invoke = [&](auto t) {
        using T = decltype(t);
        invokeMLACopyQKV(
            (T*)qkv, (const T*)q, (const T*)kv_a_k_pe, token_num, head_num, kv_lora_rank, rope_dim, stream);
    };

    TM_CHECK_EQ(byte_size(dtype, 1), 2) << "unsupported data type: " << dtype;

    return invoke(uint16_t{});
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/mla_utils.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once

#include 

#include "src/turbomind/core/data_type.h"

namespace turbomind {

void MLACopyQKV(DataType     dtype,
                void*        qkv,
                const void*  q,
                const void*  kv_a,
                int          token_num,
                int          head_num,
                int          kv_lora_rank,
                int          rope_dim,
                cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/moe_ffn_layer.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 

#include "src/turbomind/core/context.h"
#include "src/turbomind/kernels/activation.h"
#include "src/turbomind/kernels/norm/rms_norm.h"

#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/moe_ffn_layer.h"

#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/cuda_utils.h"

// #include "dbg.h"

namespace turbomind {

MoeFfnLayer::MoeFfnLayer(const ModelParam& model, const MoeParam& param, const EngineParam& engine, const Context& ctx):
    inter_size_(param.inter_size / engine.mlp_tp_size),
    hidden_dim_(model.hidden_units),
    tp_size_(engine.mlp_tp_size),
    param_(param),
    is_warm_up_{*ctx.is_warm_up},
    linear_(*ctx.linear)
{
    TM_CHECK(!param.expert_num.empty());

    const int max_expert_num = *std::max_element(param.expert_num.begin(), param.expert_num.end());

    if (param_.method == MoeParam::kFused) {
        // pass
    }
    else {
        expert_ffn_ = std::make_unique(model, ctx);
    }

    h_offsets_ = {max_expert_num + 1, kCPUpinned};

    const int max_token_num = engine.max_forward_token_num * engine.attn_dp_size;
    const int pad_token_num = (max_token_num + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;

    // dbg(inter_size_,
    //     hidden_dim_,
    //     tp_size_,
    //     param_.method,
    //     param.expert_num,
    //     max_expert_num,
    //     max_token_num,
    //     pad_token_num,
    //     param_.experts_per_token);

    masks_   = {max_expert_num * pad_token_num, kDEVICE};
    f2n_     = {param_.experts_per_token * max_token_num, kDEVICE};
    f2E_     = {param_.experts_per_token * max_token_num, kDEVICE};
    en2f_    = {param_.experts_per_token * max_token_num, kDEVICE};
    scales_  = {param_.experts_per_token * max_token_num, kDEVICE};
    offsets_ = {max_expert_num + 1, kDEVICE};
    accum_   = {max_expert_num * kMoeGateMaxTiles, kDEVICE};
}

Tensor_ MoeFfnLayer::Gate(const Tensor& input, const LlamaDenseWeight& gate)
{
    auto& weight = gate.weight;
    TM_CHECK_EQ(input.shape(1), weight.shape(0));
    Tensor_ logits{{input.shape(0), weight.shape(1)}, kDEVICE};
    linear_.Forward(input, gate, logits);
    sync_check_cuda_error();
    ApplyBias(logits, gate.bias, core::Context::stream().handle());
    sync_check_cuda_error();
    return logits;
}

void MoeFfnLayer::Forward(ForwardParam& p)
{
    const int   tokens = p.input.shape(0);
    const auto& moe    = *p.weights;

    const size_t padded     = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;
    const int    expert_num = moe.experts.size();

    FT_CHECK(expert_num);

    auto logits = Gate(p.input, moe.gate);

    TM_DEBUG_TENSOR(logits, "logits", 2);

    const auto st = core::Context::stream().handle();

    // dump_logits(tokens, layer_id);

    if (param_.topk_method == "noaux_tc") {
        // invokeMoeGate_NoAuxTC clears accum and masks internally
        TM_CHECK_EQ(param_.n_group, 1);
        TM_CHECK_EQ(param_.topk_group, 1);
        const float* correction_bias =
            (moe.score_correction_bias.size() > 0) ? moe.score_correction_bias.data() : nullptr;
        invokeMoeGate_NoAuxTC(f2n_.data(),
                              f2E_.data(),
                              en2f_.data(),
                              offsets_.data(),
                              scales_.data(),
                              masks_.data(),
                              accum_.data(),
                              logits.data(),
                              correction_bias,
                              tokens,
                              padded,
                              expert_num,
                              param_.experts_per_token,
                              param_.norm_topk_prob,
                              param_.routed_scale,
                              param_.scoring_func == "sigmoid",
                              st);
    }
    else {
        // V2: accum must be cleared by caller; masks cleared internally
        check_cuda_error(cudaMemsetAsync(accum_.data(), 0, sizeof(int) * expert_num * kMoeGateMaxTiles, st));

        bool softmax = true;
        if (param_.topk_method == "group_limited_greedy") {
            invokeMoeSoftmaxMaskTopKGroups(
                logits.data(), tokens, expert_num, expert_num / param_.n_group, param_.topk_group, st);
            sync_check_cuda_error();
            softmax = false;
        }

        /// TODO: fix illegal memory access even if NaN are present in logits
        invokeMoeGate_V2(f2n_.data(),
                         f2E_.data(),
                         en2f_.data(),
                         offsets_.data(),
                         scales_.data(),
                         masks_.data(),
                         accum_.data(),
                         logits.data(),
                         tokens,
                         padded,
                         expert_num,
                         param_.experts_per_token,
                         softmax,
                         param_.norm_topk_prob,
                         param_.routed_scale,
                         st);
    }
    sync_check_cuda_error();

    if (is_warm_up_) {
        std::mt19937     g;
        const auto       expert_ids = SampleUniform(tokens, expert_num, param_.experts_per_token, g);
        std::vector cnt(expert_num);
        for (const auto& x : expert_ids) {
            ++cnt[x];
        }
        h_offsets_[0] = 0;
        for (int i = 0; i < expert_num; ++i) {
            h_offsets_[i + 1] = h_offsets_[i] + cnt[i];
        }
        check_cuda_error(
            cudaMemcpyAsync(offsets_.data(), h_offsets_.data(), sizeof(int) * (expert_num + 1), cudaMemcpyDefault, st));
    }

    temp_ = Tensor{{param_.experts_per_token * tokens, hidden_dim_}, p.input.dtype(), p.input.device()};

    if (param_.method == MoeParam::kNaive) {

        invokeMoeDispatch(temp_, p.input, f2n_.data(), param_.experts_per_token, st);
        sync_check_cuda_error();

        check_cuda_error(
            cudaMemcpyAsync(h_offsets_.data(), offsets_.data(), sizeof(int) * (expert_num + 1), cudaMemcpyDefault, st));

        check_cuda_error(cudaStreamSynchronize(st));

        TM_CHECK_EQ(h_offsets_[expert_num], tokens * param_.experts_per_token);

        for (int i = 0; i < expert_num; ++i) {
            if (int count = h_offsets_[i + 1] - h_offsets_[i]) {
                auto io = temp_.slice({h_offsets_[i], 0}, {count, -1});
                expert_ffn_->forward({io, io, moe.experts.at(i).get(), p.layer_id});
            }
        }
    }
    else {

        auto& block = moe.block;

        auto indices = f2n_.slice(0, tokens * param_.experts_per_token);
        auto offsets = offsets_.slice(0, expert_num + 1);

        Tensor inter = linear_.Forward(p.input, block.fused_gating_intermediate, indices, offsets_);
        sync_check_cuda_error();

        if (!block.is_fused_silu) {
            Activation(inter, block.fused_gating_intermediate.bias, f2E_, moe.block.act_type, st);
            sync_check_cuda_error();
        }

        linear_.Forward(inter.slice({0, 0}, {-1, inter_size_}), block.output, {}, offsets, temp_);
        sync_check_cuda_error();
    }

    if (moe.shared_gate.weight) {
        shared_scales_ = Gate(p.input, moe.shared_gate);
    }
}

void MoeFfnLayer::Combine(ForwardParam& p)
{
    auto& moe = *p.weights;

    invokeMoeCombine(p.output,
                     temp_,
                     p.weights->block.output.bias,
                     scales_.data(),
                     en2f_.data(),
                     f2E_.data(),
                     shared_scales_.data_or((float*)nullptr),
                     param_.experts_per_token,
                     1.f / tp_size_,
                     p.scale,
                     core::Context::stream().handle());
    sync_check_cuda_error();

    temp_          = {};
    shared_scales_ = {};
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/moe_ffn_layer.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/turbomind/kernels/gemm/context.h"
#include "src/turbomind/kernels/gemm/moe_utils_v2.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class MoeFfnLayer {
public:
    MoeFfnLayer(const ModelParam& model, const MoeParam& param, const EngineParam& engine, const Context& ctx);

    struct ForwardParam {
        Tensor              input;
        Tensor              output;
        const MoeFfnWeight* weights;
        float               scale;
        int                 layer_id;
    };

    void Forward(ForwardParam& p);

    void Combine(ForwardParam& p);

private:
    Tensor_ Gate(const Tensor& input, const LlamaDenseWeight& gate);

    void dump_logits(int token_num, int layer_id, int expert_num);

    const int inter_size_;
    const int hidden_dim_;
    const int tp_size_;

    const MoeParam param_;

    int& is_warm_up_;

    LlamaLinear& linear_;

    std::unique_ptr expert_ffn_;

    ///////////////////////////////////////////////////////
    /// runtime states
    Buffer_ h_offsets_;

    Buffer_   masks_;
    Buffer_   f2n_;
    Buffer_   f2E_;
    Buffer_   en2f_;
    Buffer_ scales_;
    Buffer_   accum_;
    Buffer_   offsets_;

    Tensor         temp_;
    Tensor_ shared_scales_;
    ///////////////////////////////////////////////////////
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/test_cache_manager.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include "BlockManager.h"
#include "SequenceManager.h"

#include "src/turbomind/utils/allocator.h"

#include "src/turbomind/utils/debug_utils.h"
#include 
#include 

using namespace turbomind;

std::ostream& operator<<(std::ostream& os, const Block* b)
{
    os << "(" << b->id << "," << b->timestamp << ")";
    return os;
}

TEST_CASE("BlockManager")
{
    Allocator allocator(0);

    BlockManager m(1024, 32, 8, &allocator);
    REQUIRE(m.max_block_count() == 32);
    REQUIRE(m.free_count() == 32);

    auto blocks1 = m.Allocate(10);

    dbg(blocks1);

    REQUIRE(blocks1.size() == 10);
    REQUIRE(m.active_count() == blocks1.size());
    REQUIRE(m.free_count() == 22);

    auto blocks2 = m.Allocate(6);
    REQUIRE(blocks2.size() == 6);
    REQUIRE(m.active_count() == blocks1.size() + blocks2.size());
    REQUIRE(m.free_count() == 16);

    auto blocks3 = m.Allocate(16);
    REQUIRE(blocks3.size() == 16);
    REQUIRE(m.active_count() == 32);
    REQUIRE(m.free_count() == 0);

    std::copy(blocks3.begin(), blocks3.end(), std::back_inserter(blocks1));
    std::copy(blocks2.begin(), blocks2.end(), std::back_inserter(blocks1));

    m.Touch(blocks1);

    REQUIRE(m.Unlock(blocks1) == 32);
    REQUIRE(m.active_count() == 0);
    REQUIRE(m.free_count() == 0);
    REQUIRE(m.cached_count() == 32);

    m.Evict(16);
    REQUIRE(m.active_count() == 0);
    REQUIRE(m.free_count() == 16);
    REQUIRE(m.cached_count() == 16);

    auto blocks4 = m.Allocate(14);
    REQUIRE(m.active_count() == 14);
    REQUIRE(m.free_count() == 2);
    REQUIRE(m.cached_count() == 16);
}

TEST_CASE("SequenceManager basic test")
{
    Allocator allocator(0);

    SequenceManager manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);

    REQUIRE(manager.max_block_count() == 20);
    REQUIRE(manager.Contains(1) == false);

    auto s1 = manager.Create(1);
    dbg(*s1);
    REQUIRE(manager.Contains(1) == true);

    manager.Erase(1);
    REQUIRE(manager.Contains(1) == false);

    s1 = manager.Create(1);
    REQUIRE(manager.Contains(1) == true);

    auto outcome = manager.Materialize({s1}, {128}, {100}, 1);
    dbg(s1->blocks);
    REQUIRE(s1->blocks.size() == 2);

    auto s2 = manager.Create(2);
    REQUIRE(manager.Contains(2));

    outcome = manager.Materialize({s1, s2}, {128, 2559}, {2, 1}, 1);
    dbg(outcome);
    REQUIRE(outcome.allocation == 20);
    REQUIRE(outcome.swap_in == 1);
    REQUIRE(outcome.swap_out == 1);

    auto s3 = manager.Create(3);
    outcome = manager.Materialize({s1, s2, s3}, {127, 2559, 255}, {1, 100, 2}, 1);
    dbg(outcome);
}

TEST_CASE("SequenceManager functional test")
{
    Allocator allocator(0);
    SequenceManager                manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);

    auto seq = manager.Create(1);
    for (int i = 0; i < 1024; ++i) {
        auto outcome = manager.Materialize({seq}, {i}, {0}, 1);
        if (outcome.allocation) {
            dbg(i, outcome);
        }
    }
}


================================================
FILE: src/turbomind/models/llama/unified_attention_layer.cc
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc

#include 
#include 
#include 
#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/core.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/tensor.h"
#include "src/turbomind/engine/request.h"

#include "src/turbomind/kernels/attention/attention.h"
#include "src/turbomind/kernels/attention/decoding.h"
#include "src/turbomind/kernels/attention/kv_cache_utils_v2.h"
#include "src/turbomind/kernels/norm/rms_norm.h"

#include "src/turbomind/macro.h"

#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_rope.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/mla_utils.h"
#include "src/turbomind/models/llama/unified_attention_layer.h"

#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"

// #include "dbg.h"

namespace turbomind {

struct AttentionData {
    struct Stat {
        int n;
        int q_sum;
        int q_max;
        int k_sum;
        int k_max;
    } decode, prefill;

    Buffer_ block_ptrs;
    Buffer_   block_ptrs_offsets;

    Buffer_ rope_base;

    Tensor_ mrope_position_ids;
    Buffer_ mrope_position_delta;
    Buffer_ mrope_length;

    // borrowed from env
    Buffer_ finished;
    Buffer_  q_offsets;
    Buffer_  k_offsets;

    // int dbg_offset;
    // int dbg_size;
};

UnifiedAttentionLayer::~UnifiedAttentionLayer()
{

    check_cuda_error(cudaEventDestroy(aux_event_));
    check_cuda_error(cudaEventDestroy(qkv_event_));
    check_cuda_error(cudaStreamDestroy(aux_stream_));

    aux_event_ = qkv_event_ = {};
    aux_stream_             = {};
}

UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam&     model,
                                             const AttentionParam& attn,
                                             const EngineParam&    engine,
                                             int                   tp_size,
                                             const Context&        ctx,
                                             int                   phases,
                                             bool                  init):
    head_num_(model.head_num),
    kv_head_num_(model.kv_head_num),
    size_per_head_(model.head_dim),
    hidden_units_(model.hidden_units),
    local_head_num_(head_num_ / tp_size),
    local_kv_head_num_(model.kv_head_num / tp_size),
    param_(attn),
    model_param_(model),
    engine_param_(engine),
    cp_fn_ctx_(ctx.comm.d_comm, ctx.comm.d_cp_group),
    is_warm_up_{*ctx.is_warm_up},
    context_(ctx),
    linear_(*ctx.linear),
    arch_(getSMVersion())
{
    TM_CHECK_EQ(head_num_ % tp_size, 0) << head_num_ << " " << tp_size;
    TM_CHECK_EQ(head_num_ % kv_head_num_, 0) << head_num_ << " " << kv_head_num_;

    check_cuda_error(cudaStreamCreateWithFlags(&aux_stream_, cudaStreamNonBlocking));
    check_cuda_error(cudaEventCreateWithFlags(&qkv_event_, cudaEventDisableTiming));
    check_cuda_error(cudaEventCreateWithFlags(&aux_event_, cudaEventDisableTiming));

    init_rope_kernel_param(param_.rope, rope_param_);

    // Skip other attention layer types
    std::vector layer_types = model_param_.layer_types;
    layer_types.resize(model_param_.layer_num);
    cache_layer_ids_.resize(layer_types.size(), -1);
    int next_cache_id = 0;
    for (size_t i = 0; i < layer_types.size(); ++i) {
        if (layer_types[i] == 0) {
            cache_layer_ids_[i] = next_cache_id++;
        }
    }

    Allocator alloc            = core::Context::device_alloc();
    ssize_t   workspace_tokens = kMaxWorkspaceTokens;
    if (engine_param_.attn_cp_size > 1) {
        alloc = GetSymmAllocator(ctx.comm.d_comm);
        workspace_tokens += engine_param_.max_forward_token_num;
    }
    // partial_O layout:
    //   w/  cp, decode(q, h, k, 2) + prefill(q, h, 1, 2)
    //   w/o cp, decode(q, h, k, 2)
    partial_O_  = Tensor_({workspace_tokens, local_head_num_, size_per_head_}, kDEVICE);
    partial_ML_ = Tensor_({engine_param_.attn_cp_size, workspace_tokens, local_head_num_, 2}, alloc);
    split_cnt_  = Tensor_({workspace_tokens}, kDEVICE);
    if (init) {
        const int dim = (int)local_head_num_ * (int)size_per_head_;
        tmp_attn_     = Tensor{{engine_param_.max_forward_token_num, dim}, model.data_type, kDEVICE};
    }

    Clear(split_cnt_.buffer());

    const int bsz = engine.max_batch_size;

    if (rope_param_.type == RopeType::kDynamic) {
        rope_base_buf_ = {bsz + 1, kCPUpinned};
    }
    else if (rope_param_.type == RopeType::kMrope) {
        // `mrope_position_ids` is not buffered
        mrope_position_delta_buf_ = {bsz, kCPUpinned};
        mrope_length_buf_         = {bsz, kCPUpinned};
    }
    const int max_blocks = bsz * cdiv(engine.session_len, param_.cache_block_seq_len);
    for (int i = 0; i < phases; ++i) {
        auto& d               = data_.emplace_back(std::make_shared());
        d->block_ptrs         = {max_blocks + 16, kDEVICE};
        d->block_ptrs_offsets = {bsz + 1, kDEVICE};
        if (rope_param_.type == RopeType::kDynamic) {
            d->rope_base = empty_like(rope_base_buf_, kDEVICE);
        }
        else if (rope_param_.type == RopeType::kMrope) {
            /// TODO: total space for `mrope_position_ids` can be reduced to (max_fwd_tokens, 3)
            d->mrope_position_ids    = {{bsz, engine.session_len, 3}, kDEVICE};
            d->mrope_position_delta  = empty_like(mrope_position_delta_buf_, kDEVICE);
            d->mrope_length          = empty_like(mrope_length_buf_, kDEVICE);
            rope_param_.mrope.stride = d->mrope_position_ids.stride(0);
        }
    }
}

static void init_dynamic_ntk(RequestCache& cache, const RopeParam& rope)
{
    cache.rope_base = rope.base;
    if (auto scaling_factor = rope.factor; scaling_factor > 1.f) {
        const auto max_seq_len = cache.prompt_len;
        const auto max_pos_emb = rope.max_position_embeddings;
        if (max_seq_len > max_pos_emb) {
            scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
            cache.rope_base *= powf(scaling_factor, rope.dim / (rope.dim - 2.f));
            // clang-format off
            TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
                        (long)cache.req->id, scaling_factor, cache.rope_base);
            // clang-format on
        }
    }
}

void UnifiedAttentionLayer::Run(BatchOp op, int phase, TensorMap& env)
{
    if (op == BatchOp::kAdd) {
        Buffer_ rc = env.at("requests").buffer();
        if (rope_param_.type == RopeType::kDynamic) {
            for (int i = 0; i < rc.size(); ++i) {
                init_dynamic_ntk(*rc[i], param_.rope);
            }
        }
    }
    else if (op == BatchOp::kSetup) {
        Setup(phase, env);
    }
    else if (op == BatchOp::kPrepare) {
        data_.at(phase)->finished  = env.at("finished").buffer().borrow();
        data_.at(phase)->q_offsets = env.at("q_offsets").buffer().borrow();
        data_.at(phase)->k_offsets = env.at("k_offsets").buffer().borrow();

        // This is needed in async mode to clear the `attn` buffer for the finished sequences. Ohterwise random NaNs
        // will crash the MoE router later
        /// TODO: use better solution, this increase memory usage and heterogenous attention layers may still break it
        if (tmp_attn_) {
            auto& d = data_.at(phase);
            Clear(tmp_attn_.slice(0, d->decode.n + d->prefill.q_sum));
            Clear(split_cnt_);
        }
    }
}

void UnifiedAttentionLayer::Setup(int phase, TensorMap& env)
{
    const auto& rc  = env.at("batch").data()[0]->rc;
    const int   bsz = rc.size();

    auto& d    = *data_.at(phase);
    auto& copy = *env.at("copy").data()[0];

    {  /// Upload KV cache ptrs
        const Buffer_ offsets = env.at("block_ptrs_offsets").buffer();
        copy(env.at("block_ptrs").buffer(), offsets[bsz], d.block_ptrs);
        copy(offsets, bsz + 1, d.block_ptrs_offsets);
    }

    /// prepare Q/K stats for decode/prefill
    d.decode = d.prefill = {};

    d.decode.n  = std::find_if(rc.begin(), rc.end(), [](auto r) { return r->input_len > 1; }) - rc.begin();
    d.prefill.n = bsz - d.decode.n;

    // d.dbg_offset = d.dbg_size = 0;

    for (int i = 0; i < bsz; ++i) {
        const auto& c = *rc[i];

        // if (c.request->id == 4 && c.input_len > 1) {
        //     d.dbg_offset = d.decode.q_sum + d.prefill.q_sum;
        //     d.dbg_size   = c.input_len;
        // }

        auto& s = i < d.decode.n ? d.decode : d.prefill;
        s.q_sum += c.input_len;
        s.k_sum += c.history_len + c.alpha + c.input_len;
        s.q_max = std::max(s.q_max, c.input_len);
        s.k_max = std::max(s.k_max, c.history_len + c.alpha + c.input_len);
    }

    // auto &D = d.decode, &P = d.prefill;
    // dbg(D.n, D.k_sum, D.k_max, P.n, P.q_sum, P.q_max, P.k_sum, P.k_max);

    /// handling different RoPE types
    if (rope_param_.type == RopeType::kDynamic) {
        for (int i = 0; i < bsz; ++i) {
            rope_base_buf_[i] = rc[i]->rope_base;
        }
        copy(rope_base_buf_, bsz, d.rope_base);
    }
    else if (rope_param_.type == RopeType::kMrope) {
        const auto stride = d.mrope_position_ids.stride(0);
        for (int i = 0; i < rc.size(); ++i) {
            auto& c = *rc[i];
            auto& r = *c.req;
            if (auto pos_ids = r.inputs.try_("mrope_position_ids")) {
                int length                   = pos_ids->shape(0);
                mrope_length_buf_[i]         = length;
                mrope_position_delta_buf_[i] = *r.inputs.at("mrope_position_delta").data();
                if (auto o = Interval{0, length} & Interval{c.history_len + c.alpha, Interval::Size{c.input_len}}) {
                    copy(pos_ids->data() + o.begin() * 3,
                         (int)o.size() * 3,
                         d.mrope_position_ids.data() + i * stride + o.begin() * 3);
                }
            }
            else {
                mrope_length_buf_[i] = mrope_position_delta_buf_[i] = 0;
            }
        }
        copy(mrope_length_buf_, rc.size(), d.mrope_length);
        copy(mrope_position_delta_buf_, rc.size(), d.mrope_position_delta);
    }
}

void UnifiedAttentionLayer::Forward(ForwardParam p)
{
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);

    /////////////////////////////////////////////
    /// parse inputs
    const int token_num = p.input.shape(0);

    if (token_num == 0) {
        return;
    }

    const int layer_id = p.layer_id;

    const auto& weights = *p.weights;

    Tensor qkv;

    auto& d = *data_.at(p.phase);

    // if (d.dbg_size) {
    //     DebugTensor(p.input.slice(d.dbg_offset, d.dbg_size), Concat("attn_in", p.layer_id), 0);
    // }

    if (weights.qkv.output_dim) {
        // [token_num, hidden_dim] -> [token_num, local_q_kv_head_num, head_dim]
        qkv = linear_.Forward(p.input, weights.qkv);
        sync_check_cuda_error();

        if (model_param_.qk_norm) {
            qk_norm(qkv, weights);
        }
    }
    else {
        qkv = forward_mla(p.input, weights);
    }

    TM_DEBUG_TENSOR(qkv, Concat("qkv", layer_id), 3);

    auto invoke = [&](auto t) -> Tensor {
        using T = decltype(t);
        return core_attention(qkv, p, weights);
    };

    Tensor attn = [&]() -> Tensor { TM_DISPATCH_PRIMARY_DTYPES_RET(qkv.dtype(), invoke); }();

    // Apply sigmoid gating: attn *= sigmoid(gate)
    // Gate is stored at the end of each token's QKV: [Q|K|V|Gate]
    if (model_param_.attn_output_gate) {
        const int  q_count     = qkv.shape(0);
        const int  attn_dim    = local_head_num_ * size_per_head_;
        const int  gate_offset = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
        const int  qkv_stride  = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
        const auto stream      = core::Context::stream().handle();
        invokeSigmoidGateMultiply(attn.raw_data(),
                                  (const char*)qkv.raw_data() + gate_offset * byte_size(qkv.dtype(), 1),
                                  attn_dim,
                                  qkv_stride,
                                  q_count,
                                  qkv.dtype(),
                                  stream);
        sync_check_cuda_error();
    }

    TM_DEBUG_TENSOR(attn, Concat("attn", layer_id), 3);

    // if (d.dbg_size) {
    //     DebugTensor(attn.slice(d.dbg_offset, d.dbg_size), Concat("attn_out", p.layer_id), 0);
    // }

    //////////////////////////////////////////////
    /// output gemm  -> 
    (void)linear_.Forward(attn, weights.output, p.output);
    sync_check_cuda_error();
}

template
Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights)
{
    const auto device = qkv.device();
    const auto dtype  = qkv.dtype();

    auto& d = *data_.at(p.phase);

    const int batch_size = d.decode.n + d.prefill.n;
    const int q_count    = qkv.shape(0);

    TM_CHECK_EQ(d.prefill.q_sum + d.decode.n, q_count);

    const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;

    Tensor attn;
    if (tmp_attn_) {
        attn = tmp_attn_.slice(0, q_count);
    }
    else {
        attn = {{q_count, (int)local_head_num_ * (int)size_per_head_}, dtype, device};
    }

    const bool is_mla = model_param_.mla.kv_lora_rank > 0;

    Tensor tmp_kv{
        {(int)local_kv_head_num_, is_mla ? 1 : 2, d.prefill.k_sum + MAX_CTA_S, (int)size_per_head_}, dtype, device};

    const int cache_layer_id = cache_layer_ids_[p.layer_id];

    auto CreateParams = [&](int offset, AttentionData::Stat stat, int max_kv_splits, cudaStream_t stream) {
        AttentionParams params{};

        // Batch offset for `out` and `q` are computed inside the kernel
        params.out = (T*)attn.raw_data();

        params.q = (T*)qkv.raw_data();
        params.k = params.q + local_head_num_ * size_per_head_;
        if (is_mla) {
            params.v      = params.k;
            params.stride = (local_head_num_ + 1 * local_kv_head_num_) * size_per_head_;
        }
        else {
            params.v = params.k + local_kv_head_num_ * size_per_head_;
            // When attn_output_gate, QKV layout is [Q|K|V|Gate] per token
            // stride must account for the extra gate portion at the end
            if (model_param_.attn_output_gate) {
                params.stride = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
            }
            else {
                params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
            }
        }

        if (weights.qkv.bias) {
            params.q_bias = (T*)weights.qkv.bias.data_or(nullptr);
            params.k_bias = params.q_bias + local_head_num_ * size_per_head_;
            params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;
        }

        params.batch_size = stat.n;

        params.token_num = stat.q_sum;
        params.max_q_len = stat.q_max;
        params.max_k_len = stat.k_max;

        // decode only
        params.block_iter_params = BlockIteratorParams{(char**)d.block_ptrs.data(),  //
                                                       d.block_ptrs_offsets.data() + offset,
                                                       cache_layer_id,
                                                       (int)param_.cache_block_seq_len};

        // prefill only
        if (is_mla) {
            params.linear_iter_params = LinearIteratorParams{
                tmp_kv.raw_data(),            // flattened KV
                stat.k_sum * size_per_head_,  // stride to next head
                0                             // stride from K to V
            };
        }
        else {
            params.linear_iter_params = LinearIteratorParams{
                tmp_kv.raw_data(),                // flattened KV
                stat.k_sum * size_per_head_ * 2,  // stride to next head
                stat.k_sum * size_per_head_       // stride from K to V
            };
        }

        params.finished = d.finished.data() + offset;
        params.cu_q_len = d.q_offsets.data() + offset;
        params.cu_k_len = d.k_offsets.data() + offset;

        params.num_heads     = local_head_num_;
        params.num_kv_heads  = local_kv_head_num_;
        params.size_per_head = size_per_head_;
        params.layer_id      = cache_layer_id;

        double scaling = 1.;
        if (param_.softmax_scale) {  // model predefined softmax scale
            scaling *= param_.softmax_scale;
        }
        else {  // default value
            scaling /= std::sqrt((float)params.size_per_head);
        }
        params.inv_sqrt_dh = scaling * std::log2(std::exp(1.));

        params.sinks       = weights.sinks.data_or((T*)nullptr);
        params.scale_sinks = scaling;

        params.window_size = weights.window_size;
        if (!params.window_size) {
            params.window_size = 256 << 20;  // 256 M
        }

        params.rope_param = rope_param_;
        if (rope_param_.type == RopeType::kDynamic) {
            params.rope_param.base = d.rope_base.data() + offset;
        }
        else if (rope_param_.type == RopeType::kMrope) {
            params.rope_param.mrope.position_ids   = d.mrope_position_ids.data() + offset * rope_param_.mrope.stride;
            params.rope_param.mrope.position_delta = d.mrope_position_delta.data() + offset;
            params.rope_param.mrope.length         = d.mrope_length.data() + offset;
        }

        // logn attn
        params.use_logn_attn           = param_.use_logn_attn;
        params.max_position_embeddings = param_.max_position_embeddings;

        // Decoding use only for now
        params.split_cnt   = split_cnt_.data();
        params.partial_ML  = partial_ML_.data();
        params.partial_O   = partial_O_.data();
        params.max_split_k = std::min(std::max(1, kMaxWorkspaceTokens / params.token_num), max_kv_splits);

        // context parallel
        params.cp_rank = engine_param_.attn_cp_rank;
        params.cp_size = engine_param_.attn_cp_size;
        if (params.cp_size > 1) {
            params.cp_size = cutlass::FastDivmod(params.cp_size);

            // update ML,O offset if both prefill and decode present
            const int offset_ML_stage =
                engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0);
            const int offset_ML_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2;
            const int offset_O       = offset ? kMaxWorkspaceTokens * local_head_num_ * size_per_head_ : 0;

            params.partial_ML = partial_ML_.data() + offset_ML_stage + offset_ML_rank;
            params.partial_O  = partial_O_.data() + offset_O;
            params.offset_q   = offset;

            // postprocess func
            params.cp_fn          = CpPost;
            params.cp_fn_ctx      = (void*)&cp_fn_ctx_;
            cp_fn_ctx_.cp_rank    = params.cp_rank;
            cp_fn_ctx_.count      = params.token_num * local_head_num_ * params.max_split_k * 2;
            cp_fn_ctx_.partial_ML = partial_ML_.data() + offset_ML_stage;
            cp_fn_ctx_.stream     = stream;
        }

        params.arch   = arch_;
        params.stream = stream;

        params.quant_policy = model_param_.quant_policy;
        return params;
    };

    const cudaStream_t stream = core::Context::stream().handle();

    cudaStream_t pf_stream = stream;
    cudaStream_t dc_stream = pf_stream;

    if (d.decode.n && d.prefill.n) {
        pf_stream = aux_stream_;
        check_cuda_error(cudaEventRecord(qkv_event_, stream));
        check_cuda_error(cudaStreamWaitEvent(aux_stream_, qkv_event_));
    }

    if (d.prefill.n && !is_warm_up_) {
        const int offset = d.decode.n;
        // We are executing prefill & decoding kernels concurrently, but only have 1 workspace
        // disable split kv for prefill for now
        auto params = CreateParams(offset, d.prefill, 1, pf_stream);
        if constexpr (sizeof(T) == 2) {
            invokeProcessKV_v2_(params);
            sync_check_cuda_error();

            /// TODO: skip flattening for `sm_80`
            invokeFlattenKV_v2_(params, d.prefill.k_sum);
            sync_check_cuda_error();

            dispatchAttention(params);
            sync_check_cuda_error();
        }
    }

    if (d.decode.n && !is_warm_up_) {
        auto params = CreateParams(0, d.decode, kMaxKVSplits, dc_stream);
        if constexpr (sizeof(T) == 2) {
            dispatchDecoding(params);
            sync_check_cuda_error();
        }
    }

    if (d.decode.n && d.prefill.n) {
        check_cuda_error(cudaEventRecord(aux_event_, aux_stream_));
        check_cuda_error(cudaStreamWaitEvent(stream, aux_event_));
    }

    if (is_warm_up_) {
        rng_.set_stream(stream);
        rng_.GenerateUniform(attn.data(), attn.size(), .02f, -.01f);
    }

    return attn;
}

Tensor UnifiedAttentionLayer::forward_mla(const Tensor& hidden_state, const WeightType& w)
{

    const auto token_num = hidden_state.shape(0);
    const auto dtype     = hidden_state.dtype();

    const int q_lora_rank  = w.q_a_proj.output_dim;
    const int kv_lora_rank = w.kv_a_layernorm.size();
    const int qk_rope_dim  = w.kv_a_proj.output_dim - kv_lora_rank;

    Tensor q;

    const auto stream = core::Context::stream().handle();

    if (w.q_proj.weight) {
        q = linear_.Forward(hidden_state, w.q_proj);
        sync_check_cuda_error();
    }
    else {
        Tensor q_a = linear_.Forward(hidden_state, w.q_a_proj);
        sync_check_cuda_error();

        invokeRMSNorm(q_a, q_a, w.q_a_layernorm, model_param_.norm_eps, stream);
        sync_check_cuda_error();

        q = linear_.Forward(q_a, w.q_b_proj);
        sync_check_cuda_error();
    }

    Tensor kv_a_k_pe = linear_.Forward(hidden_state, w.kv_a_proj);
    sync_check_cuda_error();

    auto kv_a = kv_a_k_pe.slice({0, 0}, {-1, kv_lora_rank});
    invokeRMSNorm(kv_a, kv_a, w.kv_a_layernorm, model_param_.norm_eps, stream);
    sync_check_cuda_error();

    const int local_q_kv_head_num = local_head_num_ + 1 * local_kv_head_num_;

    Tensor qkv{{token_num, local_q_kv_head_num, size_per_head_}, dtype, hidden_state.device()};
    MLACopyQKV(dtype,
               qkv.raw_data(),
               q.raw_data(),
               kv_a_k_pe.raw_data(),
               token_num,
               local_head_num_,
               kv_lora_rank,
               qk_rope_dim,
               stream);
    sync_check_cuda_error();

    return qkv;
}

void UnifiedAttentionLayer::qk_norm(Tensor& qkv, const WeightType& weights)
{
    const auto stream = core::Context::stream().handle();

    check_cuda_error(cudaEventRecord(qkv_event_, stream));
    check_cuda_error(cudaStreamWaitEvent(aux_stream_, qkv_event_));

    TM_CHECK(model_param_.attn_bias == false) << "not implemented";

    const auto token_num = qkv.shape(0);

    auto qkv3 = qkv.view({token_num, -1, (int)size_per_head_});

    auto q = qkv3.slice({0, 0, 0}, {-1, (int)local_head_num_, -1});
    invokeRMSNormQK(q, weights.q_a_layernorm, model_param_.norm_eps, stream);
    sync_check_cuda_error();

    auto k = qkv3.slice({0, (int)local_head_num_, 0}, {-1, (int)local_kv_head_num_, -1});
    invokeRMSNormQK(k, weights.kv_a_layernorm, model_param_.norm_eps, aux_stream_);
    sync_check_cuda_error();

    check_cuda_error(cudaEventRecord(aux_event_, aux_stream_));
    check_cuda_error(cudaStreamWaitEvent(stream, aux_event_));
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/unified_attention_layer.h
================================================
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h

#pragma once

#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/batch.h"
#include "src/turbomind/kernels/attention/cp_utils.h"
#include "src/turbomind/kernels/gemm/test/test_utils.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

struct AttentionData;

class UnifiedAttentionLayer {
public:
    using WeightType = LlamaAttentionWeight;

    static constexpr int kMaxKVSplits        = 128;
    static constexpr int kMaxWorkspaceTokens = 4096;

    struct ForwardParam {
        int               phase;
        Tensor            input;
        Tensor            output;
        const WeightType* weights;
        int               layer_id;
    };

    ~UnifiedAttentionLayer();

    UnifiedAttentionLayer(const ModelParam&     model,
                          const AttentionParam& attn,
                          const EngineParam&    engine,
                          int                   tp_size,
                          const Context&        context,
                          int                   phases,
                          bool                  init);

    void Run(BatchOp op, int phase, TensorMap& env);

    void Forward(ForwardParam p);

private:
    void Setup(int phase, TensorMap& env);

    Tensor forward_mla(const Tensor& hidden_state, const WeightType& weights);

    /// TODO: dropping the `T` here requires deep refactor of attention dispatch
    template
    Tensor core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights);

    void qk_norm(Tensor& qkv, const WeightType& weights);

private:
    const int head_num_;
    const int kv_head_num_;
    const int size_per_head_;
    const int hidden_units_;
    const int local_head_num_;
    const int local_kv_head_num_;

    const AttentionParam param_;
    const EngineParam    engine_param_;
    const ModelParam     model_param_;
    const Context&       context_;

    int& is_warm_up_;

    LlamaLinear& linear_;
    const int    arch_{};

    cudaStream_t aux_stream_;
    cudaEvent_t  qkv_event_;
    cudaEvent_t  aux_event_;

    RNG rng_;

    RopeKernelParam rope_param_{};

    std::vector> data_;

    std::vector cache_layer_ids_;

    ///////////////////////////////////////////////////////
    /// temp runtime buffers
    Tensor_ partial_O_;
    Tensor_ partial_ML_;
    Tensor_   split_cnt_;
    Tensor         tmp_attn_;

    Buffer_ rope_base_buf_;
    Buffer_   mrope_position_delta_buf_;
    Buffer_   mrope_length_buf_;

    CpPostContext cp_fn_ctx_;  // context parallel
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/unified_decoder.cc
================================================


#include 
#include 

#include 

#include "src/turbomind/core/allocator.h"
#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/norm/rms_norm.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/moe_ffn_layer.h"
#include "src/turbomind/models/llama/unified_attention_layer.h"
#include "src/turbomind/models/llama/unified_decoder.h"
#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/cuda_utils.h"

#include "src/turbomind/engine/request.h"

// #include "dbg.h"

namespace turbomind {

void UnifiedDecoder::Run(BatchOp op, int phase, TensorMap& env)
{
    attn_layer_->Run(op, phase, env);
    if (linear_attn_layer_) {
        linear_attn_layer_->Run(op, phase, env);
    }
}

UnifiedDecoder::UnifiedDecoder(const ModelParam&     model,
                               const EngineParam&    engine,
                               const AttentionParam& attn,
                               const MoeParam&       moe,
                               const Context&        ctx,
                               int                   phases):
    layer_num_(model.layer_num),
    hidden_units_(model.hidden_units),
    attn_tp_size_(engine.attn_tp_size),
    attn_dp_size_(engine.attn_dp_size),
    attn_dp_rank_(engine.attn_dp_rank),
    mlp_tp_size_(engine.mlp_tp_size),
    attn_tp_group_(ctx.comm.d_tp_group),
    rmsnorm_eps_(model.norm_eps),
    d_comm_(ctx.comm.d_comm),
    tune_layer_num_(model.tune_layer_num),
    is_warm_up_{*ctx.is_warm_up}
{
    if (std::accumulate(moe.expert_num.begin(), moe.expert_num.end(), 0LL)) {
        moe_ffn_layer_ = std::make_unique(model, moe, engine, ctx);
    }

    attn_layer_ =
        std::make_unique(model, attn, engine, attn_tp_size_, ctx, phases, (bool)moe_ffn_layer_);

    if (std::find(model.layer_types.begin(), model.layer_types.end(), 1) != model.layer_types.end()) {
        linear_attn_layer_ = std::make_unique(model, attn, engine, attn_tp_size_, ctx, phases);
    }

    if (std::accumulate(model.inter_size.begin(), model.inter_size.end(), 0LL)) {
        ffn_layer_ = std::make_unique(model, ctx);
    }
}

void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor&       hidden_states,
                                              Tensor&       residual,
                                              const Tensor& bias,
                                              const Tensor& weight,
                                              int           token_num,
                                              int           group0,
                                              int           group1,
                                              const int*    local_token_nums)
{
    const auto dtype = hidden_states.dtype();

    const auto stream = core::Context::stream().handle();

    if (0) {}
    else if (group0 || group1) {
        d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(),
                                                residual.data_or((void*)nullptr),
                                                bias.data_or((void*)nullptr),
                                                weight.raw_data(),
                                                rmsnorm_eps_,
                                                hidden_units_,
                                                dtype,
                                                group0,
                                                group1,
                                                local_token_nums,
                                                stream);
        sync_check_cuda_error();
    }
    else if (d_comm_) {
        d_comm_->AllreduceResidualBiasRMSnorm(hidden_states.raw_data(),
                                              residual.data_or((void*)nullptr),
                                              bias.data_or((void*)nullptr),
                                              weight.raw_data(),
                                              rmsnorm_eps_,
                                              hidden_units_,
                                              token_num,
                                              dtype,
                                              0,
                                              stream);
        sync_check_cuda_error();
    }
    else {
        invokeResidualBiasRMSNorm(hidden_states.raw_data(),
                                  residual.data_or((void*)nullptr),
                                  weight.raw_data(),
                                  bias.data_or((void*)nullptr),
                                  dtype,
                                  hidden_units_,
                                  token_num,
                                  rmsnorm_eps_,
                                  stream);
        sync_check_cuda_error();
    }
}

void UnifiedDecoder::Forward(int phase, TensorMap& args, const std::vector& weights)
{
    /**
     * input tensors:
     *   \param decoder_input [token_num, hidden_units], float
     *   \param output_norm_weight [hidden_dims], float
     *   \param cu_block_counts [batch_size+1], int
     *   \param finished [batch_size], bool
     *   \param rope_theta [batch_size], float
     *   \param h_q_len [batch_size], int on cpu
     *   \param h_k_len [batch_size], int on cpu
     *   \param pf_batch_size [1], int on cpu
     *   \param dc_batch_size [1], int on cpu
     *
     * output tensors:
     *   \param decoder_output [num_token, hidden_units],
     *   \param last_token_hidden_units [batch_size, hidden_units]
     *   \param block_ptrs [total_block_counts], void*
     */

    constexpr auto device = kDEVICE;

    Tensor      local_residual   = args.try_consume("input_embeds");
    const auto& local_token_nums = args.at("batch").data()[0]->local_token_num;

    const auto local_token_num  = local_residual.shape(0);
    const auto global_token_num = std::accumulate(local_token_nums.begin(), local_token_nums.end(), ssize_t{});

    TM_CHECK_EQ(local_token_num, local_token_nums[attn_dp_rank_]);

    const DataType dtype = local_residual.dtype();

    Tensor global_hidden_states;
    if (d_comm_) {
        Buffer symm_buf      = args.at("symm_buf").buffer();
        global_hidden_states = {symm_buf.view(dtype), {global_token_num, (int)hidden_units_}};
    }
    else {
        global_hidden_states = {{global_token_num, (int)hidden_units_}, local_residual.dtype(), kDEVICE};
    }

    Tensor local_hidden_states;
    if (attn_dp_size_ > 1) {  // Offset hidden states buffer for mixed DP
        TM_CHECK_EQ(local_token_nums.size(), attn_dp_size_);
        std::vector offsets(attn_dp_size_ + 1, 0);
        std::inclusive_scan(local_token_nums.data(), local_token_nums.data() + attn_dp_size_, offsets.begin() + 1);
        const int offset    = offsets[attn_dp_rank_];
        local_hidden_states = global_hidden_states.slice({offset, 0}, {local_token_num, -1});

        // dbg(attn_dp_size_, attn_dp_rank_, local_token_nums, local_token_num, global_token_num);
    }
    else {
        local_hidden_states = global_hidden_states;
    }

    TM_DEBUG_TENSOR(local_residual, "res", 1);
    TM_DEBUG_TENSOR(weights.at(0)->self_attn_norm, "norm_weight", 2);

    const auto stream = core::Context::stream().handle();

    invokeRMSNorm(local_hidden_states, local_residual, weights.at(0)->self_attn_norm, rmsnorm_eps_, stream);
    sync_check_cuda_error();

    TM_DEBUG_TENSOR(local_hidden_states, Concat("norm0", 0), 2);

    // auto stack_alloc{core::Context::device_alloc().adapt()};
    // core::ContextGuard ctx{Allocator{stack_alloc}};

    for (int layer = 0; layer < layer_num_; ++layer) {

        // stack_alloc->iter();

        if (global_token_num == 0) {
            break;
        }

        if (is_warm_up_ && layer >= tune_layer_num_) {
            continue;
        }

        /////////////////////////////////////////////
        /// self-attention or linear-attention
        if (weights.at(layer)->linear_attn_weights) {
            linear_attn_layer_->Forward(
                {phase, local_hidden_states, local_hidden_states, weights.at(layer)->linear_attn_weights.get(), layer});
        }
        else {
            attn_layer_->Forward(
                {phase, local_hidden_states, local_hidden_states, weights.at(layer)->self_attn_weights.get(), layer});
        }

        TM_DEBUG_TENSOR(local_hidden_states, Concat("attn_block", layer), 2);

        // For gated delta networks, we may need a different output.bias name or it doesn't have it.
        // We will just use `output.bias` from either layer.
        Tensor out_bias;
        if (weights.at(layer)->linear_attn_weights) {
            out_bias = weights.at(layer)->linear_attn_weights->out_proj.bias;
        }
        else {
            out_bias = weights.at(layer)->self_attn_weights->output.bias;
        }

        AllreduceResidualRMSnorm(global_hidden_states,
                                 local_residual,
                                 out_bias,
                                 weights.at(layer)->ffn_norm,
                                 local_token_num,
                                 attn_tp_group_,
                                 0,
                                 local_token_nums.data());

        TM_DEBUG_TENSOR(local_residual, Concat("residual0", layer), 2);
        TM_DEBUG_TENSOR(local_hidden_states, Concat("norm1", layer), 2);

        ////////////////////////////////////////////
        /// feed-forward network

        std::optional moe_fwd_param;

        if (weights.at(layer)->moe_weights) {
            moe_fwd_param = MoeFfnLayer::ForwardParam{global_hidden_states,
                                                      global_hidden_states,
                                                      weights.at(layer)->moe_weights.get(),
                                                      ffn_layer_ ? 1.f : 0.f,
                                                      layer};
            moe_ffn_layer_->Forward(*moe_fwd_param);
        }

        if (weights.at(layer)->ffn_weights) {
            ffn_layer_->forward(
                {global_hidden_states, global_hidden_states, weights.at(layer)->ffn_weights.get(), (int)layer});
        }

        if (moe_fwd_param) {
            moe_ffn_layer_->Combine(*moe_fwd_param);
        }

        TM_DEBUG_TENSOR(global_hidden_states, Concat("ffn_block", layer), 2);

        const bool last = layer == layer_num_ - 1;

        auto& scale_weight = !last ? weights.at(layer + 1)->self_attn_norm : args.at("output_norm_weight");

        AllreduceResidualRMSnorm(global_hidden_states,
                                 local_residual,
                                 {},
                                 scale_weight,
                                 local_token_num,
                                 0,
                                 attn_tp_group_,
                                 local_token_nums.data());
        sync_check_cuda_error();

        TM_DEBUG_TENSOR(local_residual, Concat("residual1", layer), 2);
        TM_DEBUG_TENSOR(local_hidden_states, Concat("norm0", layer + 1), 2);

        // if (layer == layer_num_ - 1) {
        //     args.at("batch").data()[0]->Notify();
        // }
    }

    // Token indices selected for decoding
    const Buffer selected_pos = args.consume("selected_token_pos").buffer();
    // dbg(selected_pos);
    // When there are no prefill sequences, token selection is not needed
    const bool reuse_hidden_states = selected_pos.size() == local_token_num;

    const bool output_hidden_states = args.try_("output_hidden_states");

    Tensor hidden_states{local_hidden_states};

    if (d_comm_ && (output_hidden_states || reuse_hidden_states)) {
        // The full `hidden_states` buffer is needed for output but it's a ref into `symm_buf` atm.
        // Copy to residual buf so that `symm_buf` may be reused safely later
        Copy(hidden_states, local_residual);
        hidden_states = local_residual;
    }

    Tensor selected_states;
    if (reuse_hidden_states) {
        selected_states = hidden_states;
    }
    else {
        selected_states = {{selected_pos.size(), (int)hidden_units_}, dtype, kDEVICE};
        CollectHiddenStates(hidden_states, selected_pos, selected_states, stream);
    }
    args.produce("hidden_states", selected_states);

    // TM_DEBUG_TENSOR(selected_states.slice(0, selected_pos.size()), "out", 1);

    if (output_hidden_states) {
        args.produce("full_hidden_states", hidden_states);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/llama/unified_decoder.h
================================================
#pragma once

#include "src/turbomind/comm/device_comm.h"
#include "src/turbomind/models/llama/GatedDeltaNetLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/moe_ffn_layer.h"
#include "src/turbomind/models/llama/unified_attention_layer.h"

namespace turbomind {

class UnifiedDecoder {
public:
    using WeightType = LlamaDecoderLayerWeight;

    UnifiedDecoder(const ModelParam&     model,
                   const EngineParam&    engine,
                   const AttentionParam& attn,
                   const MoeParam&       moe,
                   const Context&        ctx,
                   int                   phases);

    void Run(BatchOp op, int phase, TensorMap& env);

    void Forward(int phase, TensorMap& env, const std::vector& weights);

private:
    const size_t layer_num_;
    const size_t hidden_units_;

    const int attn_tp_size_;
    const int attn_dp_size_;
    const int attn_dp_rank_;
    const int mlp_tp_size_;

    const int attn_tp_group_;

    const float rmsnorm_eps_;

    comm::DeviceCommImpl* const d_comm_;

    const int tune_layer_num_;

    int& is_warm_up_;

    std::unique_ptr attn_layer_;
    std::unique_ptr    linear_attn_layer_;
    std::unique_ptr         ffn_layer_;
    std::unique_ptr           moe_ffn_layer_;

    void AllreduceResidualRMSnorm(Tensor&       hidden_states,
                                  Tensor&       residual,
                                  const Tensor& bias,
                                  const Tensor& weight,
                                  int           token_num,
                                  int           t0,
                                  int           t1,
                                  const int*    local_token_nums);
};

}  // namespace turbomind


================================================
FILE: src/turbomind/models/output_processor.cc
================================================

#include "src/turbomind/models/output_processor.h"

#include 

#include "src/turbomind/engine/request.h"

// #include "dbg.h"

namespace turbomind {

using std::vector;
using std::shared_ptr;

struct OutputProcessor::Impl {

    static constexpr auto kAll = GenerationConfig::kAll;

    const int vocab_size_;
    const int max_logits_len_;
    const int tp_rank_;

    std::function lm_head_;

    Impl(const ModelParam&                    model,
         int                                  max_logits_len,
         int                                  tp_rank,
         int                                  phases,
         std::function lm_head):
        vocab_size_{(int)model.vocab_size},
        max_logits_len_{max_logits_len},
        tp_rank_{tp_rank},
        lm_head_{std::move(lm_head)}
    {
        for (int i = 0; i < phases; ++i) {
            data_.emplace_back();
        }
    }

    struct Data {
        Interval full_states;  // requested range for full hidden states
        Interval full_logits;  // requested range for full logits

        vector> output_states;
        vector> output_logits;
    };

    vector data_;

    struct Matching {
        Interval& target;
        const int offset_d;
        Interval  src;
        Interval  dst;

        bool operator()(const Interval& x, int offset_s, Interval& merged)
        {
            if (auto y = target & x; y && y.begin() == target.begin()) {
                dst    = {y.begin() - offset_d, y.size()};
                src    = {offset_s + (y.begin() - x.begin()), y.size()};
                merged = merged | src;
                target = -(int)y.size() | target;
                return true;
            }
            return false;
        }
    };

    void Add(int phase, TensorMap& env)
    {
        const Buffer_ rc = env.at("requests").buffer();

        for (int i = 0; i < rc.size(); ++i) {
            auto& c = *rc[i];
            auto& r = *c.req;
            auto& g = r.gen_cfg;
            if (g.output_logits) {
                c.output_logits = g.output_logits == kAll ? Interval{c.step0} : Interval{c.prompt_len - 1};
                c.logits_offset = c.output_logits.begin();
            }
            if (g.output_last_hidden_state) {
                c.output_hidden_states =
                    g.output_last_hidden_state == kAll ? Interval{c.step0} : Interval{c.prompt_len - 1};
                c.hidden_states_offset = c.output_hidden_states.begin();
                // dbg(&c.output_hidden_states, c.hidden_states_offset);
            }
        }
    }

    void Setup(int phase, TensorMap& env)
    {
        auto& d = data_.at(phase);

        const auto& rc = env.at("batch").data()[0]->rc;

        vector all_tokens;
        vector sel_tokens;
        for (int i = 0; i < rc.size(); ++i) {
            using Size = Interval::Size;
            auto& c    = *rc[i];
            all_tokens.emplace_back(c.history_len + c.alpha, Size{c.input_len});
            sel_tokens.emplace_back(c.history_len + c.alpha + c.input_len - 1, Size{1});
            if (!c.generating) {
                sel_tokens.back() = {};
            }
            // dbg(&all_tokens.back(), &sel_tokens.back());
        }

        const int token_num = *env.at("token_num").data();

        d.full_logits = {INT_MAX, 0};
        d.full_states = {INT_MAX, 0};

        Interval select_states{INT_MAX, 0};
        Interval select_logits{INT_MAX, 0};

        d.output_logits = {};
        d.output_states = {};

        int offset = 0;

        for (int i = 0; i < rc.size(); ++i) {
            auto& c = *rc[i];
            auto& g = c.req->gen_cfg;
            if (c.output_hidden_states) {
                Matching m{c.output_hidden_states, c.hidden_states_offset};
                int      type = 0;
                if (m(sel_tokens[i], i, select_states)) {
                    type = 1;
                }
                else if (m(all_tokens[i], offset, d.full_states)) {
                    type = 2;
                }
                if (type) {
                    d.output_states.emplace_back(i, type, m.src, m.dst);
                    // dbg(type, &m.src, &m.dst);
                }
            }
            if (c.output_logits) {
                Matching m{c.output_logits, c.logits_offset};
                int      type = 0;
                if (m(sel_tokens[i], i, select_logits)) {
                    type = 1;
                }
                else if (m(all_tokens[i], offset, d.full_logits)) {
                    type = 2;
                }
                if (type) {
                    d.output_logits.emplace_back(i, type, m.src, m.dst);
                }
            }
            offset += c.input_len;
        }

        // logits depends on hidden states
        d.full_states = d.full_states | d.full_logits;
    }

    void Prepare(int phase, TensorMap& env)
    {
        auto& d = data_.at(phase);
        if (d.full_states) {
            env.produce("output_hidden_states", Tensor{});
        }
    }

    template
    void OutputHiddenStates(const Ranges& ranges, const Tensor& h, int type, const vector>& rs)
    {
        for (const auto& [i, t, src, dst] : ranges) {
            if (t == type) {
                auto& out = rs[i]->req->outputs.at("last_hidden_state");
                if (tp_rank_ == 0) {
                    // dbg(&src, &dst);
                    Copy(h.slice(src.begin(), (int)src.size()), out.slice(dst.begin(), (int)dst.size()));
                }
            }
        }
    }

    void ComputeAndOutputLogits(const Data& data, const Tensor& h, const vector>& rs)
    {
        const int step_size = max_logits_len_;

        // Coroutine frame
        int  p      = 0;
        auto ranges = data.output_logits;

        using Size = Interval::Size;

        bool success = false;
        // Erode the range iteratively until empty
        for (auto r = data.full_logits; r; r = -step_size | r) {
            // dbg(&r);
            if (auto chunk = r & Interval{r.begin(), Size{step_size}}) {
                // dbg(&chunk);
                // Compute & output full logits by chunks
                auto logits = lm_head_(h.slice(chunk.begin(), (int)chunk.size()));
                success     = OutputLogitsImpl(ranges, p, logits, chunk.begin(), 2, rs);
                if (success) {  // all requests satisfied, exit early
                    break;
                }
            }
        }

        TM_CHECK(success);  // all requests must be satisfied at the end
    }

    template
    void OutputLogits(Ranges& ranges_, const Tensor& l, int type, const vector>& rs)
    {
        // Coroutine frame
        int  p      = 0;
        auto ranges = ranges_;

        TM_CHECK(OutputLogitsImpl(ranges, p, l, /* base */ 0, type, rs));
    }

    template
    bool OutputLogitsImpl(
        Ranges& ranges, int& p, const Tensor& l, int base, int type, const vector>& rs)
    {
        // dbg("OutputLogitsImpl");
        const auto stream = core::Context::stream().handle();
        for (; p < ranges.size(); ++p) {
            if (auto& [i, t, src, dst] = ranges[p]; t == type) {
                Tensor&        out   = rs[i]->req->outputs.at("logits");
                const DataType dtype = out.dtype();
                TM_CHECK_LE(base, src.begin());  // logical error
                if (Interval msrc = src & Interval{base, Interval::Size{(int)l.shape(0)}}) {
                    const int tokens = (int)msrc.size();
                    Interval  mdst{dst.begin(), msrc.size()};
                    // TODO: support strides in `DLTensor`, so that batched 1D copy can be used
                    if (tp_rank_ == 0) {
                        // dbg(&mdst, &msrc, tokens, out, base, l);
                        TM_CHECK_EQ(cudaMemcpy2DAsync(out.slice(mdst.begin(), tokens).raw_data(),
                                                      byte_size(dtype, out.stride(0)),
                                                      l.slice(msrc.begin() - base, tokens).raw_data(),
                                                      byte_size(dtype, l.stride(0)),
                                                      byte_size(dtype, vocab_size_),
                                                      tokens,
                                                      cudaMemcpyDefault,
                                                      stream),
                                    0);
                    }
                    // move to next request if they are empty after the erosion
                    src = -(int)msrc.size() | src;
                    dst = -(int)mdst.size() | dst;
                }
                // dbg(&src, (int)src.size(), &dst, (int)dst.size());
                if (src) {
                    // request not compeleted, suspend and wait for next chunk
                    return false;
                }
            }
        }
        return true;
    }

    void OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type)
    {
        auto& d = data_.at(phase);
        auto& b = *env.at("batch").data()[0];

        if (type == 2 && d.full_states) {
            auto hidden_states = env.consume("full_hidden_states");
            if (!d.output_states.empty()) {
                OutputHiddenStates(d.output_states, hidden_states, 2, b.rc);
            }
            if (!d.output_logits.empty() && d.full_logits) {
                ComputeAndOutputLogits(d, hidden_states, b.rc);
            }
        }

        if (type == 1) {
            if (!d.output_states.empty()) {
                OutputHiddenStates(d.output_states, env.at("hidden_states"), 1, b.rc);
            }
            if (!d.output_logits.empty()) {
                OutputLogits(d.output_logits, env.at("logits"), 1, b.rc);
            }
        }
    }
};

OutputProcessor::~OutputProcessor() = default;

OutputProcessor::OutputProcessor(
    const ModelParam& model, int max_logits_len, int tp_rank, int phases, std::function lm_head):
    impl_{std::make_unique(model, max_logits_len, tp_rank, phases, std::move(lm_head))}
{
}

void OutputProcessor::Run(BatchOp op, int phase, TensorMap& env)
{
    switch (op) {
        case BatchOp::kAdd:
            return impl_->Add(phase, env);
        case BatchOp::kSetup:
            return impl_->Setup(phase, env);
        case BatchOp::kPrepare:
            return impl_->Prepare(phase, env);
        default:
            return;
    }
}

void OutputProcessor::OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type)
{
    return impl_->OutputHiddenStatesAndLogits(phase, env, type);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/models/output_processor.h
================================================
#pragma once

#include "src/turbomind/engine/batch.h"
#include "src/turbomind/models/llama/llama_params.h"

namespace turbomind {

class OutputProcessor {
public:
    ~OutputProcessor();

    OutputProcessor(const ModelParam&                    model,  //
                    int                                  max_logits_len,
                    int                                  tp_rank,
                    int                                  phases,
                    std::function lm_head);

    void Run(BatchOp op, int phase, TensorMap& env);

    void OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type);

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/python/CMakeLists.txt
================================================
# Copyright (c) OpenMMLab. All rights reserved.

cmake_minimum_required(VERSION 3.11)
project(_turbomind LANGUAGES CXX CUDA)

find_package(pybind11 CONFIG)
if(NOT pybind11_FOUND)
    execute_process(COMMAND "pybind11-config" "--cmakedir"
                    RESULT_VARIABLE _COMMAND_SUCCESS
                    OUTPUT_VARIABLE pybind11_DIR
                    OUTPUT_STRIP_TRAILING_WHITESPACE)
    find_package(pybind11 CONFIG)
endif()

pybind11_add_module(${PROJECT_NAME} bind.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE turbomind xgrammar)

pybind11_add_module(_xgrammar xgrammar_bind.cpp)
target_link_libraries(_xgrammar PRIVATE core xgrammar)
target_compile_features(_xgrammar PRIVATE cxx_std_14)

if (CALL_FROM_SETUP_PY)
  string(REPLACE "." ";" _ver ${CMAKE_CUDA_COMPILER_VERSION})
  list(GET _ver 0 CUDA_MAJOR)

  if(CUDA_MAJOR GREATER_EQUAL "13")
    set(_INSTALL_CUDA_RPATH
        "\$ORIGIN"
        "\$ORIGIN/../../nvidia/nccl/lib/"
        "\$ORIGIN/../../nvidia/cu${CUDA_MAJOR}/lib/"
    )
  else()
    set(_INSTALL_CUDA_RPATH
        "\$ORIGIN"
        "\$ORIGIN/../../nvidia/nccl/lib/"
        "\$ORIGIN/../../nvidia/cuda_runtime/lib/"
        "\$ORIGIN/../../nvidia/cublas/lib/"
        "\$ORIGIN/../../nvidia/curand/lib/"
    )
  endif()
  set_target_properties(${PROJECT_NAME} PROPERTIES
      BUILD_RPATH "\$ORIGIN"
      INSTALL_RPATH "${_INSTALL_CUDA_RPATH}"
  )
endif ()


================================================
FILE: src/turbomind/python/bind.cpp
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 

#include 

#include 
#include 
#include 
#include 
#include 

#include "xgrammar/compiler.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/core/tensor.h"
#include "src/turbomind/engine/model_request.h"
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/turbomind.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/metrics.h"

namespace py = pybind11;
namespace ft = turbomind;
using namespace pybind11::literals;

using ft::core::Tensor;

// prepare to bind container
using TensorMap = ft::core::TensorMap;
PYBIND11_MAKE_OPAQUE(TensorMap);
static const char kDlTensorCapsuleName[] = "dltensor";

DLDevice getDLDevice(const Tensor& tensor)
{
    int device_id = 0;
    if (tensor.device().type == ft::kDEVICE) {
        cudaPointerAttributes ptr_attr{};
        cudaPointerGetAttributes(&ptr_attr, tensor.raw_data());
        device_id = ptr_attr.device;
    }

    DLDevice device{kDLCPU, device_id};

    switch (tensor.device().type) {
        case ft::kCPU:
            device.device_type = DLDeviceType::kDLCPU;
            break;
        case ft::kCPUpinned:
            device.device_type = DLDeviceType::kDLCUDAHost;
            break;
        case ft::kDEVICE:
            device.device_type = DLDeviceType::kDLCUDA;
            break;
        default:
            break;
    }

    return device;
}

DLManagedTensor* TritonTensorToDLManagedTensor(Tensor& tensor)
{
    DLDevice   device = getDLDevice(tensor);
    DLDataType data_type{0, 0, 1};
    using ft::data_type_v;
    switch (tensor.dtype()) {
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLBool;
            data_type.bits = 8;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLUInt;
            data_type.bits = 8;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLUInt;
            data_type.bits = 16;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLUInt;
            data_type.bits = 32;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLUInt;
            data_type.bits = 64;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLInt;
            data_type.bits = 8;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLInt;
            data_type.bits = 16;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLInt;
            data_type.bits = 32;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLInt;
            data_type.bits = 64;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLFloat;
            data_type.bits = 16;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLFloat;
            data_type.bits = 32;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLFloat;
            data_type.bits = 64;
            break;
        case data_type_v:
            data_type.code = DLDataTypeCode::kDLBfloat;
            data_type.bits = 16;
            break;
        default:
            break;
    }

    static_assert(sizeof(int64_t) == sizeof(tensor.shape(0)));

    Tensor*  ctx = new Tensor(tensor);
    DLTensor dl_tensor{const_cast(ctx->raw_data()),
                       device,
                       (int32_t)(ctx->ndim()),
                       data_type,
                       (int64_t*)ctx->shape().data(),
                       (int64_t*)(nullptr),
                       0};
    return new DLManagedTensor{dl_tensor, ctx, [](DLManagedTensor* dlmt) {  //
                                   delete (Tensor*)dlmt->manager_ctx;
                                   delete dlmt;
                               }};
}

ft::DeviceType getMemoryType(DLDevice device)
{
    switch (device.device_type) {
        case DLDeviceType::kDLCUDAHost:
            return ft::DeviceType::kCPUpinned;
        case DLDeviceType::kDLCUDA:
            return ft::DeviceType::kDEVICE;
        case DLDeviceType::kDLCPU:
        default:
            return ft::DeviceType::kCPU;
    }
}

ft::DataType getDataType(DLDataType data_type)
{
    using ft::data_type_v;
    switch (data_type.code) {
        case DLDataTypeCode::kDLUInt:
            switch (data_type.bits) {
                case 8:
                    return data_type_v;
                case 16:
                    return data_type_v;
                case 32:
                    return data_type_v;
                case 64:
                    return data_type_v;
                default:
                    return data_type_v;
            }
            break;
        case DLDataTypeCode::kDLInt:
            switch (data_type.bits) {
                case 8:
                    return data_type_v;
                case 16:
                    return data_type_v;
                case 32:
                    return data_type_v;
                case 64:
                    return data_type_v;
                default:
                    return data_type_v;
            }
            break;
        case DLDataTypeCode::kDLFloat:
            switch (data_type.bits) {
                case 16:
                    return data_type_v;
                case 32:
                    return data_type_v;
                case 64:
                    return data_type_v;
                default:
                    return data_type_v;
            }
            break;
        case DLDataTypeCode::kDLBfloat:
            switch (data_type.bits) {
                case 16:
                    return data_type_v;
                default:
                    return data_type_v;
            }
            break;
        case DLDataTypeCode::kDLBool:
            return data_type_v;
        default:
            return data_type_v;
    }
}

std::shared_ptr DLManagedTensorToTritonTensor(DLManagedTensor* tensor)
{
    auto& dl_tensor = tensor->dl_tensor;
    auto  where     = getMemoryType(dl_tensor.device);
    auto  dtype     = getDataType(dl_tensor.dtype);
    assert(dl_tensor.ndim > 0);
    std::vector shape(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim);

    std::shared_ptr ptr{dl_tensor.data, [tensor](void* p) {
                                  if (tensor->deleter) {
                                      tensor->deleter(tensor);
                                  }
                              }};
    return std::make_shared(ptr, std::move(shape), dtype, where);
}

static void safe_memcpy(void* dst, const void* src, size_t size)
{
    cudaPointerAttributes dat{};
    cudaPointerAttributes sat{};
    ft::check_cuda_error(cudaPointerGetAttributes(&dat, dst));
    ft::check_cuda_error(cudaPointerGetAttributes(&sat, src));
    try {
        if (dat.devicePointer && sat.devicePointer) {
            // Both can be accessed from current context
            ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));
        }
        else if (dat.type == cudaMemoryTypeDevice && sat.type == cudaMemoryTypeDevice) {
            if (dat.device != sat.device) {
                // On different devices, try peer memcpy
                ft::check_cuda_error(cudaMemcpyPeer(dst, dat.device, src, sat.device, size));
            }
            else {
                // Same device, switch to the device first (this is unlikely)
                ft::CudaDeviceGuard guard(dat.device);
                ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));
            }
        }
        else {
            // Unknown case, give it a try anyway
            ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));
        }
    }
    catch (...) {
        int device_id{-1};
        cudaGetDevice(&device_id);
        TM_LOG_ERROR("cudaMemcpy failed: dst=(%d, %d, %p, %p), src=(%d, %d, %p, %p), size=%s, device=%d",
                     (int)dat.type,
                     dat.device,
                     dat.devicePointer,
                     dat.hostPointer,
                     (int)sat.type,
                     sat.device,
                     sat.devicePointer,
                     sat.hostPointer,
                     std::to_string(size).c_str(),
                     device_id);
        throw;
    }
}

namespace {

struct ScopedGIL {
    ScopedGIL(const ScopedGIL&) = delete;
    ScopedGIL& operator=(const ScopedGIL&) = delete;
    ScopedGIL(ScopedGIL&&)                 = delete;
    ScopedGIL& operator=(ScopedGIL&&) = delete;
    ScopedGIL()
    {
        state = PyGILState_Ensure();
    }
    ~ScopedGIL()
    {
        PyGILState_Release(state);
    }
    PyGILState_STATE state;
};

}  // namespace

PYBIND11_MODULE(_turbomind, m)
{
    py::class_>(m, "RequestMetrics")
        .def(py::init())
        .def_property_readonly("enqueue_time",
                               [](ft::RequestMetrics& m) { return m.enqueue_time.load(std::memory_order_relaxed); })
        .def_property_readonly("scheduled_time",
                               [](ft::RequestMetrics& m) { return m.scheduled_time.load(std::memory_order_relaxed); });

    py::class_>(m, "ScheduleMetrics")
        .def(py::init())
        .def_readonly("total_seqs", &ft::ScheduleMetrics::total_seqs)
        .def_readonly("active_seqs", &ft::ScheduleMetrics::active_seqs)
        .def_readonly("waiting_seqs", &ft::ScheduleMetrics::waiting_seqs)
        .def_readonly("total_blocks", &ft::ScheduleMetrics::total_blocks)
        .def_readonly("active_blocks", &ft::ScheduleMetrics::active_blocks)
        .def_readonly("cached_blocks", &ft::ScheduleMetrics::cached_blocks)
        .def_readonly("free_blocks", &ft::ScheduleMetrics::free_blocks);

    py::class_(m, "SessionParam")
        .def(py::init([](uint64_t id, int step, bool start, bool end) {
                 if (!start && end) {
                     throw std::logic_error("unsupported arguments: start=false, end=true");
                 }
                 ft::SessionParam param{};
                 param.id         = id;
                 param.step       = step;
                 param.start_flag = start;
                 param.end_flag   = end;
                 return param;
             }),
             "id"_a,
             "step"_a,
             "start"_a,
             "end"_a)
        .def_readwrite("id", &ft::SessionParam::id)
        .def_readwrite("step", &ft::SessionParam::step)
        .def_readwrite("start", &ft::SessionParam::start_flag)
        .def_readwrite("end", &ft::SessionParam::end_flag);

    py::class_(m, "GenerationConfig")
        .def(py::init())
        .def_readwrite("max_new_tokens", &ft::GenerationConfig::max_new_tokens)
        .def_readwrite("min_new_tokens", &ft::GenerationConfig::min_new_tokens)
        .def_readwrite("eos_ids", &ft::GenerationConfig::eos_ids)
        .def_readwrite("stop_ids", &ft::GenerationConfig::stop_ids)
        .def_readwrite("bad_ids", &ft::GenerationConfig::bad_ids)
        .def_readwrite("top_p", &ft::GenerationConfig::top_p)
        .def_readwrite("top_k", &ft::GenerationConfig::top_k)
        .def_readwrite("min_p", &ft::GenerationConfig::min_p)
        .def_readwrite("temperature", &ft::GenerationConfig::temperature)
        .def_readwrite("repetition_penalty", &ft::GenerationConfig::repetition_penalty)
        .def_readwrite("random_seed", &ft::GenerationConfig::random_seed)
        .def_readwrite("output_logprobs", &ft::GenerationConfig::output_logprobs)
        .def_readwrite("output_last_hidden_state", &ft::GenerationConfig::output_last_hidden_state)
        .def_readwrite("output_logits", &ft::GenerationConfig::output_logits)
        .def("__repr__", [](const ft::GenerationConfig& c) {
            std::ostringstream oss;
            oss << c;
            return oss.str();
        });

    py::class_>(m, "RequestState")
        .def_readonly("status", &ft::RequestState::status)
        .def_readonly("seq_len", &ft::RequestState::seq_len);

    py::class_>(m, "AtomicRequestState")
        .def("consume", [](ft::AtomicRequestState& s) { return s.exchange(nullptr); });

    // data type
    {
        using namespace turbomind;
        py::enum_(m, "DataType")
            .value("TYPE_INVALID", kNull)
            .value("TYPE_BOOL", kBool)
            .value("TYPE_UINT8", kUint8)
            .value("TYPE_UINT16", kUint16)
            .value("TYPE_UINT32", kUint32)
            .value("TYPE_UINT64", kUint64)
            .value("TYPE_INT8", kInt8)
            .value("TYPE_INT16", kInt16)
            .value("TYPE_INT32", kInt32)
            .value("TYPE_INT64", kInt64)
            .value("TYPE_FP16", kFloat16)
            .value("TYPE_FP32", kFloat32)
            .value("TYPE_FP64", kFloat64)
            .value("TYPE_BF16", kBfloat16);

        // memory type
        py::enum_(m, "MemoryType")
            .value("MEMORY_CPU", ft::DeviceType::kCPU)
            .value("MEMORY_CPU_PINNED", ft::DeviceType::kCPUpinned)
            .value("MEMORY_GPU", ft::DeviceType::kDEVICE);
    }

    // tensor
    py::class_>(m, "Tensor")
        .def_property_readonly("where", [](const Tensor& t) { return t.device().type; })
        .def_property_readonly("type", [](const Tensor& t) { return t.dtype(); })
        .def_property_readonly("shape", [](const Tensor& t) { return t.shape(); })
        .def_property_readonly("data", [](const Tensor& t) { return t.raw_data(); })
        .def(
            "copy_from",
            [](Tensor& self, py::object obj) {
                py::capsule      cap = obj.attr("__dlpack__")();
                DLManagedTensor* dlmt =
                    static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
                auto src = DLManagedTensorToTritonTensor(dlmt);
                // take ownership of capsule's payload
                cap.set_name("used_dltensor");

                TM_CHECK_EQ(self.byte_size(), src->byte_size()) << self << " " << *src;
                safe_memcpy(self.raw_data(), src->raw_data(), self.byte_size());
            },
            "tensor"_a)
        .def(
            "__dlpack__",
            [](Tensor& self, long stream) {
                DLManagedTensor* dlmt = TritonTensorToDLManagedTensor(self);
                return py::capsule(dlmt, kDlTensorCapsuleName, [](PyObject* obj) {
                    DLManagedTensor* dlmt =
                        static_cast(PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
                    if (dlmt) {
                        dlmt->deleter(dlmt);
                    }
                    else {
                        // The tensor has been deleted. Clear any error from
                        // PyCapsule_GetPointer.
                        PyErr_Clear();
                    }
                });
            },
            "stream"_a = 0)
        .def("__dlpack_device__", [](const Tensor& self) {
            auto device = getDLDevice(self);
            return std::tuple(int(device.device_type), device.device_id);
        });
    m.def(
        "from_dlpack",
        [](py::object obj) {
            py::capsule      cap = obj.attr("__dlpack__")();
            DLManagedTensor* dlmt =
                static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
            auto ret = DLManagedTensorToTritonTensor(dlmt);
            // take ownership of capsule's payload
            cap.set_name("used_dltensor");
            return ret;
        },
        "dl_managed_tensor"_a);

    py::bind_map>(m, "TensorMap");

    using ft::ModelRequest;
    py::class_(m, "ModelRequest")
        .def(
            "forward",
            [](ModelRequest*               model_request,
               std::shared_ptr  input_tensors,
               const ft::SessionParam&     session,
               const ft::GenerationConfig& gen_cfg,
               bool                        stream_output,
               bool                        enable_metrics,
               std::function       cb) {
                ModelRequest::InputParam param{};
                param.tensors        = std::move(input_tensors);
                param.session        = session;
                param.gen_cfg        = gen_cfg;
                param.stream_output  = stream_output;
                param.enable_metrics = enable_metrics;

                auto ret = model_request->Forward(std::move(param), [cb = std::move(cb)]() {
                    try {
                        cb();
                    }
                    catch (const py::error_already_set& e) {
                        std::cerr << e.what() << std::endl;
                    }
                });
                return std::make_tuple(std::move(ret.tensors), std::move(ret.state), std::move(ret.metrics));
            },
            py::call_guard(),
            "input_tensors"_a,
            "session"_a,
            "gen_cfg"_a,
            "stream_output"_a,
            "enable_metrics"_a,
            "cb"_a)
        .def(
            "cancel",
            [](ModelRequest* model_request) {
                model_request->Cancel();  //
            },
            py::call_guard())
        .def(
            "end",
            [](ModelRequest* model_request, std::function cb, uint64_t session_id) {
                model_request->End(std::move(cb), session_id);  //
            },
            py::call_guard(),
            "cb"_a,
            "session_id"_a)
        .def(
            "set_grammar",
            [](ModelRequest* model_request, const xgrammar::CompiledGrammar& grammar) {
                TM_LOG_INFO("Set grammar for model_request");
                model_request->setGrammar(grammar);
            },
            py::call_guard(),
            "grammar"_a);

    // transformer model
    using ft::TurboMind;
    py::class_>(m, "TurboMind")
        .def_static(
            "create",
            [](std::string model_dir, std::string config, std::string weight_type) -> std::shared_ptr {
                auto gil_factory = [] {  //
                    // erase the type
                    return std::static_pointer_cast(std::make_shared());
                };
                auto no_gil_deleter = [](TurboMind* ptr) {
                    pybind11::gil_scoped_release release;
                    delete ptr;
                };

                std::shared_ptr model(new TurboMind(model_dir, config, gil_factory), no_gil_deleter);
                return model;
            },
            "model_dir"_a,
            "config"_a      = "",
            "weight_type"_a = "half")
        .def(
            "create_request",
            [](TurboMind* model) { return model->CreateRequest(); },
            py::call_guard())
        .def("create_weights", &TurboMind::CreateWeights, py::call_guard(), "index"_a)
        .def(
            "get_weights",
            [](TurboMind* model, int index) { return model->GetWeights(index); },
            py::call_guard(),
            "index"_a)
        .def(
            "process_weight",
            [](TurboMind* model, int index) { model->ProcessWeights(index); },
            py::call_guard(),
            "index"_a)
        .def(
            "create_engine",
            [](TurboMind* model, int index) { model->CreateEngine(index); },
            py::call_guard(),
            "index"_a)
        .def(
            "get_schedule_metrics",
            [](TurboMind* model, int index) { return model->GetScheduleMetrics(index); },
            py::call_guard(),
            "index"_a)
        .def(
            "sleep",
            [](TurboMind* model, int index, int level) { model->Sleep(index, level); },
            py::call_guard(),
            "index"_a,
            "level"_a)
        .def(
            "wakeup",
            [](TurboMind* model, int index, const std::vector& tags) { model->WakeUp(index, tags); },
            py::call_guard(),
            "index"_a,
            "tags"_a)
        .def("is_dummy_node", [](TurboMind* model) { return model->is_dummy_node(); });
}


================================================
FILE: src/turbomind/python/dlpack.h
================================================
/*!
 *  Copyright (c) 2017 by Contributors
 * \file dlpack.h
 * \brief The common header of DLPack.
 */
#ifndef DLPACK_DLPACK_H_
#define DLPACK_DLPACK_H_

/**
 * \brief Compatibility with C++
 */
#ifdef __cplusplus
#define DLPACK_EXTERN_C extern "C"
#else
#define DLPACK_EXTERN_C
#endif

/*! \brief The current major version of dlpack */
#define DLPACK_MAJOR_VERSION 1

/*! \brief The current minor version of dlpack */
#define DLPACK_MINOR_VERSION 0

/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
#ifdef DLPACK_EXPORTS
#define DLPACK_DLL __declspec(dllexport)
#else
#define DLPACK_DLL __declspec(dllimport)
#endif
#else
#define DLPACK_DLL
#endif

#include 
#include 

#ifdef __cplusplus
extern "C" {
#endif

/*!
 * \brief The DLPack version.
 *
 * A change in major version indicates that we have changed the
 * data layout of the ABI - DLManagedTensorVersioned.
 *
 * A change in minor version indicates that we have added new
 * code, such as a new device type, but the ABI is kept the same.
 *
 * If an obtained DLPack tensor has a major version that disagrees
 * with the version number specified in this header file
 * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
 * (and it is safe to do so). It is not safe to access any other fields
 * as the memory layout will have changed.
 *
 * In the case of a minor version mismatch, the tensor can be safely used as
 * long as the consumer knows how to interpret all fields. Minor version
 * updates indicate the addition of enumeration values.
 */
typedef struct {
    /*! \brief DLPack major version. */
    uint32_t major;
    /*! \brief DLPack minor version. */
    uint32_t minor;
} DLPackVersion;

/*!
 * \brief The device type in DLDevice.
 */
#ifdef __cplusplus
typedef enum: int32_t
{
#else
typedef enum
{
#endif
    /*! \brief CPU device */
    kDLCPU = 1,
    /*! \brief CUDA GPU device */
    kDLCUDA = 2,
    /*!
     * \brief Pinned CUDA CPU memory by cudaMallocHost
     */
    kDLCUDAHost = 3,
    /*! \brief OpenCL devices. */
    kDLOpenCL = 4,
    /*! \brief Vulkan buffer for next generation graphics. */
    kDLVulkan = 7,
    /*! \brief Metal for Apple GPU. */
    kDLMetal = 8,
    /*! \brief Verilog simulator buffer */
    kDLVPI = 9,
    /*! \brief ROCm GPUs for AMD GPUs */
    kDLROCM = 10,
    /*!
     * \brief Pinned ROCm CPU memory allocated by hipMallocHost
     */
    kDLROCMHost = 11,
    /*!
     * \brief Reserved extension device type,
     * used for quickly test extension device
     * The semantics can differ depending on the implementation.
     */
    kDLExtDev = 12,
    /*!
     * \brief CUDA managed/unified memory allocated by cudaMallocManaged
     */
    kDLCUDAManaged = 13,
    /*!
     * \brief Unified shared memory allocated on a oneAPI non-partititioned
     * device. Call to oneAPI runtime is required to determine the device
     * type, the USM allocation type and the sycl context it is bound to.
     *
     */
    kDLOneAPI = 14,
    /*! \brief GPU support for next generation WebGPU standard. */
    kDLWebGPU = 15,
    /*! \brief Qualcomm Hexagon DSP */
    kDLHexagon = 16,
} DLDeviceType;

/*!
 * \brief A Device for Tensor and operator.
 */
typedef struct {
    /*! \brief The device type used in the device. */
    DLDeviceType device_type;
    /*!
     * \brief The device index.
     * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
     */
    int32_t device_id;
} DLDevice;

/*!
 * \brief The type code options DLDataType.
 */
typedef enum
{
    /*! \brief signed integer */
    kDLInt = 0U,
    /*! \brief unsigned integer */
    kDLUInt = 1U,
    /*! \brief IEEE floating point */
    kDLFloat = 2U,
    /*!
     * \brief Opaque handle type, reserved for testing purposes.
     * Frameworks need to agree on the handle data type for the exchange to be well-defined.
     */
    kDLOpaqueHandle = 3U,
    /*! \brief bfloat16 */
    kDLBfloat = 4U,
    /*!
     * \brief complex number
     * (C/C++/Python layout: compact struct per complex number)
     */
    kDLComplex = 5U,
    /*! \brief boolean */
    kDLBool = 6U,
} DLDataTypeCode;

/*!
 * \brief The data type the tensor can hold. The data type is assumed to follow the
 * native endian-ness. An explicit error message should be raised when attempting to
 * export an array with non-native endianness
 *
 *  Examples
 *   - float: type_code = 2, bits = 32, lanes = 1
 *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
 *   - int8: type_code = 0, bits = 8, lanes = 1
 *   - std::complex: type_code = 5, bits = 64, lanes = 1
 *   - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of
 * bool is 8 bits)
 */
typedef struct {
    /*!
     * \brief Type code of base types.
     * We keep it uint8_t instead of DLDataTypeCode for minimal memory
     * footprint, but the value should be one of DLDataTypeCode enum values.
     * */
    uint8_t code;
    /*!
     * \brief Number of bits, common choices are 8, 16, 32.
     */
    uint8_t bits;
    /*! \brief Number of lanes in the type, used for vector types. */
    uint16_t lanes;
} DLDataType;

/*!
 * \brief Plain C Tensor object, does not manage memory.
 */
typedef struct {
    /*!
     * \brief The data pointer points to the allocated data. This will be CUDA
     * device pointer or cl_mem handle in OpenCL. It may be opaque on some device
     * types. This pointer is always aligned to 256 bytes as in CUDA. The
     * `byte_offset` field should be used to point to the beginning of the data.
     *
     * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
     * TVM, perhaps others) do not adhere to this 256 byte alignment requirement
     * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This must be fixed
     * (after which this note will be updated); at the moment it is recommended
     * to not rely on the data pointer being correctly aligned.
     *
     * For given DLTensor, the size of memory required to store the contents of
     * data is calculated as follows:
     *
     * \code{.c}
     * static inline size_t GetDataSize(const DLTensor* t) {
     *   size_t size = 1;
     *   for (tvm_index_t i = 0; i < t->ndim; ++i) {
     *     size *= t->shape[i];
     *   }
     *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
     *   return size;
     * }
     * \endcode
     */
    void* data;
    /*! \brief The device of the tensor */
    DLDevice device;
    /*! \brief Number of dimensions */
    int32_t ndim;
    /*! \brief The data type of the pointer*/
    DLDataType dtype;
    /*! \brief The shape of the tensor */
    int64_t* shape;
    /*!
     * \brief strides of the tensor (in number of elements, not bytes)
     *  can be NULL, indicating tensor is compact and row-majored.
     */
    int64_t* strides;
    /*! \brief The offset in bytes to the beginning pointer to data */
    uint64_t byte_offset;
} DLTensor;

/*!
 * \brief C Tensor object, manage memory of DLTensor. This data structure is
 *  intended to facilitate the borrowing of DLTensor by another framework. It is
 *  not meant to transfer the tensor. When the borrowing framework doesn't need
 *  the tensor, it should call the deleter to notify the host that the resource
 *  is no longer needed.
 *
 * \note This data structure is used as Legacy DLManagedTensor
 *       in DLPack exchange and is deprecated after DLPack v0.8
 *       Use DLManagedTensorVersioned instead.
 *       This data structure may get renamed or deleted in future versions.
 *
 * \sa DLManagedTensorVersioned
 */
typedef struct DLManagedTensor {
    /*! \brief DLTensor which is being memory managed */
    DLTensor dl_tensor;
    /*! \brief the context of the original host framework of DLManagedTensor in
     *   which DLManagedTensor is used in the framework. It can also be NULL.
     */
    void* manager_ctx;
    /*!
     * \brief Destructor - this should be called
     * to destruct the manager_ctx  which backs the DLManagedTensor. It can be
     * NULL if there is no way for the caller to provide a reasonable destructor.
     * The destructors deletes the argument self as well.
     */
    void (*deleter)(struct DLManagedTensor* self);
} DLManagedTensor;

// bit masks used in in the DLManagedTensorVersioned

/*! \brief bit mask to indicate that the tensor is read only. */
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)

/*!
 * \brief A versioned and managed C Tensor object, manage memory of DLTensor.
 *
 * This data structure is intended to facilitate the borrowing of DLTensor by
 * another framework. It is not meant to transfer the tensor. When the borrowing
 * framework doesn't need the tensor, it should call the deleter to notify the
 * host that the resource is no longer needed.
 *
 * \note This is the current standard DLPack exchange data structure.
 */
struct DLManagedTensorVersioned {
    /*!
     * \brief The API and ABI version of the current managed Tensor
     */
    DLPackVersion version;
    /*!
     * \brief the context of the original host framework.
     *
     * Stores DLManagedTensorVersioned is used in the
     * framework. It can also be NULL.
     */
    void* manager_ctx;
    /*!
     * \brief Destructor.
     *
     * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
     * It can be NULL if there is no way for the caller to provide a reasonable
     * destructor. The destructors deletes the argument self as well.
     */
    void (*deleter)(struct DLManagedTensorVersioned* self);
    /*!
     * \brief Additional bitmask flags information about the tensor.
     *
     * By default the flags should be set to 0.
     *
     * \note Future ABI changes should keep everything until this field
     *       stable, to ensure that deleter can be correctly called.
     *
     * \sa DLPACK_FLAG_BITMASK_READ_ONLY
     */
    uint64_t flags;
    /*! \brief DLTensor which is being memory managed */
    DLTensor dl_tensor;
};

#ifdef __cplusplus
}  // DLPACK_EXTERN_C
#endif
#endif  // DLPACK_DLPACK_H_


================================================
FILE: src/turbomind/python/xgrammar_bind.cpp
================================================
// Modified from xgrammar/nanobind/nanobind.cc from xgrammar project.
/*!
 *  Copyright (c) 2024 by Contributors
 * \file xgrammar/nanobind/nanobind.cc
 */

#include 
#include 
#include 

#include 
#include 
#include 
#include 
#include 

#include 

#include "src/turbomind/core/check.h"

namespace py = pybind11;
using namespace xgrammar;
using namespace pybind11::literals;

namespace {

static const std::vector
CommonEncodedVocabType(const py::typing::List>& lst)
{
    std::vector out;
    out.reserve(lst.size());
    for (const auto& h : lst) {
        if (py::isinstance(h)) {
            out.emplace_back(h.cast());
        }
        else if (py::isinstance(h)) {
            out.emplace_back(h.cast());
        }
        else {
            throw std::invalid_argument("encoded_vocab items must be str or bytes");
        }
    }
    return out;
}

TokenizerInfo TokenizerInfo_Init(const std::vector&     encoded_vocab,
                                 int                                 vocab_type,
                                 std::optional                  vocab_size,
                                 std::optional> stop_token_ids,
                                 bool                                add_prefix_space)
{
    TM_CHECK(vocab_type == 0 || vocab_type == 1 || vocab_type == 2) << "Invalid vocab type: " << vocab_type;
    return TokenizerInfo(
        encoded_vocab, static_cast(vocab_type), vocab_size, stop_token_ids, add_prefix_space);
}

int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer)
{
    return static_cast(tokenizer.GetVocabType());
}

std::vector TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer)
{
    const auto&            decoded_vocab = tokenizer.GetDecodedVocab();
    std::vector py_result;
    py_result.reserve(decoded_vocab.size());
    for (const auto& item : decoded_vocab) {
        py_result.emplace_back(py::bytes(item.c_str()));
    }
    return py_result;
}

}  // namespace

PYBIND11_MODULE(_xgrammar, m)
{
    py::class_>(m, "TokenizerInfo")
        .def(py::init([](const py::typing::List>& encoded_vocab,
                         int                                                           vocab_type,
                         std::optional                                            vocab_size,
                         std::optional>                           stop_token_ids,
                         bool                                                          add_prefix_space) {
                 return TokenizerInfo{TokenizerInfo_Init(CommonEncodedVocabType(encoded_vocab),
                                                         vocab_type,
                                                         vocab_size,
                                                         std::move(stop_token_ids),
                                                         add_prefix_space)};
             }),
             py::arg("encoded_vocab"),
             py::arg("vocab_type"),
             py::arg("vocab_size")     = py::none(),
             py::arg("stop_token_ids") = py::none(),
             py::arg("add_prefix_space"))

        .def_property_readonly("vocab_type", &TokenizerInfo_GetVocabType)
        .def_property_readonly("vocab_size", &TokenizerInfo::GetVocabSize)
        .def_property_readonly("add_prefix_space", &TokenizerInfo::GetAddPrefixSpace)
        .def_property_readonly("decoded_vocab", &TokenizerInfo_GetDecodedVocab)
        .def_property_readonly("stop_token_ids", &TokenizerInfo::GetStopTokenIds)
        .def_property_readonly("special_token_ids", &TokenizerInfo::GetSpecialTokenIds)

        .def("dump_metadata", &TokenizerInfo::DumpMetadata)

        .def_static("from_vocab_and_metadata",
                    [](const py::typing::List>& encoded_vocab,
                       const std::string&                                            metadata) {
                        return TokenizerInfo::FromVocabAndMetadata(CommonEncodedVocabType(encoded_vocab), metadata);
                    })

        .def_static("_detect_metadata_from_hf", &TokenizerInfo::DetectMetadataFromHF);

    py::class_(m, "CompiledGrammar");

    py::class_ pyGrammarCompiler(m, "GrammarCompiler");
    pyGrammarCompiler
        .def(py::init(),
             py::arg("tokenizer_info"),
             py::arg("max_threads")      = 8,
             py::arg("cache_enabled")    = true,
             py::arg("max_memory_bytes") = -1)
        .def("compile_json_schema",
             &GrammarCompiler::CompileJSONSchema,
             py::call_guard(),
             py::arg("schema"),
             py::arg("any_whitespace")     = false,
             py::arg("indent")             = py::none(),
             py::arg("separators")         = py::none(),
             py::arg("strict_mode")        = true,
             py::arg("max_whitespace_cnt") = py::none())
        .def("compile_regex",
             &GrammarCompiler::CompileRegex,
             py::call_guard(),
             py::arg("schema"));
}


================================================
FILE: src/turbomind/turbomind.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 

#include "src/turbomind/turbomind.h"

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/check.h"
#include "src/turbomind/core/context.h"
#include "src/turbomind/core/core.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/engine/engine.h"
#include "src/turbomind/engine/gateway.h"
#include "src/turbomind/engine/model_executor.h"
#include "src/turbomind/engine/model_request.h"

#include "src/turbomind/models/language_model.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h"

#include "src/turbomind/kernels/gemm/tuner/params.h"

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/metrics.h"

#include 

// #include "dbg.h"

namespace turbomind {

using std::vector;
using std::string;
using std::shared_ptr;
using std::unique_ptr;

static std::optional get_moe_method()
{
    static const auto value = []() -> std::optional {
        const auto p = std::getenv("TM_MOE_METHOD");
        if (p) {
            std::string str(p);
            for (auto& x : str) {
                x = std::tolower(x);
            }
            if (str == "naive") {
                return MoeParam::kNaive;
            }
            else if (str == "fused") {
                return MoeParam::kFused;
            }
            else {
                std::cerr << "[WARNING] unrecognised MoE method: " << str << "\n";
            }
        }
        return {};
    }();
    return value;
}

/// TODO: move config parsing to suitable place
static void parse_default_rope_param(const YAML::Node& node, RopeParam& param)
{
    param.base = node["base"].as();
    param.dim  = node["dim"].as();
    if (param.base == 0.f || param.dim == 0) {
        TM_LOG_ERROR("invalid rope param: base = %f, dim = %d", param.base, param.dim);
        FT_CHECK(0);
    }
}

static void parse_linear_rope_param(const YAML::Node& node, RopeParam& param)
{
    parse_default_rope_param(node, param);
    param.factor = node["factor"].as();
}

static void parse_dynamic_rope_param(const YAML::Node& node, RopeParam& param)
{
    parse_linear_rope_param(node, param);
    param.max_position_embeddings = node["max_position_embeddings"].as();
}

static void parse_yarn_rope_param(const YAML::Node& node, RopeParam& param)
{
    parse_dynamic_rope_param(node, param);
    param.yarn.attention_factor = node["attention_factor"].as();
    param.yarn.beta_fast        = node["beta_fast"].as();
    param.yarn.beta_slow        = node["beta_slow"].as();
}

static void parse_llama3_rope_param(const YAML::Node& node, RopeParam& param)
{
    parse_linear_rope_param(node, param);
    param.llama3.low_freq_factor                  = node["low_freq_factor"].as();
    param.llama3.high_freq_factor                 = node["high_freq_factor"].as();
    param.llama3.original_max_position_embeddings = node["original_max_position_embeddings"].as();
}

static void parse_mrope_rope_param(const YAML::Node& node, RopeParam& param)
{
    parse_default_rope_param(node, param);
    auto mrope_section = node["mrope_section"].as>();
    FT_CHECK(mrope_section.size() == 3);
    param.mrope.section = {mrope_section[0], mrope_section[1], mrope_section[2]};
}

static void parse_rope_param(const YAML::Node& node, RopeParam& rope)
{
    rope.type = GetRoPEType(node["type"].as());

    switch (rope.type) {
        case RopeType::kDefault:
            parse_default_rope_param(node, rope);
            break;
        case RopeType::kLinear:
            parse_linear_rope_param(node, rope);
            break;
        case RopeType::kDynamic:
            parse_dynamic_rope_param(node, rope);
            break;
        case RopeType::kYarn:
            parse_yarn_rope_param(node, rope);
            break;
        case RopeType::kLlama3:
            parse_llama3_rope_param(node, rope);
            break;
        case RopeType::kMrope:
            parse_mrope_rope_param(node, rope);
            break;
        default:
            FT_CHECK(0);
            break;
    }
}

static DataType data_type_from_string(std::string str)
{
    if (str == "fp16" || str == "float16") {
        return kFloat16;
    }
    else if (str == "bf16" || str == "bfloat16") {
        return kBfloat16;
    }
    else if (str == "fp32") {
        return kFloat32;
    }
    else if (str == "int8") {
        return kUint8;
    }
    else if (str == "int4") {
        return kUint4;
    }
    else if (str == "fp8") {
        return kFloat8_e4m3;
    }
    else if (str == "e2m1") {
        return kFloat4_e2m1;
    }
    TM_CHECK(0) << "unsupported weight type: " << str;
    return {};
}

struct TurboMind::Impl {
    DataType       data_type_;
    ModelParam     model_param_;
    AttentionParam attn_param_;
    MoeParam       moe_param_;
    EngineParam    engine_param_;
    size_t         comm_size_;

    vector engine_params_;

    string communicator_type_;  // communicator backend

    unique_ptr group_id_;

    shared_ptr gateway_;

    FFICtxFactory ffi_ctx_factory_;

    vector global_rank_;

    // Weights & engine instances for the ranks
    vector> weights_;
    vector>     contexts_;
    vector                  engines_;

    string model_name_;
    string model_dir_;

    vector queue_id_;
    int         n_queues_{0};

    int need_warm_up_{1};
    int phases_{1};

    ~Impl();

    Impl(string model_dir, string config, FFICtxFactory ffi_ctx_factory);

    unique_ptr CreateRequest()
    {
        return std::make_unique(gateway_.get(),  //
                                              data_type_,
                                              engine_param_.session_len,
                                              model_param_.vocab_size,
                                              model_param_.hidden_units);
    }

    void CreateWeights(int index)
    {
        CudaDeviceGuard dev_guard(engine_param_.devices[index]);

        CreateContext(index);

        weights_[index] = std::make_shared(data_type_,  //
                                                        model_param_,
                                                        engine_params_.at(index),
                                                        moe_param_);
    }

    TensorMap GetWeights(int index)
    {
        const auto& tensor_ptr_map = TM_CHECK_NOTNULL(weights_[index])->get_parameters();
        TensorMap   params;
        for (const auto& [name, tensor_ptr] : tensor_ptr_map) {
            params[name] = *tensor_ptr;
        }
        return params;
    }

    void ProcessWeights(int index)
    {
        CudaDeviceGuard dev_guard(engine_param_.devices[index]);
        FT_CHECK(weights_[index] != nullptr);

        cudaDeviceProp props{};
        check_cuda_error(cudaGetDeviceProperties(&props, engine_param_.devices[index]));

        weights_[index]->prepare(props);
        sync_check_cuda_error();
    }

    void CreateEngine(int index);

    void CreateContext(int index);

    void WarmUp(int index);

    void Sleep(int index, int level)
    {
        CudaDeviceGuard dev_guard(engine_param_.devices[index]);

        if (level == 2) {
            // free weights
            weights_[index]->release();
        }
        else {
            // offload weights to CPU
            TM_CHECK(moe_param_.experts_per_token == 0) << "level 1 sleep not supported for MoE model";
            weights_[index]->to_device(kCPU);
        }

        // free model (kv cache and buffer)
        if (index == 0) {
            gateway_->shutdown();
            gateway_.reset();
        }

        engines_[index] = {};
        contexts_[index]->allocator->trim(0);

        trim_default_mempool(engine_param_.devices[index]);
    }

    void WakeUp(int index, const std::vector& tags)
    {
        CudaDeviceGuard dev_guard(engine_param_.devices[index]);

        std::set keys(tags.begin(), tags.end());

        auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);

        if (keys.find("weights") != keys.end()) {
            TM_CHECK(weights_[index] != nullptr);
            if (weights_[index]->is_initialized()) {
                weights_[index]->to_device(kDEVICE);
            }
            else {
                weights_[index]->initialize();
            }
        }

        if (keys.find("kv_cache") != keys.end()) {
            if (index == 0) {
                gateway_ = std::make_shared(n_queues_, ffi_ctx_factory_);
            }
            CreateEngine(index);
        }
    }

    void HandleMissingParams()
    {
        if (!engine_param_.max_context_token_num) {
            engine_param_.max_context_token_num = engine_param_.session_len;
            TM_LOG_WARNING("[TM] `max_context_token_num` is not set, default to %d.",
                           (int)engine_param_.max_context_token_num);
        }

        if (engine_param_.max_context_token_num <= engine_param_.max_batch_size) {
            engine_param_.max_context_token_num *= engine_param_.session_len;
            TM_LOG_WARNING("[TM] `max_context_token_num` = %d.", (int)engine_param_.max_context_token_num);
        }
    }
};

TurboMind::Impl::~Impl()
{
    TM_LOG_INFO(__PRETTY_FUNCTION__);
    if (gateway_) {
        gateway_->shutdown();
    }
    for (int i = 0; i < (int)engines_.size(); ++i) {
        /// TODO: make device part of core::Context
        CudaDeviceGuard device(engine_param_.devices[i]);
        {
            core::ContextGuard context{contexts_[i]->core_stream};
            engines_[i]  = {};
            contexts_[i] = {};
        }
        weights_[i] = {};
    }
}

TurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_factory):
    data_type_{}, model_param_{}, attn_param_{}, moe_param_{}, engine_param_{}, ffi_ctx_factory_{ffi_ctx_factory}
{
    TM_CHECK(!config.empty());

    YAML::Node node;
    try {
        node = YAML::Load(config);
    }
    catch (const YAML::Exception& e) {
        TM_CHECK(0) << "Error loading YAML config: " << e.what() << "\nconfig:\n" << config;
    }

    /// TODO: move config parsing to suitable place
    const auto model     = node["model_config"];
    const auto attention = node["attention_config"];
    const auto engine    = node["engine_config"];

    data_type_ = model_param_.data_type = data_type_from_string(model["data_type"].as());
    TM_CHECK(data_type_ == kBfloat16 || data_type_ == kHalf);

    model_name_                     = model["model_name"].as();
    model_param_.head_num           = model["head_num"].as();
    model_param_.head_dim           = model["size_per_head"].as();
    model_param_.kv_head_num        = model["kv_head_num"].as(0);
    model_param_.hidden_units       = model["hidden_units"].as();
    model_param_.layer_num          = model["num_layer"].as();
    model_param_.vocab_size         = model["vocab_size"].as();
    model_param_.embedding_size     = model["embedding_size"].as();
    model_param_.norm_eps           = model["norm_eps"].as();
    model_param_.tune_layer_num     = model["tune_layer_num"].as(1);
    model_param_.mla.q_lora_rank    = model["q_lora_rank"].as();
    model_param_.mla.kv_lora_rank   = model["kv_lora_rank"].as();
    model_param_.mla.qk_rope_dim    = model["qk_rope_dim"].as();
    model_param_.mla.v_head_dim     = model["v_head_dim"].as();
    attn_param_.cache_block_seq_len = attention["cache_block_seq_len"].as(0);
    model_param_.quant_policy       = engine["quant_policy"].as(0);

    auto inter_size = model["inter_size"];
    for (auto it = inter_size.begin(); it != inter_size.end(); ++it) {
        model_param_.inter_size.push_back(it->as());
    }

    if (auto layer_types = model["layer_types"]) {
        for (auto it = layer_types.begin(); it != layer_types.end(); ++it) {
            auto type_str = it->as("");
            if (type_str == "linear_attention") {
                model_param_.layer_types.push_back(1);
            }
            else if (type_str == "full_attention" || type_str.empty()) {
                model_param_.layer_types.push_back(0);
            }
            else {
                TM_LOG_WARNING("[TM] Unknown layer_type '%s', treating as full_attention.", type_str.c_str());
                model_param_.layer_types.push_back(0);
            }
        }
    }

    // Qwen3.5 Gated DeltaNet linear attention parameters
    model_param_.linear_key_head_dim    = model["linear_key_head_dim"].as(0);
    model_param_.linear_value_head_dim  = model["linear_value_head_dim"].as(0);
    model_param_.linear_conv_kernel_dim = model["linear_conv_kernel_dim"].as(0);
    model_param_.linear_num_key_heads   = model["linear_num_key_heads"].as(0);
    model_param_.linear_num_value_heads = model["linear_num_value_heads"].as(0);
    model_param_.attn_output_gate       = model["attn_output_gate"].as(false);
    model_param_.linear_state_dtype     = data_type_;

    if (auto uqel = model["unquantized_expert_layers"]) {
        for (auto it = uqel.begin(); it != uqel.end(); ++it) {
            model_param_.unquantized_expert_layers.insert(it->as());
        }
    }
    model_param_.attn_sink = model["attn_sink"].as();
    model_param_.mlp_bias  = model["mlp_bias"].as();
    if (model["activation_type"].as("") == "gpt-oss") {
        model_param_.act_type = ActivationType::kSiluGptOss;
    }

    auto window_size = model["window_size"];
    for (auto it = window_size.begin(); it != window_size.end(); ++it) {
        model_param_.window_size.push_back(it->as());
    }

    model_param_.attn_bias  = model["attn_bias"].as(0);
    model_param_.qk_norm    = model["qk_norm"].as();
    model_param_.group_size = model["group_size"].as(0);

    attn_param_.softmax_scale = attention["softmax_scale"].as(0);
    // logn attn for qwen model
    attn_param_.use_logn_attn           = attention["use_logn_attn"].as(0);
    attn_param_.max_position_embeddings = attention["max_position_embeddings"].as(0);
    // rotary embedding parameters
    parse_rope_param(attention["rope_param"], attn_param_.rope);

    engine_param_.max_batch_size = engine["max_batch_size"].as(0);
    auto max_forward_token_num   = engine["max_prefill_token_num"].as(0);
    max_forward_token_num += engine_param_.max_batch_size;

    engine_param_.max_context_token_num = engine["max_context_token_num"].as(0);
    engine_param_.session_len           = model["session_len"].as(0);

    engine_param_.cache_max_block_count = engine["cache_max_entry_count"].as(0);
    engine_param_.cache_chunk_size      = engine["cache_chunk_size"].as(0);
    engine_param_.enable_prefix_caching = engine["enable_prefix_caching"].as(false);
    engine_param_.enable_metrics        = engine["enable_metrics"].as(false);

    if (engine_param_.enable_prefix_caching && HasLinearAttention(model_param_)) {
        TM_CHECK(0) << "Prefix caching is unsupported when linear attention is present";
    }

    engine_param_.num_tokens_per_iter = engine["num_tokens_per_iter"].as(0);
    engine_param_.max_prefill_iters   = engine["max_prefill_iters"].as(1);

    phases_ = engine["async_"].as() ? 2 : 1;

    engine_param_.outer_dp_size = engine["outer_dp_size"].as();

    engine_param_.attn_dp_size = engine["attn_dp_size"].as();
    engine_param_.attn_tp_size = engine["attn_tp_size"].as();
    engine_param_.attn_cp_size = engine["attn_cp_size"].as();

    engine_param_.mlp_tp_size = engine["mlp_tp_size"].as();

    engine_param_.devices = engine["devices"].as>();

    // multi-node information
    engine_param_.nnodes    = engine["nnodes"].as();
    engine_param_.node_rank = engine["node_rank"].as();

    {
        auto sp                             = engine_param_.attn_tp_size * engine_param_.attn_cp_size;
        engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + sp - 1) / sp * sp;
    }

    comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size * engine_param_.attn_cp_size;
    FT_CHECK(engine_param_.mlp_tp_size == comm_size_);

    communicator_type_ = engine["communicator"].as();

    moe_param_.experts_per_token = model["experts_per_token"].as(0);
    moe_param_.inter_size        = model["expert_inter_size"].as(0);
    moe_param_.shared_gate       = model["moe_shared_gate"].as();
    moe_param_.norm_topk_prob    = model["norm_topk_prob"].as();
    moe_param_.routed_scale      = model["routed_scale"].as(1.f);
    moe_param_.topk_group        = model["topk_group"].as(1);
    moe_param_.topk_method       = model["topk_method"].as("greedy");
    moe_param_.n_group           = model["moe_group_num"].as(1);
    moe_param_.scoring_func      = model["scoring_func"].as("softmax");
    moe_param_.router_n_groups   = model["router_n_groups"].as(-1);
    moe_param_.router_bias       = model["expert_router_bias"].as();
    YAML::Node expert_num        = model["expert_num"];
    for (auto it = expert_num.begin(); it != expert_num.end(); ++it) {
        moe_param_.expert_num.push_back(it->as());
    }

    HandleMissingParams();

    weights_.resize(engine_param_.devices.size());
    engines_.resize(engine_param_.devices.size());
    contexts_.resize(engine_param_.devices.size());

    model_param_.weight_type        = data_type_from_string(model["weight_type"].as());
    model_param_.expert_weight_type = data_type_from_string(model["expert_weight_type"].as());
    model_param_.ffn_weight_type =
        data_type_from_string(model["ffn_weight_type"].as(model["weight_type"].as()));

    if (auto method = get_moe_method()) {
        moe_param_.method = *method;
    }
    else {
        moe_param_.method = MoeParam::kFused;
    }

    // NOTE: This runs on Python main thread
    group_id_ = comm::CreateHostGroupId((engine_param_.nnodes == 1) ? "" : "hybrid");
    group_id_->Initialize();

    const int devices = engine_param_.devices.size();

    for (int i = 0; i < devices; ++i) {
        global_rank_.push_back(engine_param_.node_rank * devices + i);
    }

    queue_id_.resize(devices);
    engine_params_.resize(devices, engine_param_);
}

void TurboMind::Impl::CreateContext(int index)
{
    auto& p = engine_params_[index];

    CudaDeviceGuard dev_guard(p.devices[index]);

    TM_CHECK(contexts_[index] == nullptr);

    auto& ctx = contexts_[index] = std::make_shared(p.devices[index]);

    // Layout: (outer, dp, tp, cp)

    const int global_rank = global_rank_[index];

    const int outer_rank = global_rank / comm_size_;
    const int inner_rank = global_rank % comm_size_;

    p.outer_dp_rank = outer_rank;

    const int tp_cp_size = p.attn_tp_size * p.attn_cp_size;

    const int tp_color = inner_rank / tp_cp_size;
    const int dp_color = inner_rank % tp_cp_size;
    const int cp_color = inner_rank / p.attn_cp_size;

    auto& c = ctx->comm;

    c.h_global = group_id_->CreateCommunicator(comm_size_, global_rank, p.node_rank);

    c.h_comm = c.h_global->Split(outer_rank, 0);

    c.h_tp_group = c.h_comm->Split(tp_color, 0);
    c.h_dp_group = c.h_comm->Split(dp_color, 0);

    if (comm_size_ > 1) {
        c.d_comm = CreateDeviceCommunicator(communicator_type_, comm_size_, inner_rank, c.h_comm);

        c.d_tp_group = 0;
        c.d_cp_group = 0;

        if (p.attn_dp_size > 1) {  // has attn_dp
            c.d_tp_group   = c.d_comm->Split(tp_color, 0, 0);
            p.attn_dp_rank = c.h_dp_group->rank();
        }

        if (p.attn_cp_size > 1) {  // has attn_cp
            c.d_cp_group   = c.d_comm->Split(cp_color, 0, 0);
            p.attn_cp_rank = c.d_comm->rank(c.d_cp_group);
        }

        p.attn_tp_rank = c.d_comm->rank(c.d_tp_group) / p.attn_cp_size;
        p.mlp_tp_rank  = c.d_comm->rank(0);
    }

    if (c.h_tp_group->rank() == 0) {
        queue_id_[index] = 1;
    }

    c.h_global->Sync();

    if (index == 0) {
        n_queues_ = 0;
        for (size_t i = 0; i < queue_id_.size(); ++i) {
            queue_id_[i] = queue_id_[i] ? n_queues_++ : -1;
        }
        gateway_ = std::make_shared(n_queues_, ffi_ctx_factory_);
    }

    c.h_global->Sync();
}

void TurboMind::Impl::CreateEngine(int index)
{
    CudaDeviceGuard dev_guard(engine_param_.devices[index]);

    auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);

    core::ContextGuard guard{ctx.core_stream, ctx.allocator, Allocator{kCPUpinned}};

    const auto& param = engine_params_.at(index);

    ctx.comm.h_comm->Sync();

    // create model
    LanguageModel model{data_type_,  //
                        model_param_,
                        param,
                        attn_param_,
                        moe_param_,
                        ctx,
                        *weights_[index],
                        phases_};

    // create engine
    engines_[index] = Engine{data_type_,  //
                             param,
                             std::move(model),
                             ctx,
                             *gateway_,
                             engine_param_.devices[index],
                             queue_id_[index],
                             phases_};

    core::Context::stream().Sync();

    ctx.comm.h_comm->Sync();

    engines_[index].Start();

    if (need_warm_up_) {
        WarmUp(index);
    }
}

template
static std::string Join(Iter first, Iter last, const std::string& delim)
{
    if (first == last) {
        return {};
    }
    std::ostringstream oss;
    oss << *first++;
    while (first != last) {
        oss << delim << *first++;
    }
    return oss.str();
}

void TurboMind::Impl::WarmUp(int index)
{
    auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);

    auto& global = ctx.comm.h_global;
    auto& linear = *ctx.linear;

    if (auto str = std::getenv("TM_GEMM_IMPORT")) {
        std::ifstream ifs(str);
        const int     n_imported = linear.Import(ifs);
        if (index == 0) {
            TM_LOG_INFO("[GEMM] %d records imported", n_imported);
        }
        return;
    }

    global->Sync();

    *ctx.is_warm_up = 1;
    linear.set_measure(true);

    if (index == 0) {
        gateway_->set_threshold(engine_param_.attn_dp_size);
    }

    global->Sync();

    if (ctx.comm.h_tp_group->rank() == 0) {

        std::vector bss = linear.GetTuningSeq();
        if (bss.empty()) {
            bss = gemm::GenerateTuningSequence(gemm::GetDefaultTuningGenerators());
        }

        const int max_fwd_token_num = engine_param_.max_forward_token_num;

        // remove bs that is too large
        bss.erase(std::remove_if(bss.begin(), bss.end(), [&](auto x) { return x > max_fwd_token_num; }), bss.end());

        if (bss.empty() || bss.back() < max_fwd_token_num) {
            bss.push_back(max_fwd_token_num);
        }

        auto str = Join(bss.begin(), bss.end(), ", ");
        TM_LOG_INFO("[Engine] Warm-up lengths: %s", str.c_str());

        if (!bss.empty()) {
            const auto                         max_bs = *std::max_element(bss.begin(), bss.end());
            Buffer_                       input_ids(max_bs, kCPU);
            std::mt19937                       g{};
            std::uniform_int_distribution d{0, (int)model_param_.vocab_size - 1};
            for (auto& x : input_ids) {
                x = d(g);
            }

            auto tick = std::chrono::steady_clock::now();

            for (auto token_num : bss) {

                TM_LOG_INFO("[WarmUp] %d", token_num);

                auto r = CreateRequest();

                TensorMap inputs{{"input_ids", input_ids.slice(0, token_num)}};

                ModelRequest::InputParam param{};
                param.session.start_flag     = true;
                param.session.end_flag       = true;
                param.gen_cfg.max_new_tokens = 1;
                param.tensors                = std::make_shared(inputs);

                struct Channel {
                    int                flag = 1;
                    std::promise promise;
                };
                auto c = std::make_shared();

                ModelRequest::OutputParam out = r->Forward(std::move(param), [c] {
                    /// NOTE: It's risky to set `out.state` here, `out` may not be initialized at this point
                    if (std::exchange(c->flag, 0)) {
                        c->promise.set_value();
                    }
                });

                c->promise.get_future().get();

                int status = -1;
                if (auto state = out.state->exchange(nullptr)) {
                    status = state->status;
                }

                if (status != Request::kFinish) {
                    TM_LOG_ERROR("[Engine] Warm-up for %d tokens failed with status %d", (int)token_num, (int)status);
                }
            }

            auto tock = std::chrono::steady_clock::now();

            TM_LOG_INFO("[WarmUp] Warm-up finished in %.2f seconds.",
                        std::chrono::duration>(tock - tick).count());
        }
    }

    global->Sync();

    linear.set_measure(false);
    *ctx.is_warm_up = 0;

    if (index == 0) {
        if (auto path = std::getenv("TM_GEMM_EXPORT")) {
            std::ofstream ofs(path);
            const auto    n_records = linear.Export(ofs);
            TM_LOG_INFO("[GEMM] %d records exported.", n_records);
        }

        gateway_->set_threshold(1);
        need_warm_up_ = 0;
    }

    global->Sync();
}

TurboMind::~TurboMind() = default;

TurboMind::TurboMind(string model_dir, string config, FFICtxFactory ffi_ctx_factory):
    impl_{std::make_unique(model_dir, config, ffi_ctx_factory)}
{
}

void TurboMind::CreateWeights(int index)
{
    return impl_->CreateWeights(index);
}

TensorMap TurboMind::GetWeights(int index)
{
    return impl_->GetWeights(index);
}

void TurboMind::ProcessWeights(int index)
{
    return impl_->ProcessWeights(index);
}

void TurboMind::CreateEngine(int index)
{
    return impl_->CreateEngine(index);
}

void TurboMind::Sleep(int index, int level)
{
    return impl_->Sleep(index, level);
}

void TurboMind::WakeUp(int index, const vector& tags)
{
    return impl_->WakeUp(index, tags);
}

shared_ptr TurboMind::GetScheduleMetrics(int index)
{
    return impl_->engines_[index].GetScheduleMetrics();
}

unique_ptr TurboMind::CreateRequest()
{
    return impl_->CreateRequest();
}

bool TurboMind::is_dummy_node() const noexcept
{
    return impl_->n_queues_ == 0;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/turbomind.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 

#include "src/turbomind/core/core.h"
#include "src/turbomind/engine/model_request.h"
#include "src/turbomind/utils/metrics.h"

namespace turbomind {

class TurboMind {
public:
    using FFICtxFactory = std::function()>;

    ~TurboMind();

    TurboMind(std::string model_dir, std::string config, FFICtxFactory ffi_ctx_factory);

    void CreateWeights(int index);

    TensorMap GetWeights(int index);

    void ProcessWeights(int index);

    void CreateEngine(int index);

    void Sleep(int index, int level);

    void WakeUp(int index, const std::vector& tags);

    bool is_dummy_node() const noexcept;

    std::shared_ptr GetScheduleMetrics(int index);

    std::unique_ptr CreateRequest();

private:
    struct Impl;
    std::unique_ptr impl_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/CMakeLists.txt
================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.11)

find_package(CUDAToolkit REQUIRED)

add_library(logger STATIC logger.cc)
set_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
target_link_libraries(logger PUBLIC CUDA::cudart)


add_library(cuda_utils STATIC cuda_utils.cc)
set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
target_link_libraries(cuda_utils PUBLIC logger CUDA::cudart CUDA::cuda_driver)


add_library(nvtx_utils STATIC nvtx_utils.cc)
set_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
if(${CMAKE_VERSION} VERSION_LESS "3.25")
    target_link_libraries(nvtx_utils PUBLIC CUDA::nvToolsExt -ldl)
else()
    target_link_libraries(nvtx_utils PUBLIC CUDA::nvtx3 -ldl)
endif()

add_library(memory_utils STATIC memory_utils.cu)
set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
target_link_libraries(memory_utils PUBLIC cuda_utils logger)

add_library(anomaly_handler STATIC anomaly_handler.cu)
set_property(TARGET anomaly_handler PROPERTY POSITION_INDEPENDENT_CODE  ON)
set_property(TARGET anomaly_handler PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
target_link_libraries(anomaly_handler PUBLIC cuda_utils logger)

add_library(parser STATIC parser.cc)
set_property(TARGET parser PROPERTY POSITION_INDEPENDENT_CODE  ON)


================================================
FILE: src/turbomind/utils/anomaly_handler.cu
================================================


#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {

static std::optional parse_float(const std::string& s, const std::string& key)
{
    if (auto pos = s.find(key); pos != std::string::npos) {
        float value{};
        if (sscanf(s.c_str() + pos + key.size(), "%f", &value) != EOF) {
            return value;
        }
    }
    return {};
}

template
__global__ void CountAndFixAnormaly(
    T* data, int64_t size, unsigned long long* n_inf, unsigned long long* n_nan, T pinf_val, T ninf_val, T nan_val)
{
    int inf_count{};
    int nan_count{};

    for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) {
        auto x = static_cast(data[i]);
        if (isinf(x)) {
            ++inf_count;
            data[i] = x > 0.f ? pinf_val : ninf_val;
        }
        else if (isnan(x)) {
            ++nan_count;
            data[i] = nan_val;
        }
    }

    typedef cub::BlockReduce BlockReduce;

    __shared__ typename BlockReduce::TempStorage temp_storage;

    if (n_inf) {
        inf_count = BlockReduce(temp_storage).Sum(inf_count);
        if (threadIdx.x == 0) {
            atomicAdd(n_inf, inf_count);
        }
    }

    // Wait for last use of `temp_storage`
    __syncthreads();

    if (n_nan) {
        nan_count = BlockReduce(temp_storage).Sum(nan_count);
        if (threadIdx.x == 0) {
            atomicAdd(n_nan, nan_count);
        }
    }
}

template
__global__ void FixLogitsAnomaly(T*   logits,  //
                                 int* is_anomaly,
                                 int  vocab_size,
                                 int  batch_size,
                                 int  fallback)
{
    const int bi = blockIdx.x;

    T* ptr = logits + vocab_size * bi;

    int count = 0;

    // Accumulate per thread anomaly count
    for (int i = threadIdx.x; i < vocab_size; i += BLOCK_SIZE) {
        const float val = static_cast(ptr[i]);
        count += static_cast(isnan(val) || isinf(val));
    }

    // If anything goes wrong
    int error = __syncthreads_or(count);

    if (!error) {
        return;
    }

    // Clear all logits
    for (int i = threadIdx.x; i < vocab_size; i += BLOCK_SIZE) {
        ptr[i] = T(0.f);
    }

    // Set the fallback token
    if (fallback % BLOCK_SIZE == threadIdx.x) {
        // Ideally we want INF here, but it leads to `INF - INF -> NaN` in the sampling kernels
        // Setting other logits to -INF has similar problem when banning bad words (same -INF)
        ptr[fallback] = T(65504.f);  // Maximum finite value of half
    }

    if (threadIdx.x == 0 && is_anomaly) {
        is_anomaly[bi] = 1;
    }
}

struct AnomalyHandler::Impl {

    Impl()
    {
        GlobalInit();

        if (g_level) {
            d_count_.resize(max_entries * 2);
            h_count_.resize(d_count_.size());
        }
    }

    // Process level initialization from environment variable
    static void GlobalInit()
    {
        [[maybe_unused]] static const auto _ = []() -> bool {
            const auto var = std::getenv("TM_ANOMALY_HANDLER");
            if (!var) {
                return false;
            }
            const std::string str{var};

            const auto level = parse_float(str, "level=");
            if (level) {
                g_level = static_cast(*level);
            }

            TM_LOG_WARNING("[AnomalyHandler] level: %d", g_level);

            if (!g_level) {
                return {};
            }

            const auto pos_inf = parse_float(str, "pinf=");
            if (pos_inf) {
                g_pinf_val_ = *pos_inf;
                TM_LOG_WARNING("[AnomalyHandler] +INF -> %f", g_pinf_val_);
            }

            const auto neg_inf = parse_float(str, "ninf=");
            if (neg_inf) {
                g_ninf_val_ = *neg_inf;
                TM_LOG_WARNING("[AnomalyHandler] -INF -> %f", g_ninf_val_);
            }

            if (!pos_inf && !neg_inf) {
                if (const auto flush_inf = parse_float(str, "inf=")) {
                    g_pinf_val_ = *flush_inf;
                    g_ninf_val_ = -g_pinf_val_;
                    TM_LOG_WARNING("[AnomalyHandler] +INF -> %f", g_pinf_val_);
                    TM_LOG_WARNING("[AnomalyHandler] -INF -> %f", g_ninf_val_);
                }
            }

            if (const auto nan = parse_float(str, "nan=")) {
                g_nan_val_ = *nan;
                TM_LOG_WARNING("[AnomalyHandler] NaN -> %f", g_nan_val_);
            }

            const auto fallback = parse_float(str, "fallback=");
            if (fallback) {
                g_fallback = *fallback;
                TM_LOG_WARNING("[AnomalyHandler] fallback -> %d", g_fallback);
            }

            return {};
        }();
    }

    void Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream)
    {
        if (g_level) {
            rank_       = rank;
            stream_     = stream;
            vocab_size_ = vocab_size;

            max_batch_size_ = max_batch_size;

            d_is_anomaly_.resize(max_batch_size);
            h_is_anomaly_.resize(max_batch_size);

            fallback_ = g_fallback;

            // When fallback is not set from env
            if (fallback_ == -1) {
                fallback_ = fallback;
                TM_LOG_WARNING("[AnomalyHandler] fallback: %d", fallback_);
            }

            FT_CHECK(0 <= fallback_);
            FT_CHECK(fallback_ < vocab_size);

            TM_LOG_WARNING("[AnomalyHandler] max_batch_size: %d", max_batch_size);
            TM_LOG_WARNING("[AnomalyHandler] vocab_size: %d", vocab_size);
        }
    }

    void Summarize(std::function handler)
    {
        if (g_level) {
            check_cuda_error(cudaMemcpyAsync(h_count_.data(),
                                             d_count_.data().get(),
                                             sizeof(size_type) * info_.size() * 2,
                                             cudaMemcpyDefault,
                                             stream_));

            check_cuda_error(cudaMemcpyAsync(h_is_anomaly_.data(),
                                             d_is_anomaly_.data().get(),
                                             sizeof(int) * batch_size_,
                                             cudaMemcpyDefault,
                                             stream_));

            check_cuda_error(cudaStreamSynchronize(stream_));

#if 0
            int die = 0;
            for (size_t i = 0; i < info_.size(); ++i) {
                const auto& n_inf = h_count_[i * 2];
                const auto& n_nan = h_count_[i * 2 + 1];
                if (n_inf || n_nan) {
                    TM_LOG_WARNING("[AnomalyHandler][rank=%d] (%s) INF: %s, NaN: %s",
                                   rank_,
                                   info_[i].c_str(),
                                   std::to_string(n_inf).c_str(),
                                   std::to_string(n_nan).c_str());
                    ++die;
                }
            }
            TM_CHECK_EQ(die, 0);
#endif

            handler(h_is_anomaly_.data(), batch_size_);
        }
    }

    void Reset()
    {
        if (g_level) {
            if (!info_.empty()) {
                std::fill_n(h_count_.data(), info_.size() * 2, 0);
                check_cuda_error(
                    cudaMemsetAsync(d_count_.data().get(), 0, sizeof(size_type) * info_.size() * 2, stream_));
                info_.clear();
            }

            if (batch_size_) {
                std::fill_n(h_is_anomaly_.data(), batch_size_, 0);
                check_cuda_error(cudaMemsetAsync(d_is_anomaly_.data().get(), 0, sizeof(int) * batch_size_, stream_));
                batch_size_ = 0;
            }
        }
    }

    template
    void invokeCountAndFixAnomaly(T* data, int64_t size, const std::string& key, int level)
    {
        if (g_level && level <= g_level) {
            FT_CHECK(size >= 0);

            constexpr int block = 512;
            const int     grid  = (size + block - 1) / block;

            auto idx = info_.size();
            auto ptr = d_count_.data().get() + idx * 2;

            info_.push_back(key);

            FT_CHECK(info_.size() <= max_entries);

            CountAndFixAnormaly<<>>(data,  //
                                                                       size,
                                                                       ptr,
                                                                       ptr + 1,
                                                                       g_pinf_val_,
                                                                       g_ninf_val_,
                                                                       g_nan_val_);

            sync_check_cuda_error();
        }
    }

    template
    void invokeFixLogitsAnomaly(T* logits, int batch_size, int level)
    {
        if (g_level && level <= g_level) {
            FT_CHECK(batch_size <= max_batch_size_);

            batch_size_ = batch_size;

            constexpr int block = 256;

            FixLogitsAnomaly<<>>(logits,  //
                                                                          d_is_anomaly_.data().get(),
                                                                          vocab_size_,
                                                                          batch_size,
                                                                          fallback_);

            sync_check_cuda_error();
        }
    }

    static int   g_level;
    static int   g_fallback;
    static float g_pinf_val_;
    static float g_ninf_val_;
    static float g_nan_val_;

    cudaStream_t stream_{};
    int          rank_{};
    int          vocab_size_{};
    int          fallback_{};
    int          max_batch_size_{};

    ////////////////////////////////////////////////////////////////////////////////
    /// Members below has SINGLE iteration validity and must be cleared in `Reset`

    // Datum for tracing anomalies
    thrust::device_vector d_count_;
    thrust::host_vector   h_count_;
    std::vector         info_;

    // Datum for fixing logits
    thrust::device_vector d_is_anomaly_;
    thrust::host_vector   h_is_anomaly_;
    int                        batch_size_{};
};

int   AnomalyHandler::Impl::g_level     = 0;
int   AnomalyHandler::Impl::g_fallback  = -1;
float AnomalyHandler::Impl::g_pinf_val_ = INFINITY;
float AnomalyHandler::Impl::g_ninf_val_ = -INFINITY;
float AnomalyHandler::Impl::g_nan_val_  = NAN;

AnomalyHandler::AnomalyHandler(): impl_{new Impl{}} {}

AnomalyHandler::~AnomalyHandler() = default;

AnomalyHandler& AnomalyHandler::instance()
{
    thread_local AnomalyHandler inst{};
    return inst;
}

void AnomalyHandler::Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream) noexcept
{
    impl_->Init(rank, vocab_size, fallback, max_batch_size, stream);
}

void AnomalyHandler::Summarize(std::function handler)
{
    impl_->Summarize(handler);
}

void AnomalyHandler::Reset()
{
    impl_->Reset();
}

template
void AnomalyHandler::CountAndFix(T* data, int64_t size, std::string key, int level)
{
    return impl_->invokeCountAndFixAnomaly(data, size, key, level);
}

template void AnomalyHandler::CountAndFix(float*, int64_t, std::string, int);
template void AnomalyHandler::CountAndFix(half*, int64_t, std::string, int);
#ifdef ENABLE_BF16
template void AnomalyHandler::CountAndFix(__nv_bfloat16*, int64_t, std::string, int);
#endif

template
void AnomalyHandler::FixLogits(T* logits, int batch_size, int level)
{
    impl_->invokeFixLogitsAnomaly(logits, batch_size, level);
}

int AnomalyHandler::level() noexcept
{
    return Impl::g_level;
}

template void AnomalyHandler::FixLogits(float*, int, int);
template void AnomalyHandler::FixLogits(half*, int, int);
#ifdef ENABLE_BF16
template void AnomalyHandler::FixLogits(__nv_bfloat16*, int, int);
#endif

void DebugTensor(Tensor& tensor, const std::string& key, int level)
{
    auto invoke = [&](auto t) {
        using T = decltype(t);
        AnomalyHandler::instance().CountAndFix((T*)tensor.raw_data(), tensor.size(), key, level);
        // Compare((T*)tensor.raw_data(), tensor.size(), key, kCmpRead, core::Context::stream().handle());
    };
    if (tensor.size() == 0) {
        return;
    }
    TM_DISPATCH_DTYPES(tensor.dtype(), invoke, float, half_t, bfloat16_t);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/anomaly_handler.h
================================================

// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 
#include 
#include 
#include 
#include 
#include 

#include "src/turbomind/core/core.h"

namespace turbomind {

class AnomalyHandler {
public:
    static constexpr size_t max_entries = 65536;

    using size_type = unsigned long long;

    ~AnomalyHandler();

    static AnomalyHandler& instance();

    static int level() noexcept;

    void Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream) noexcept;

    template
    void CountAndFix(T* data, int64_t size, std::string key, int level);

    template
    void FixLogits(T* logits, int batch_size, int level);

    void Summarize(std::function handler);

    void Reset();

private:
    AnomalyHandler();

private:
    struct Impl;
    std::unique_ptr impl_;
};

template
void count_and_fix(T* data, size_t size, std::string key, int level)
{
    AnomalyHandler::instance().CountAndFix(data, size, key, level);
}

void DebugTensor(Tensor& tensor, const std::string& key, int level);

inline void DebugTensor(Tensor&& tensor, const std::string& key, int level)
{
    DebugTensor(tensor, key, level);
}

#define TM_DEBUG_RAW(ptr, size, key, __level)                                                                          \
    if (::turbomind::AnomalyHandler::level() >= __level) {                                                             \
        ::turbomind::count_and_fix(ptr, size, key, __level);                                                           \
    }

#define TM_DEBUG_TENSOR(tensor, key, __level)                                                                          \
    if (::turbomind::AnomalyHandler::level() >= __level) {                                                             \
        ::turbomind::DebugTensor(tensor, key, __level);                                                                \
    }

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/constant.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

namespace turbomind {

const int kMaxLogProb = 1024;

}


================================================
FILE: src/turbomind/utils/cuda_bf16_fallbacks.cuh
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include 

namespace turbomind {

#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float2 f_val;
    f_val.x = __low2float(val);
    f_val.y = __high2float(val);
    return f_val;
#else
    return __bfloat1622float2(val);
#endif
}

inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float2 f_val;
    f_val.x = max(min(__low2float(val), 127.f), -128.f);
    f_val.y = max(min(__high2float(val), 127.f), -128.f);
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int8[0] = static_cast(static_cast(f_val.x));
    int8[1] = static_cast(static_cast(f_val.y));
    return int16;
#else
    val = __hmin2(val, make_bfloat162(127., 127.));
    val = __hmax2(val, make_bfloat162(-128., -128.));
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int8[0] = static_cast(static_cast(val.x));
    int8[1] = static_cast(static_cast(val.y));
    return int16;
#endif
}

inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __floats2bfloat162_rn(val.x, val.y);
#else
    return __float22bfloat162_rn(val);
#endif
}

inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    __nv_bfloat162 val2;
    val2.x = val;
    val2.y = val;
    return val2;
#else
    return __bfloat162bfloat162(val);
#endif
}

inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
    return __hadd2(x, y);
#endif
}

inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else
    return __hadd(x, y);
#endif
}

inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
    return __hsub2(x, y);
#endif
}

inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else
    return __hsub(x, y);
#endif
}

inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
    return __hmul2(x, y);
#endif
}

inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
    return __hmul(x, y);
#endif
}

inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh, fzl, fzh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    fzl = __low2float(z);
    fzh = __high2float(z);
    return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
    return __hfma2(x, y, z);
#endif
}

inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
    return __hfma(x, y, z);
#endif
}

inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    ;
    return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
    return h2exp(x);
#endif
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
    return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
    return bf16hadd2(x, y);
};

inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
    __nv_bfloat162 t;
    t.x = x;
    t.y = y;
    return t;
}

#endif

inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
    return a + b + c;
#endif
}

inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
    return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}

inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
    return a + b + c;
#endif
}

inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
    return a * b * c;
#endif
}

inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
    return a * b * c;
#endif
}

inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    fdl = __low2float(d);
    fdh = __high2float(d);
    return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
    return a * b * c + d;
#endif
}

#endif  // ENABLE_BF16

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/cuda_bf16_wrapper.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#ifdef ENABLE_BF16
#include 
#endif


================================================
FILE: src/turbomind/utils/cuda_type_utils.cuh
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "src/turbomind/utils/cuda_bf16_fallbacks.cuh"
#include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include 
#include 

namespace turbomind {

template
inline __device__ T ldg(const T* val)
{
    return __ldg(val);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return val[0];
#else
    return __ldg(val);
#endif
}

template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return val[0];
#else
    return __ldg(val);
#endif
}
#endif  // ENABLE_BF16

// Get type2 from type or vice versa (applied to half and bfloat16)
template
struct TypeConverter {
    using Type = half2;
};  // keep for generality

template<>
struct TypeConverter {
    using Type = half;
};

template<>
struct TypeConverter {
    using Type = half2;
};

#if ENABLE_BF16
template<>
struct TypeConverter<__nv_bfloat162> {
    using Type = __nv_bfloat16;
};

template<>
struct TypeConverter<__nv_bfloat16> {
    using Type = __nv_bfloat162;
};
#endif  // ENABLE_BF16

// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template
inline __device__ T hadd2(T a, T b)
{
    return __hadd2(a, b);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
    return bf16hadd2(a, b);
}
#endif  // ENABLE_BF16

template
inline __device__ T add(T a, T b)
{
    return a + b;
}

template<>
inline __device__ half2 add(half2 a, half2 b)
{
    return __hadd2(a, b);
}

template<>
inline __device__ half add(half a, half b)
{
    return __hadd(a, b);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
    return bf16hadd2(a, b);
}

template<>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
    return bf16hadd(a, b);
}

inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
    return bf16hadd(a, __float2bfloat16(b));
}
#endif  // ENABLE_BF16

// applies to all 4 values addition
template
inline __device__ T add(T a, T b, T c)
{
    return a + b + c;
}

#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
    return bf16hadd(a, b, c);
}

inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
    return bf16hadd2(a, b, c);
}
#endif  // ENABLE_BF16

// applies to all 4 values addition
template
inline __device__ T add(T a, T b, T c, T d)
{
    return (T)((float)a + (float)b + (float)c + (float)d);
}

#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
    return bf16hadd(a, b, c, d);
}
#endif  // ENABLE_BF16

template
inline __device__ T hsub2(T a, T b)
{
    return __hsub2(a, b);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
    return bf16hsub2(a, b);
}
#endif  // ENABLE_BF16

template
inline __device__ T hmul2(T a, T b)
{
    return __hmul2(a, b);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
    return bf16hmul2(a, b);
}
#endif  // ENABLE_BF16

template
inline __device__ T hmul2(T a, T b, T c)
{
    return a * b * c;
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
    return bf16hmul2(a, b, c);
}
#endif  // ENABLE_BF16

template
inline __device__ T mul(T a, T b, T c)
{
    return a * b * c;
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
    return bf16hmul(a, b, c);
}

inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
    return bf16hmul2(a, b, c);
}
#endif  // ENABLE_BF16

template
inline __device__ T fma(T a, T b, T c, T d)
{
    return a * b * c + d;
}

#if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
    return bf16hfma2(a, b, c, d);
}
#endif  // ENABLE_BF16

template
inline __device__ T fma(T a, T b, T c)
{
    return a * b + c;
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
    return bf16hfma2(a, b, c);
}

template<>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
    return bf16hfma(a, b, c);
}
#endif  // ENABLE_BF16

template
inline __device__ T hexp2(T a)
{
    return h2exp(a);
}

#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
    return bf16exp2(a);
}
#endif  // ENABLE_BF16

template
__device__ inline T_OUT cuda_cast(T_IN val)
{
    return val;
}

template<>
__device__ inline float2 cuda_cast(int2 val)
{
    return make_float2(val.x, val.y);
}
template<>
__device__ inline float2 cuda_cast(float val)
{
    return make_float2(val, val);
}
template<>
__device__ inline float2 cuda_cast(half2 val)
{
    return __half22float2(val);
}
template<>
__device__ inline half2 cuda_cast(float2 val)
{
    return __float22half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast(float val)
{
    return __float2half2_rn(val);
}
template<>
__device__ inline half2 cuda_cast(half val)
{
    return __half2half2(val);
}

template<>
__device__ inline int8_t cuda_cast(half val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    union {
        half    fp16;
        int16_t int16_in;
    };
    fp16 = val;
    asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
    return int8[0];
}

template<>
__device__ inline int16_t cuda_cast(half2 val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int8[0] = cuda_cast(val.x);
    int8[1] = cuda_cast(val.y);
    return int16;
}

template<>
__device__ inline int8_t cuda_cast(float val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
    return int8[0];
}

template<>
__device__ inline int16_t cuda_cast(float2 val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int8[0] = cuda_cast(val.x);
    int8[1] = cuda_cast(val.y);
    return int16;
}

template<>
__device__ inline half2 cuda_cast(int16_t val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int16 = val;
    return make_half2(int8[0], int8[1]);
}

template<>
__device__ inline float2 cuda_cast(int16_t val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int16 = val;
    return make_float2(int8[0], int8[1]);
}

#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
    return static_cast(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
    return static_cast(val);
}
template<>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
    return static_cast(val);
}

template<>
__device__ inline float cuda_cast(__nv_bfloat16 val)
{
    return __bfloat162float(val);
}

template<>
__device__ inline float2 cuda_cast(__nv_bfloat162 val)
{
    return bf1622float2(val);
}

template<>
__device__ inline half cuda_cast(__nv_bfloat16 val)
{
    return __float2half(__bfloat162float(val));
}

template<>
__device__ inline int16_t cuda_cast(__nv_bfloat162 val)
{
    return bf1622int16(val);
}

template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
    return __float2bfloat16(val);
}
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
    return __float2bfloat16(__half2float(val));
}

template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
    return bf162bf162(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
    return __float2bfloat162_rn(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
    return float22bf162(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
    union {
        int8_t  int8[2];
        int16_t int16;
    };
    int16 = val;
    __nv_bfloat162 res;
    res.x = cuda_cast<__nv_bfloat16>(int8[0]);
    res.y = cuda_cast<__nv_bfloat16>(int8[1]);
    return res;
}

template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
    return float22bf162(__half22float2(val));
}

#endif  // ENABLE BF16

template
__device__ inline T cuda_abs(T val);
template<>
__device__ inline float cuda_abs(float val)
{
    return fabs(val);
}
template<>
__device__ inline half cuda_abs(half val)
{
    return __habs(val);
}
template<>
__device__ inline half2 cuda_abs(half2 val)
{
    return __habs2(val);
}

#ifdef ENABLE_BF16

#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
    return __habs(val);
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
    return __habs2(val);
}
#else
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
    return fabs(cuda_cast(val));
}
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
    return make_bfloat162(fabs(cuda_cast(val.x)), fabs(cuda_cast(val.y)));
}
#endif

#endif  // ENABLE_FP16

// Unary maximum: compute the max of a vector type
template
__device__ inline To cuda_max(Ti val)
{
    return cuda_cast(val);
};

template<>
__device__ inline half cuda_max(half2 val)
{
    return (val.x > val.y) ? val.x : val.y;
}
#ifdef ENABLE_BF16
template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
    return (val.x > val.y) ? val.x : val.y;
}
#endif

// Binary maximum: compute the max of two scalar types
template
__device__ inline T cuda_max(T val1, T val2)
{
    return (val1 > val2) ? val1 : val2;
}

#ifdef ENABLE_FP8
template<>
__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val)
{
    return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template<>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
    return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}

template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
{
    return __nv_fp8_e4m3(val);
}
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
    return __nv_fp8_e4m3(val);
}
template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
    return __nv_fp8_e4m3(val);
}
template<>
__device__ inline float cuda_cast(__nv_fp8_e4m3 val)
{
    return (float)val;
}
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
    return fp8x2_e4m3_to_bfloat2(&val);
}

template<>
__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val)
{
    // no impl
    return 0;
}

template<>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{
    return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val)));
}

#endif  // ENABLE_FP8

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/cuda_utils.cc
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/macro.h"
#include 
#include 

namespace turbomind {

void syncAndCheck(const char* const file, int const line)
{
    // When FT_DEBUG_LEVEL=DEBUG, must check error
    static char* level_name = std::getenv("TM_DEBUG_LEVEL");
    if (level_name != nullptr) {
        static std::string level = std::string(level_name);
        if (level == "DEBUG") {
            cudaDeviceSynchronize();
            cudaError_t result = cudaGetLastError();
            if (result) {
                TM_LOG_ERROR((std::string("CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " + file + ":"
                              + std::to_string(line))
                                 .c_str());
                std::abort();
            }
            TM_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line));
        }
    }
}

/* **************************** debug tools ********************************* */

template
void printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr)
{
    T* tmp;
    if (is_device_ptr) {
        // k < stride ; stride = col-dimension.
        tmp = reinterpret_cast(malloc(m * stride * sizeof(T)));
        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));
        cudaDeviceSynchronize();
    }
    else {
        tmp = ptr;
    }

    for (int ii = -1; ii < m; ++ii) {
        if (ii >= 0) {
            printf("%02d ", ii);
        }
        else {
            printf("   ");
        }

        for (int jj = 0; jj < k; jj += 1) {
            if (ii >= 0) {
                printf("%7.3f ", (float)tmp[ii * stride + jj]);
            }
            else {
                printf("%7d ", jj);
            }
        }
        printf("\n");
    }
    if (is_device_ptr) {
        free(tmp);
    }
}

template void printMatrix(float* ptr, int m, int k, int stride, bool is_device_ptr);
template void printMatrix(half* ptr, int m, int k, int stride, bool is_device_ptr);
#ifdef ENABLE_BF16
template void printMatrix(__nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr);
#endif

void printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr)
{
    typedef unsigned long long T;
    T*                         tmp;
    if (is_device_ptr) {
        // k < stride ; stride = col-dimension.
        tmp = reinterpret_cast(malloc(m * stride * sizeof(T)));
        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));
        cudaDeviceSynchronize();
    }
    else {
        tmp = ptr;
    }

    for (int ii = -1; ii < m; ++ii) {
        if (ii >= 0) {
            printf("%02d ", ii);
        }
        else {
            printf("   ");
        }

        for (int jj = 0; jj < k; jj += 1) {
            if (ii >= 0) {
                printf("%4llu ", tmp[ii * stride + jj]);
            }
            else {
                printf("%4d ", jj);
            }
        }
        printf("\n");
    }
    if (is_device_ptr) {
        free(tmp);
    }
}

void printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr)
{
    typedef int T;
    T*          tmp;
    if (is_device_ptr) {
        // k < stride ; stride = col-dimension.
        tmp = reinterpret_cast(malloc(m * stride * sizeof(T)));
        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));
        cudaDeviceSynchronize();
    }
    else {
        tmp = ptr;
    }

    for (int ii = -1; ii < m; ++ii) {
        if (ii >= 0) {
            printf("%02d ", ii);
        }
        else {
            printf("   ");
        }

        for (int jj = 0; jj < k; jj += 1) {
            if (ii >= 0) {
                printf("%4d ", tmp[ii * stride + jj]);
            }
            else {
                printf("%4d ", jj);
            }
        }
        printf("\n");
    }
    if (is_device_ptr) {
        free(tmp);
    }
}

// multiple definitions for msvc
#ifndef _MSC_VER
void printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr)
{
    typedef size_t T;
    T*             tmp;
    if (is_device_ptr) {
        // k < stride ; stride = col-dimension.
        tmp = reinterpret_cast(malloc(m * stride * sizeof(T)));
        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));
        cudaDeviceSynchronize();
    }
    else {
        tmp = ptr;
    }

    for (int ii = -1; ii < m; ++ii) {
        if (ii >= 0) {
            printf("%02d ", ii);
        }
        else {
            printf("   ");
        }

        for (int jj = 0; jj < k; jj += 1) {
            if (ii >= 0) {
                printf("%4ld ", tmp[ii * stride + jj]);
            }
            else {
                printf("%4d ", jj);
            }
        }
        printf("\n");
    }
    if (is_device_ptr) {
        free(tmp);
    }
}
#endif

template
void check_max_val(const T* result, const int size)
{
    T* tmp = new T[size];
    cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost);
    float max_val = -100000;
    for (int i = 0; i < size; i++) {
        float val = static_cast(tmp[i]);
        if (val > max_val) {
            max_val = val;
        }
    }
    delete tmp;
    printf("[INFO][CUDA] addr %p max val: %f \n", result, max_val);
}

template void check_max_val(const float* result, const int size);
template void check_max_val(const half* result, const int size);
#ifdef ENABLE_BF16
template void check_max_val(const __nv_bfloat16* result, const int size);
#endif

template
void check_abs_mean_val(const T* result, const int size)
{
    T* tmp = new T[size];
    cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost);
    float sum = 0.0f;
    for (int i = 0; i < size; i++) {
        sum += abs(static_cast(tmp[i]));
    }
    delete tmp;
    printf("[INFO][CUDA] addr %p abs mean val: %f \n", result, sum / size);
}

template void check_abs_mean_val(const float* result, const int size);
template void check_abs_mean_val(const half* result, const int size);
#ifdef ENABLE_BF16
template void check_abs_mean_val(const __nv_bfloat16* result, const int size);
#endif

/* ***************************** common utils ****************************** */

int getSMVersion()
{
    int device{-1};
    check_cuda_error(cudaGetDevice(&device));
    int sm_major = 0;
    int sm_minor = 0;
    check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
    check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
    return sm_major * 10 + sm_minor;
}

int getSMCount()
{
    int device{-1};
    check_cuda_error(cudaGetDevice(&device));
    int sm_count{};
    check_cuda_error(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));
    return sm_count;
}

std::string getDeviceName()
{
    int device{-1};
    check_cuda_error(cudaGetDevice(&device));
    cudaDeviceProp props;
    check_cuda_error(cudaGetDeviceProperties(&props, device));
    return std::string(props.name);
}

int getDevice()
{
    int current_dev_id = 0;
    check_cuda_error(cudaGetDevice(¤t_dev_id));
    return current_dev_id;
}

int getDeviceCount()
{
    int count = 0;
    check_cuda_error(cudaGetDeviceCount(&count));
    return count;
}

void trim_default_mempool(int device_id)
{
    cudaMemPool_t mempool;
    check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));
    check_cuda_error(cudaMemPoolTrimTo(mempool, 0));
}

/* ************************** end of common utils ************************** */
}  // namespace turbomind


================================================
FILE: src/turbomind/utils/cuda_utils.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 
#include 
#include 
#include 
#include 

#include 
#include 

#include 
#include 
#ifdef SPARSITY_ENABLED
#include 
#endif

#include "src/turbomind/core/check.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/logger.h"

namespace turbomind {

/* **************************** debug tools ********************************* */
static const char* _cudaGetErrorEnum(cudaError_t error)
{
    return cudaGetErrorString(error);
}

static const char* _cudaGetErrorEnum(cublasStatus_t error)
{
    switch (error) {
        case CUBLAS_STATUS_SUCCESS:
            return "CUBLAS_STATUS_SUCCESS";

        case CUBLAS_STATUS_NOT_INITIALIZED:
            return "CUBLAS_STATUS_NOT_INITIALIZED";

        case CUBLAS_STATUS_ALLOC_FAILED:
            return "CUBLAS_STATUS_ALLOC_FAILED";

        case CUBLAS_STATUS_INVALID_VALUE:
            return "CUBLAS_STATUS_INVALID_VALUE";

        case CUBLAS_STATUS_ARCH_MISMATCH:
            return "CUBLAS_STATUS_ARCH_MISMATCH";

        case CUBLAS_STATUS_MAPPING_ERROR:
            return "CUBLAS_STATUS_MAPPING_ERROR";

        case CUBLAS_STATUS_EXECUTION_FAILED:
            return "CUBLAS_STATUS_EXECUTION_FAILED";

        case CUBLAS_STATUS_INTERNAL_ERROR:
            return "CUBLAS_STATUS_INTERNAL_ERROR";

        case CUBLAS_STATUS_NOT_SUPPORTED:
            return "CUBLAS_STATUS_NOT_SUPPORTED";

        case CUBLAS_STATUS_LICENSE_ERROR:
            return "CUBLAS_STATUS_LICENSE_ERROR";
    }
    return "";
}

template
void check(T result, char const* const func, const char* const file, int const line)
{
    if (result) {
        TM_LOG_ERROR((std::string("CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " + file + ":"
                      + std::to_string(line))
                         .c_str());
        std::abort();
    }
}

#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)

void syncAndCheck(const char* const file, int const line);

#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__)

#define CUDRVCHECK(expr)                                                                                               \
    if (auto ec = expr; ec != CUDA_SUCCESS) {                                                                          \
        const char* p_str{};                                                                                           \
        cuGetErrorString(ec, &p_str);                                                                                  \
        p_str    = p_str ? p_str : "Unknown error";                                                                    \
        auto msg = fmtstr("[TM][ERROR] CUDA driver error: %s:%d '%s'", __FILE__, __LINE__, p_str);                     \
        throw std::runtime_error(msg.c_str());                                                                         \
    }

template
void printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr);

void printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr);
void printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr);
void printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr);

template
void check_max_val(const T* result, const int size);

template
void check_abs_mean_val(const T* result, const int size);

#define PRINT_FUNC_NAME_()                                                                                             \
    do {                                                                                                               \
        std::cout << "[TM][CALL] " << __FUNCTION__ << " " << std::endl;                                                \
    } while (0)

[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
{
    throw std::runtime_error(std::string("[TM][ERROR] ") + info + " Assertion fail: " + file + ":"
                             + std::to_string(line) + " \n");
}

inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "")
{
    if (!result) {
        throwRuntimeError(file, line, info);
    }
}

#define FT_CHECK(val) myAssert(bool(val), __FILE__, __LINE__)
#define FT_CHECK_WITH_INFO(val, info)                                                                                  \
    do {                                                                                                               \
        bool is_valid_val = bool(val);                                                                                 \
        if (!is_valid_val) {                                                                                           \
            turbomind::myAssert(is_valid_val, __FILE__, __LINE__, (info));                                             \
        }                                                                                                              \
    } while (0)

#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info)

/* ***************************** common utils ****************************** */

int getSMVersion();

int getSMCount();

std::string getDeviceName();

template
inline T div_up(T a, T n)
{
    return (a + n - 1) / n;
}

int getDevice();

int getDeviceCount();

class CudaDeviceGuard {
public:
    CudaDeviceGuard(int device)
    {
        check_cuda_error(cudaGetDevice(&last_device_id_));
        if (device != last_device_id_) {
            check_cuda_error(cudaSetDevice(device));
        }
    }

    ~CudaDeviceGuard()
    {
        TM_CHECK_EQ(cudaSetDevice(last_device_id_), cudaSuccess);
    }

private:
    int last_device_id_{-1};
};

void trim_default_mempool(int device_id);

/* ************************** end of common utils ************************** */
}  // namespace turbomind


================================================
FILE: src/turbomind/utils/debug_utils.h
================================================
#pragma once

#if __has_include("3rdparty/dbg.h")
#include "3rdparty/dbg.h"
#else
#define dbg(...)
#endif


================================================
FILE: src/turbomind/utils/dispatch.h
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include 

namespace turbomind {

namespace detail {

template
inline constexpr std::integral_constant _Int{};

template
bool dispatch_impl(F&& f, P&& p, G g, std::integer_sequence, std::index_sequence)
{
    constexpr int N = sizeof...(Xs);
    return (((((P &&) p)(_Int) || (g && Is == N - 1)) && (((F &&) f)(_Int), 1)) || ...);
}

}  // namespace detail

template
bool dispatch(std::integer_sequence seq, P&& p, F&& f, G g = {})
{
    return detail::dispatch_impl((F &&) f, (P &&) p, g, seq, std::make_index_sequence{});
}

template
bool dispatch(std::integer_sequence seq, F&& f)
{
    return (((F &&) f)(detail::_Int) || ...);
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/logger.cc
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/utils/logger.h"
#include 

namespace turbomind {

Logger& Logger::getLogger()
{
    thread_local Logger instance;
    return instance;
}

Logger::Logger()
{
    char* is_first_rank_only_char = std::getenv("TM_LOG_FIRST_RANK_ONLY");
    bool  is_first_rank_only =
        (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false;

    int device_id;
    cudaGetDevice(&device_id);

    char* level_name = std::getenv("TM_LOG_LEVEL");
    if (level_name != nullptr) {
        std::map name_to_level = {
            {"TRACE", TRACE},
            {"DEBUG", DEBUG},
            {"INFO", INFO},
            {"WARNING", WARNING},
            {"ERROR", ERROR},
        };
        auto level = name_to_level.find(level_name);
        // If TM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
        if (is_first_rank_only && device_id != 0) {
            level = name_to_level.find("ERROR");
        }
        if (level != name_to_level.end()) {
            setLevel(level->second);
        }
        else {
            fprintf(stderr,
                    "[TM][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
                    "Ignore the environment variable and use a default "
                    "logging level.\n",
                    level_name);
            level_name = nullptr;
        }
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/logger.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 
#include 
#include 

#include "src/turbomind/utils/string_utils.h"

namespace turbomind {

// cub.cuh brings windows.h
// should be included after cub.cuh
#ifdef ERROR
#undef ERROR
#endif

class Logger {

public:
    enum Level
    {
        TRACE   = 0,
        DEBUG   = 10,
        INFO    = 20,
        WARNING = 30,
        ERROR   = 40
    };

    static Logger& getLogger();
    Logger(Logger const&) = delete;
    void operator=(Logger const&) = delete;

    template
    void log(const Level level, const std::string format, const Args&... args)
    {
        if (level_ <= level) {
            std::string fmt = getPrefix(level) + format + "\n";
            // FILE*       out    = level_ < WARNING ? stdout : stderr;
            std::string logstr = fmtstr(fmt, args...);
            fprintf(stderr, "%s", logstr.c_str());
        }
    }

    template
    void log(const Level level, const int rank, const std::string format, const Args&... args)
    {
        if (level_ <= level) {
            std::string fmt = getPrefix(level, rank) + format + "\n";
            // FILE*       out    = level_ < WARNING ? stdout : stderr;
            std::string logstr = fmtstr(fmt, args...);
            fprintf(stderr, "%s", logstr.c_str());
        }
    }

    void setLevel(const Level level)
    {
        level_ = level;
        log(DEBUG, "Set logger level by %s", getLevelName(level).c_str());
    }

    int getLevel() const
    {
        return level_;
    }

private:
    const std::string                              PREFIX      = "[TM]";
    const std::map level_name_ = {
        {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}};

#ifndef NDEBUG
    const Level DEFAULT_LOG_LEVEL = DEBUG;
#else
    const Level DEFAULT_LOG_LEVEL = INFO;
#endif
    Level level_ = DEFAULT_LOG_LEVEL;

    Logger();

    inline const std::string getLevelName(const Level level)
    {
        return level_name_.at(level);
    }

    inline const std::string getPrefix(const Level level)
    {
        return PREFIX + "[" + getLevelName(level) + "] ";
    }

    inline const std::string getPrefix(const Level level, const int rank)
    {
        return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] ";
    }
};

#define TM_LOG(level, ...)                                                                                             \
    do {                                                                                                               \
        if (turbomind::Logger::getLogger().getLevel() <= level) {                                                      \
            turbomind::Logger::getLogger().log(level, __VA_ARGS__);                                                    \
        }                                                                                                              \
    } while (0)

#define TM_LOG_TRACE(...) TM_LOG(turbomind::Logger::TRACE, __VA_ARGS__)
#define TM_LOG_DEBUG(...) TM_LOG(turbomind::Logger::DEBUG, __VA_ARGS__)
#define TM_LOG_INFO(...) TM_LOG(turbomind::Logger::INFO, __VA_ARGS__)
#define TM_LOG_WARNING(...) TM_LOG(turbomind::Logger::WARNING, __VA_ARGS__)
#define TM_LOG_ERROR(...) TM_LOG(turbomind::Logger::ERROR, __VA_ARGS__)
}  // namespace turbomind


================================================
FILE: src/turbomind/utils/memory_utils.cu
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/turbomind/macro.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {

template
__global__ void transpose102(T_OUT* dst, T_IN* src, const int dim0, const int dim1, const int dim2)
{
    // src permutation: [0, 1, 2]
    // dst permutation: [1, 0, 2]
    for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) {
        int       tmp_idx                                           = tid;
        const int dim_2_idx                                         = tmp_idx % dim2;
        tmp_idx                                                     = (tmp_idx - dim_2_idx) / dim2;
        const int dim_1_idx                                         = tmp_idx % dim1;
        tmp_idx                                                     = (tmp_idx - dim_1_idx) / dim1;
        const int dim_0_idx                                         = tmp_idx % dim0;
        dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid];
    }
}

template
void invokeInPlaceTranspose102(
    T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream)
{
    // copy data to workspace, and then transpose from workspace to data
    // Note that this kernel is used for pre-processing and not very efficient.
    const size_t count = dim0 * dim1 * dim2;
    if (copy) {
        check_cuda_error(cudaMemcpyAsync(workspace, data, sizeof(T) * count, cudaMemcpyDefault, stream));
    }
    const int block = 512;
    const int grid  = std::min((count + block - 1) / block, (size_t)8192);
    transpose102<<>>(data, workspace, dim0, dim1, dim2);
}

template void invokeInPlaceTranspose102(uint16_t*    data,
                                        uint16_t*    workspace,
                                        const int    dim0,
                                        const int    dim1,
                                        const int    dim2,
                                        bool         copy,
                                        cudaStream_t stream);

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/memory_utils.h
================================================
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 

namespace turbomind {

template
void invokeInPlaceTranspose102(
    T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy = true, cudaStream_t stream = 0);

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/metrics.h
================================================
#pragma once

#include 
#include 
#include 
#include 

namespace turbomind {

struct ScheduleMetrics {
    // sequences
    int total_seqs;    // the number of received sequence
    int active_seqs;   // the number of active sequence
    int waiting_seqs;  // the number of waiting sequence

    // kv block usage
    int total_blocks;   // the number of kv blocks
    int active_blocks;  // the number of active kv blocks
    int cached_blocks;  // the number of cached kv blocks
    int free_blocks;    // the number of free kv blocks
};

struct RequestMetrics {
    std::atomic enqueue_time{};    // when a request is enqued
    std::atomic scheduled_time{};  // when a request is scheduled for inference

    static int64_t timestamp()
    {
        // Get current timestamp in microseconds since Unix epoch
        // system_clock uses wall-clock time (matches Python's time.time())
        return std::chrono::duration_cast(
                   std::chrono::system_clock::now().time_since_epoch())
            .count();
    }
};

inline std::ostream& operator<<(std::ostream& os, const ScheduleMetrics& m)
{
    os << "ScheduleMetrics { ";
    os << "total_seqs=" << m.total_seqs;
    os << ", active_seqs=" << m.active_seqs;
    os << ", waiting_seqs=" << m.waiting_seqs;
    os << ", total_blocks=" << m.total_blocks;
    os << ", cached_blocks=" << m.cached_blocks;
    os << ", free_blocks=" << m.free_blocks;
    os << " }";
    return os;
}

inline std::ostream& operator<<(std::ostream& os, const RequestMetrics& m)
{
    os << "RequestMetrics { ";
    os << "enqueue_time=" << m.enqueue_time.load(std::memory_order_relaxed);
    os << ", scheduled_time=" << m.scheduled_time.load(std::memory_order_relaxed);
    os << " }";
    return os;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/monotonic.h
================================================
#pragma once

#include 
#include 
#include 

namespace turbomind {

class Monotonic {
public:
    Monotonic(void* base, size_t alignment = 256): ptr_{base}, alignment_{alignment}
    {
        ptr_ = align(ptr_);
    }

    template
    void operator()(T** ptr, size_t numel) noexcept
    {
        *ptr = (T*)std::exchange(ptr_, align((T*)ptr_ + numel));
    }

    void* ptr() const noexcept
    {
        return ptr_;
    }

private:
    template
    void* align(T* p)
    {
        static_assert(sizeof(T*) == sizeof(uintptr_t));
        auto x = reinterpret_cast(p);
        if (auto remainder = x % alignment_) {
            x += alignment_ - remainder;
        }
        return reinterpret_cast(x);
    }

    void*  ptr_;
    size_t alignment_;
};

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/nvtx_utils.cc
================================================
/*
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include 

#include "nvtx_utils.h"
#ifdef USE_NVTX
#include "nvtx3/nvToolsExt.h"
#endif

namespace ft_nvtx {
std::string getScope()
{
    return scope;
}
void addScope(std::string name)
{
    scope = scope + name + "/";
    return;
}
void setScope(std::string name)
{
    scope = name + "/";
    return;
}
void resetScope()
{
    scope = "";
    return;
}
void setDeviceDomain(int deviceId)
{
    domain = deviceId;
    return;
}
void resetDeviceDomain()
{
    domain = 0;
    return;
}
int getDeviceDomain()
{
    return domain;
}

bool isEnableNvtx()
{
    if (!has_read_nvtx_env) {
        static char* ft_nvtx_env_char = std::getenv("FT_NVTX");
        is_enable_ft_nvtx = (ft_nvtx_env_char != nullptr && std::string(ft_nvtx_env_char) == "ON") ? true : false;
        has_read_nvtx_env = true;
    }
    return is_enable_ft_nvtx;
}

void ftNvtxRangePush(std::string name)
{
#ifdef USE_NVTX
    nvtxStringHandle_t    nameId      = nvtxDomainRegisterStringA(NULL, (getScope() + name).c_str());
    nvtxEventAttributes_t eventAttrib = {0};
    eventAttrib.messageType           = NVTX_MESSAGE_TYPE_REGISTERED;
    eventAttrib.message.registered    = nameId;
    eventAttrib.payloadType           = NVTX_PAYLOAD_TYPE_INT32;
    eventAttrib.payload.iValue        = getDeviceDomain();
    nvtxRangePushEx(&eventAttrib);
#endif
}

void ftNvtxRangePop()
{
#ifdef USE_NVTX
    nvtxRangePop();
#endif
}

}  // namespace ft_nvtx


================================================
FILE: src/turbomind/utils/nvtx_utils.h
================================================
/*
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

namespace ft_nvtx {
static std::string scope;
std::string        getScope();
void               addScope(std::string name);
void               setScope(std::string name);
void               resetScope();
static int         domain = 0;
void               setDeviceDomain(int deviceId);
int                getDeviceDomain();
void               resetDeviceDomain();
bool               isEnableNvtx();

static bool has_read_nvtx_env = false;
static bool is_enable_ft_nvtx = false;
void        ftNvtxRangePush(std::string name);
void        ftNvtxRangePop();
}  // namespace ft_nvtx

#define PUSH_RANGE(name)                                                                                               \
    {                                                                                                                  \
        if (ft_nvtx::isEnableNvtx()) {                                                                                 \
            ft_nvtx::ftNvtxRangePush(name);                                                                            \
        }                                                                                                              \
    }

#define POP_RANGE                                                                                                      \
    {                                                                                                                  \
        if (ft_nvtx::isEnableNvtx()) {                                                                                 \
            ft_nvtx::ftNvtxRangePop();                                                                                 \
        }                                                                                                              \
    }


================================================
FILE: src/turbomind/utils/parser.cc
================================================
// Copyright (c) OpenMMLab. All rights reserved.

#include 
#include 
#include 
#include 

namespace turbomind {

std::vector> ParseArgsList(const std::string& str)
{
    const std::regex regex(R"((\w+)=([^,\[\(]+|\[.*\]|\(.*\)))");

    std::sregex_iterator beg(str.begin(), str.end(), regex);
    std::sregex_iterator end{};

    std::vector> ret;
    for (auto it = beg; it != end; ++it) {
        std::smatch match = *it;
        ret.emplace_back(match[1], match[2]);
    }

    return ret;
}

std::vector ParseListOrTuple(const std::string& str)
{
    const std::regex regex(R"([,\[\]\(\)]+)");

    std::vector ret;
    std::copy_if(std::sregex_token_iterator(str.begin(), str.end(), regex, -1),
                 std::sregex_token_iterator{},
                 std::back_inserter(ret),
                 [](const std::string& s) { return !s.empty(); });

    return ret;
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/parser.h
================================================
#include 
#include 

namespace turbomind {

std::vector> ParseArgsList(const std::string& str);

std::vector ParseListOrTuple(const std::string& str);

inline void Parse(int& value, const std::string& str)
{
    value = std::stoi(str);
}

inline void Parse(float& value, const std::string& str)
{
    value = std::stof(str);
}

template
void Parse(std::vector& xs, const std::string& str)
{
    const auto ss = ParseListOrTuple(str);
    for (const auto& s : ss) {
        xs.emplace_back();
        Parse(xs.back(), s);
    }
}

}  // namespace turbomind


================================================
FILE: src/turbomind/utils/string_utils.h
================================================
/*
 * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include    // std::make_unique
#include   // std::stringstream
#include 
#include 

namespace turbomind {

template
inline std::string fmtstr(const std::string& format, Args... args)
{
    // This function came from a code snippet in stackoverflow under cc-by-1.0
    //   https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf

    // Disable format-security warning in this function.
#if defined(_MSC_VER)  // for visual studio
#pragma warning(push)
#pragma warning(warning(disable : 4996))
#elif defined(__GNUC__) || defined(__clang__)  // for gcc or clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wformat-security"
#endif
    int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1;  // Extra space for '\0'
    if (size_s <= 0) {
        throw std::runtime_error("Error during formatting.");
    }
    auto size = static_cast(size_s);
    auto buf  = std::make_unique(size);
    std::snprintf(buf.get(), size, format.c_str(), args...);
#if defined(_MSC_VER)
#pragma warning(pop)
#elif defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
    return std::string(buf.get(), buf.get() + size - 1);  // We don't want the '\0' inside
}

template
inline std::string vec2str(std::vector vec)
{
    std::stringstream ss;
    ss << "(";
    if (!vec.empty()) {
        for (size_t i = 0; i < vec.size() - 1; ++i) {
            ss << vec[i] << ", ";
        }
        ss << vec.back();
    }
    ss << ")";
    return ss.str();
}

template
inline std::string arr2str(T* arr, size_t size)
{
    std::stringstream ss;
    ss << "(";
    for (size_t i = 0; i < size - 1; ++i) {
        ss << arr[i] << ", ";
    }
    if (size > 0) {
        ss << arr[size - 1];
    }
    ss << ")";
    return ss.str();
}
}  // namespace turbomind


================================================
FILE: src/turbomind/utils/test_utils.h
================================================
/*
 * Copyright (c) 2022 NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include 
#include 
#include 

namespace turbomind {

#define TIMEIT(print, n, stream, fn, ...)                                                                              \
    ({                                                                                                                 \
        cudaEvent_t _macro_event_start, _macro_event_stop;                                                             \
        cudaEventCreate(&_macro_event_start);                                                                          \
        cudaEventCreate(&_macro_event_stop);                                                                           \
        cudaEventRecord(_macro_event_start, stream);                                                                   \
        for (int i = 0; i < n; i++) {                                                                                  \
            fn(__VA_ARGS__);                                                                                           \
        }                                                                                                              \
        cudaEventRecord(_macro_event_stop, stream);                                                                    \
        cudaStreamSynchronize(stream);                                                                                 \
        float ms = 0.0f;                                                                                               \
        cudaEventElapsedTime(&ms, _macro_event_start, _macro_event_stop);                                              \
        ms /= n;                                                                                                       \
        if (print)                                                                                                     \
            printf("[TIMEIT] " #fn ": %.2fµs\n", ms * 1000);                                                           \
        ms;                                                                                                            \
    })

template
struct rel_abs_diff {
    T operator()(const T& lhs, const T& rhs) const
    {
        return lhs == 0 ? 0 : static_cast(fabs(lhs - rhs) / fabs(lhs));
    }
};

template
struct abs_diff {
    T operator()(const T& lhs, const T& rhs) const
    {
        return static_cast(fabs(lhs - rhs));
    }
};

}  // namespace turbomind


================================================
FILE: tests/csrc/CMakeLists.txt
================================================
# Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(unittests)


================================================
FILE: tests/csrc/unittests/CMakeLists.txt
================================================
# Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# GoogleTest Preparation - Code block copied from
#   https://google.github.io/googletest/quickstart-cmake.html
include(FetchContent)
FetchContent_Declare(
  googletest
  GIT_REPOSITORY https://github.com/google/googletest.git
  GIT_TAG release-1.12.1
)

find_package(CUDAToolkit REQUIRED)

if (NOT MSVC)
  add_definitions(-DTORCH_CUDA=1)
endif()

# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)

add_executable(unittest
    test_logprob_kernels.cu
    test_penalty_kernels.cu
    test_sampling_kernels.cu
    test_sampling_layer.cu
)

# automatic discovery of unit tests
target_link_libraries(unittest PUBLIC "${TORCH_LIBRARIES}" gtest_main)
target_compile_features(unittest PRIVATE cxx_std_14)

# Sorted by alphabetical order of test name.
target_link_libraries(  # Libs for test_attention_kernels
  unittest PUBLIC
    CUDA::cudart CUDA::curand
    gpt_kernels gtest memory_utils tensor unfused_attention_kernels cuda_utils logger)
target_link_libraries(  # Libs for test_logprob_kernels
  unittest PUBLIC
    CUDA::cudart
    logprob_kernels memory_utils cuda_utils logger)
target_link_libraries(  # Libs for test_penalty_kernels
  unittest PUBLIC
    CUDA::cublas CUDA::cublasLt CUDA::cudart
    sampling_penalty_kernels memory_utils cuda_utils logger)
target_link_libraries(  # Libs for test_sampling_kernel
  unittest PUBLIC
    CUDA::cudart
    sampling_topk_kernels sampling_topp_kernels memory_utils tensor cuda_utils logger)
target_link_libraries(  # Libs for test_sampling_layer
  unittest PUBLIC
    CUDA::cublas CUDA::cublasLt CUDA::cudart
    cublasMMWrapper memory_utils
    DynamicDecodeLayer cuda_utils logger
)
target_link_libraries(  # Libs for test_tensor
  unittest PUBLIC cuda_utils logger)


================================================
FILE: tests/csrc/unittests/gtest_utils.h
================================================
#include    // std::fill_n
#include     // snprintf
#include       // expf, log
#include     // rand
#include       // std::string
#include       // std::vector

#include 
#include 

#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/logger.h"

namespace ft = turbomind;

namespace {

#define EPSILON (1e-20)

bool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8)
{
    // Params: a = value to compare and b = reference
    // This function follows implementation of numpy.isclose(), which checks
    //   abs(a - b) <= (atol + rtol * abs(b)).
    // Note that the inequality above is asymmetric where b is considered as
    // a reference value. To account into both absolute/relative errors, it
    // uses absolute tolerance and relative tolerance at the same time. The
    // default values of atol and rtol borrowed from numpy.isclose(). For the
    // case of nan value, the result will be true.
    if (isnan(a) && isnan(b)) {
        return true;
    }
    if (isinf(a) && isinf(b) && (a > 0 && b > 0 || a < 0 && b < 0)) {
        return true;
    }
    return fabs(a - b) <= (atol + rtol * fabs(b));
}

template
bool checkResult(std::string name, T* out, T*ref, size_t size, float atol, float rtol) {
    size_t failures = 0;
    float relative_gap = 0.0f;;

    for (size_t i = 0; i < size; ++i) {
        // The values for the output and the reference.
        float a = (float)out[i];
        float b = (float)ref[i];

        bool ok = almostEqual(a, b, atol, rtol);
        // Print the error.
        if (!ok && failures < 4) {
            TM_LOG_ERROR(">> invalid result for i=%lu:", i);
            TM_LOG_ERROR(">>    found......: %10.6f", a);
            TM_LOG_ERROR(">>    expected...: %10.6f", b);
            TM_LOG_ERROR(">>    error......: %.6f", fabsf(a - b));
            TM_LOG_ERROR(">>    tol........: %.6f", atol + rtol * fabs(b));
        }
        // Update the number of failures.
        failures += ok ? 0 : 1;
        // Update the relative gap.
        relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON);
    }

    relative_gap /= size;

    // Allow not matched up to 1% elements.
    size_t tol_failures = (size_t)(0.01 * size);
    if (failures > tol_failures) {
        TM_LOG_ERROR("%s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)",
                     name.c_str(), 100. * failures / size, atol, rtol, 100. * relative_gap);
    }
    return failures <= tol_failures;
}

template
bool checkResult(std::string name, T* out, T* ref, size_t size,
                 bool device_out = true, bool device_ref = false)
{
    bool is_fp32 = sizeof(T) == 4;
    float atol = is_fp32 ? 1e-4f : 1e-3f;
    float rtol = is_fp32 ? 1e-2f : 1e-1f;

    T* h_out = nullptr;
    if (device_out) {
        h_out = new T[size];
        cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost);
        out = h_out;
    }
    T* h_ref = nullptr;
    if (device_ref) {
        h_ref = new T[size];
        cudaMemcpy(h_ref, ref, sizeof(T) * size, cudaMemcpyDeviceToHost);
        ref = h_ref;
    }
    bool is_ok = checkResult(name, out, ref, size, atol, rtol);
    if (h_out != nullptr){
        delete[] h_out;
    }
    if (h_ref != nullptr) {
        delete[] h_ref;
    }
    return is_ok;
}

template
void initRandom(T* ptr, size_t size, float minval, float maxval) {
    for (size_t i = 0; i < size; ++i) {
        float val = static_cast(rand()) / static_cast(RAND_MAX);
        val *= (maxval - minval);
        ptr[i] = static_cast(minval + val);
    }
}

void initRandomInt(int* ptr, size_t size, int minval, int maxval) {
    assert(minval < maxval);
    int mod = maxval - minval;
    for (size_t i = 0; i < size; ++i) {
        ptr[i] = minval + rand() % mod;
    }
}

template
void tile(T* x, int m, int n) {
    for (int i = 1; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            x[i * n + j] = x[j];
        }
    }
}

template
void tile(T* dst, T* src, int m, int n) {
    for (int i = 1; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            dst[i * n + j] = src[j];
        }
    }
}

// for the safe arithmetic functions in host.
namespace math {
template
inline T add(T a, T b)
{
    return static_cast((float)a + (float)b);
}

template
inline T mul(T a, T b)
{
    return static_cast((float)a * (float)b);
}

template
inline T fma(T a, T b, T c)
{
    return static_cast((float)a * (float)b + (float)c);
}
}

#ifdef ENABLE_FP32
#ifdef ENABLE_BF16
typedef testing::Types SamplingTypes;
#else
typedef testing::Types SamplingTypes;
#endif
#else
#ifdef ENABLE_BF16
typedef testing::Types        SamplingTypes;
#else
typedef testing::Types SamplingTypes;
#endif
#endif

typedef testing::Types FloatType;
typedef testing::Types FloatAndHalfTypes;
#ifndef ENABLE_BF16
typedef FloatAndHalfTypes SupportTypes;
#else
typedef testing::Types FloatHalfBf16Types;
typedef FloatHalfBf16Types SupportTypes;
#endif

class FtTestBase: public testing::Test {
public:
    void SetUp() override
    {
        int device = 0;
        cudaGetDevice(&device);
        cudaStreamCreate(&stream);
        allocator = new ft::Allocator(device);
        allocator->setStream(stream);
    }

    void TearDown() override
    {
        // Automatically allocated CPU buffers should be released at the end of a test.
        // We don't need to care GPU buffers allocated by Allocator because they are
        // managed by the allocator.
        for (auto& buffer : allocated_cpu_buffers) {
            free(buffer);
        }
        allocated_cpu_buffers.clear();
        delete allocator;
        cudaStreamDestroy(stream);
    }

protected:
    cudaStream_t                            stream;
    ft::Allocator* allocator;
    std::vector                      allocated_cpu_buffers;

    // Utilities to easily handle tensor instances in test cases.

    ft::Tensor createTensor(const ft::MemoryType mtype,
                            const ft::DataType dtype,
                            const std::vector shape)
    {
        size_t n_elmts  = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
        size_t buf_size = ft::Tensor::getTypeSize(dtype) * n_elmts;

        void* data = nullptr;
        if (mtype == ft::MEMORY_CPU || mtype == ft::MEMORY_CPU_PINNED) {
            data = malloc(buf_size);
            allocated_cpu_buffers.push_back(data);
        }
        else {
            data = allocator->malloc(buf_size);
        }
        return ft::Tensor(mtype, dtype, shape, data);
    };

    template
    ft::Tensor toHost(ft::Tensor& device_tensor)
    {
        if (device_tensor.data == nullptr) {
            return ft::Tensor();
        }
        ft::Tensor host_tensor = createTensor(ft::MEMORY_CPU, device_tensor.type, device_tensor.shape);
        ft::cudaAutoCpy(host_tensor.getPtr(), device_tensor.getPtr(), host_tensor.size(), stream);
        cudaStreamSynchronize(stream);
        return host_tensor;
    };

    template
    ft::Tensor toDevice(ft::Tensor& host_tensor)
    {
        if (host_tensor.data == nullptr) {
            return ft::Tensor();
        }
        ft::Tensor device_tensor = createTensor(ft::MEMORY_GPU, host_tensor.type, host_tensor.shape);
        ft::cudaAutoCpy(device_tensor.getPtr(), host_tensor.getPtr(), host_tensor.size(), stream);
        return device_tensor;
    };

    void copyTensor(ft::Tensor& dst, ft::Tensor& src)
    {
        FT_CHECK_WITH_INFO(
            src.sizeBytes() == dst.sizeBytes(),
            ft::fmtstr("src and dst has different size (%ld != %ld)", src.sizeBytes(), dst.sizeBytes()));
        ft::cudaAutoCpy(dst.getPtr(), src.getPtr(), src.sizeBytes(), stream);
        cudaStreamSynchronize(stream);
    }

};

}


================================================
FILE: tests/csrc/unittests/test_logprob_kernels.cu
================================================
#include 
#include 
#include 
#include 
#include 
#include 
#ifdef __linux__
#include 
#endif
#include "src/turbomind/kernels/logprob_kernels.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"

#include "gtest_utils.h"

using namespace turbomind;

////////////////////////////////////////////////////////////////////////////////////

struct LogProbKernelTestParam {
    size_t max_input_length;
    size_t batch_size;
    size_t vocab_size;
    size_t beam_width;

    std::string toString()
    {
        return fmtstr("LogProbKernelTestParam[max_input_length=%ld, batch=%ld, vocab=%ld, beam_width=%ld]",
                      max_input_length,
                      batch_size,
                      vocab_size,
                      beam_width);
    }
};

/////////////////////////////////// Unittests //////////////////////////////////////////
template
class LogProbKernelTest: public FtTestBase {

protected:
    void computeCumLogProbs(float*       cum_log_probs,
                            float*       log_probs,
                            const T*     logits,
                            const int*   input_ids,
                            const int*   input_lengths,
                            const size_t max_input_length,
                            const size_t batch_size,
                            const size_t vocab_size,
                            const size_t vocab_size_padded)
    {
        for (size_t step = 0; step < max_input_length; ++step) {
            for (size_t i = 0; i < batch_size; ++i) {
                if ((int)step == 0) {
                    if (log_probs != nullptr) {
                        log_probs[i] = 0.0f;
                    }
                    cum_log_probs[i] = 0.0f;
                }
                else if ((int)step < input_lengths[i]) {
                    size_t   step_offset = (step - 1) * batch_size * vocab_size_padded;
                    const T* vec         = logits + step_offset + i * vocab_size_padded;
                    float    max_logits  = -FLT_MAX;
                    for (size_t v = 0; v < vocab_size; ++v) {
                        float val = static_cast(vec[v]);
                        if (val > max_logits) {
                            max_logits = val;
                        }
                    }
                    float sum = 0.0f;
                    for (size_t v = 0; v < vocab_size; ++v) {
                        sum += expf(static_cast(vec[v]) - max_logits);
                    }
                    int   token_id = input_ids[step * batch_size + i];
                    float log_prob = static_cast(vec[token_id]) - max_logits - log(sum);
                    if (log_probs != nullptr) {
                        log_probs[step * batch_size + i] = log_prob;
                    }
                    cum_log_probs[i] += log_prob;
                }
            }
        }
    }

    void computeCumLogProbsBatchFirst(float*       cum_log_probs,
                                      float*       log_probs,
                                      const T*     logits,
                                      const int*   input_ids,
                                      const int*   input_lengths,
                                      const size_t max_input_length,
                                      const size_t batch_size,
                                      const size_t vocab_size,
                                      const size_t vocab_size_padded)
    {
        for (size_t i = 0; i < batch_size; ++i) {
            size_t batch_offset = i * max_input_length * vocab_size_padded;
            for (size_t step = 0; step < max_input_length; ++step) {
                if ((int)step == 0) {
                    if (log_probs != nullptr) {
                        log_probs[i * max_input_length] = 0.0f;
                    }
                    cum_log_probs[i] = 0.0f;
                }
                else if ((int)step < input_lengths[i]) {
                    const T* vec        = logits + batch_offset + (step - 1) * vocab_size_padded;
                    float    max_logits = -FLT_MAX;
                    for (size_t v = 0; v < vocab_size; ++v) {
                        float val = static_cast(vec[v]);
                        if (val > max_logits) {
                            max_logits = val;
                        }
                    }
                    float sum = 0.0f;
                    for (size_t v = 0; v < vocab_size; ++v) {
                        sum += expf(static_cast(vec[v]) - max_logits);
                    }
                    int   token_id = input_ids[i * max_input_length + step];
                    float log_prob = static_cast(vec[token_id]) - max_logits - log(sum);
                    if (log_probs != nullptr) {
                        log_probs[i * max_input_length + step] = log_prob;
                    }
                    cum_log_probs[i] += log_prob;
                }
            }
        }
    }

public:
    void runTest(LogProbKernelTestParam param)
    {
        size_t max_input_length = param.max_input_length;
        size_t batchxbeam       = param.batch_size * param.beam_width;
        size_t vocab_size       = param.vocab_size;
        // Make multiple of 8 as GPT does.
        size_t vocab_size_padded = static_cast(ceil(vocab_size / 8.f) * 8);

        // input values
        T*   h_logits        = new T[max_input_length * batchxbeam * vocab_size];
        int* h_input_ids     = new int[max_input_length * batchxbeam];
        int* h_input_lengths = new int[batchxbeam];

        // output buffers
        float* expected_cum_log_probs = new float[batchxbeam];

        // initialize host buffers
        initRandom(h_logits, max_input_length * batchxbeam * vocab_size, -10.0f / vocab_size, -1.0f);
        initRandomInt(h_input_ids, max_input_length * batchxbeam, 0, vocab_size);
        initRandomInt(h_input_lengths, batchxbeam, 1, max_input_length + 1);
        memset(expected_cum_log_probs, 0, sizeof(float) * batchxbeam);

        // device buffers
        T*   d_logits = reinterpret_cast(allocator->malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size));
        int* d_input_ids       = reinterpret_cast(allocator->malloc(sizeof(int) * max_input_length * batchxbeam));
        int* d_input_lengths   = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam));
        float* d_cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batchxbeam));

        // initialize device buffers
        cudaH2Dcpy(d_logits, h_logits, max_input_length * batchxbeam * vocab_size);
        cudaH2Dcpy(d_input_ids, h_input_ids, max_input_length * batchxbeam);
        cudaH2Dcpy(d_input_lengths, h_input_lengths, batchxbeam);
        deviceFill(d_cum_log_probs, batchxbeam, 0.0f);

        size_t workspace_size = sizeof(float) * max_input_length * batchxbeam;
        void*  workspace      = allocator->malloc(workspace_size);
        invokeLogProbFromLogits(d_cum_log_probs,
                                d_logits,
                                d_input_ids,
                                d_input_lengths,
                                max_input_length,
                                batchxbeam,
                                vocab_size,
                                vocab_size_padded,
                                workspace,
                                workspace_size,
                                stream,
                                false);
        computeCumLogProbs(expected_cum_log_probs,
                           nullptr,
                           h_logits,
                           h_input_ids,
                           h_input_lengths,
                           max_input_length,
                           batchxbeam,
                           vocab_size,
                           vocab_size_padded);
        bool passed = checkResult(param.toString(), d_cum_log_probs, expected_cum_log_probs, batchxbeam);
        EXPECT_TRUE(passed);

        TM_LOG_DEBUG("free host buffers");
        delete[] expected_cum_log_probs;
        delete[] h_input_lengths;
        delete[] h_input_ids;
        delete[] h_logits;
    }

    void runBatchFirstTest(LogProbKernelTestParam param)
    {
        size_t max_input_length = param.max_input_length;
        size_t batchxbeam       = param.batch_size * param.beam_width;
        size_t vocab_size       = param.vocab_size;
        // Make multiple of 8 as GPT does.
        size_t vocab_size_padded = static_cast(ceil(vocab_size / 8.f) * 8);

        // input values
        T*   h_logits        = new T[max_input_length * batchxbeam * vocab_size_padded];
        int* h_input_ids     = new int[max_input_length * batchxbeam];
        int* h_input_lengths = new int[batchxbeam];

        // output buffers
        float* expected_cum_log_probs = new float[batchxbeam];

        // initialize host buffers
        initRandom(h_logits, max_input_length * batchxbeam * vocab_size_padded, -10.0f / vocab_size, -1.0f);
        initRandomInt(h_input_ids, max_input_length * batchxbeam, 0, vocab_size);
        initRandomInt(h_input_lengths, batchxbeam, 1, max_input_length + 1);
        memset(expected_cum_log_probs, 0, sizeof(float) * batchxbeam);

        // device buffers
        T* d_logits =
            reinterpret_cast(allocator->malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size_padded));
        int*   d_input_ids     = reinterpret_cast(allocator->malloc(sizeof(int) * max_input_length * batchxbeam));
        int*   d_input_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam));
        float* d_cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batchxbeam));

        // initialize device buffers
        cudaH2Dcpy(d_logits, h_logits, max_input_length * batchxbeam * vocab_size_padded);
        cudaH2Dcpy(d_input_ids, h_input_ids, max_input_length * batchxbeam);
        cudaH2Dcpy(d_input_lengths, h_input_lengths, batchxbeam);
        check_cuda_error(cudaMemset(d_cum_log_probs, 0, sizeof(float) * batchxbeam));

        size_t workspace_size = sizeof(float) * max_input_length * batchxbeam;
        void*  workspace      = allocator->malloc(workspace_size);
        invokeLogProbFromLogits(d_cum_log_probs,
                                d_logits,
                                d_input_ids,
                                d_input_lengths,
                                max_input_length,
                                batchxbeam,
                                vocab_size,
                                vocab_size_padded,
                                workspace,
                                workspace_size,
                                stream,
                                true);

        computeCumLogProbsBatchFirst(expected_cum_log_probs,
                                     nullptr,
                                     h_logits,
                                     h_input_ids,
                                     h_input_lengths,
                                     max_input_length,
                                     batchxbeam,
                                     vocab_size,
                                     vocab_size_padded);
        std::string tag    = param.toString() + (std::is_same::value ? " (fp32)" : " (fp16)");
        bool        passed = checkResult(tag.c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam);
        EXPECT_TRUE(passed);

        delete[] expected_cum_log_probs;
        delete[] h_input_lengths;
        delete[] h_input_ids;
        delete[] h_logits;
    }
};

TYPED_TEST_SUITE(LogProbKernelTest, FloatAndHalfTypes);

TYPED_TEST(LogProbKernelTest, SingleStep)
{
    this->runTest({1, 32, 16, 1});
}

TYPED_TEST(LogProbKernelTest, AccumLongStep129)
{
    this->runTest({129, 8, 50211, 1});
}

TYPED_TEST(LogProbKernelTest, AccumLongStep1023)
{
    this->runTest({1023, 8, 5001, 1});
}

TYPED_TEST(LogProbKernelTest, AccumLongStep4096)
{
    this->runTest({4096, 8, 5001, 1});
}

TYPED_TEST(LogProbKernelTest, BatchFirstSingleStep)
{
    this->runBatchFirstTest({1, 32, 16, 1});
}

TYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep129)
{
    this->runBatchFirstTest({129, 8, 50211, 1});
}

TYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep1023)
{
    this->runBatchFirstTest({1023, 8, 5001, 1});
}

TYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep4096)
{
    this->runBatchFirstTest({4096, 8, 5001, 1});
}


================================================
FILE: tests/csrc/unittests/test_penalty_kernels.cu
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include   // std::min, std::max
#include    // snprintf
#include      // expf, log
#include 
#include   // rand
#include     // std::string
#include 
#include   // std::vector

#include 
#include 
#include 

#include "gtest_utils.h"
#include "src/turbomind/kernels/penalty_types.h"
#include "src/turbomind/kernels/sampling_penalty_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

using namespace turbomind;

struct TemperatureTestParam {
    size_t batch_size;
    size_t vocab_size;
    float* temperatures;
    size_t temperatures_size;

    std::string toString()
    {
        return fmtstr("TemperatureTestParam[batch=%ld, vocab=%ld, temperatures=%s]",
                      batch_size,
                      vocab_size,
                      arr2str(temperatures, temperatures_size).c_str());
    }
};

size_t pad_vocab_size(size_t vocab_size, size_t pad = 8)
{
    return (vocab_size + pad - 1) / pad * pad;
}

template
void applyRepetitonPenalty(T*           logits,
                           const int*   output_ids,
                           const int*   input_lengths,
                           const float  repetition_penalty,
                           const size_t step,
                           const size_t max_input_length,
                           const size_t batch_size,
                           const size_t vocab_size,
                           const size_t vocab_size_padded)
{
    bool* penalized = new bool[vocab_size];
    for (size_t i = 0; i < batch_size; ++i) {
        std::fill_n(penalized, vocab_size, false);
        size_t length = std::min(step, input_lengths[i]);
        size_t offset = i * vocab_size_padded;
        for (size_t t = 0; t < step; ++t) {
            if (t >= (size_t)input_lengths[i] && t < max_input_length) {
                continue;
            }
            int token_id = output_ids[i + t * batch_size];
            if (!penalized[token_id]) {
                float logit = static_cast(logits[offset + token_id]);
                logits[offset + token_id] =
                    static_cast(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);
                penalized[token_id] = true;
            }
        }
    }
    delete[] penalized;
}

template
void batchApplyRepetitonPenalty(T*           logits,
                                const int*   output_ids,
                                const int*   input_lengths,
                                const float* repetition_penalties,
                                const size_t step,
                                const size_t max_input_length,
                                const size_t batch_size,
                                const size_t vocab_size,
                                const size_t vocab_size_padded)
{
    bool* penalized = new bool[vocab_size];
    for (size_t i = 0; i < batch_size; ++i) {
        float repetition_penalty = repetition_penalties[i];
        std::fill_n(penalized, vocab_size, false);
        size_t offset = i * vocab_size_padded;
        for (size_t t = 0; t < step; ++t) {
            if (t >= (size_t)input_lengths[i] && t < max_input_length) {
                continue;
            }
            int token_id = output_ids[i + t * batch_size];
            if (!penalized[token_id]) {
                float logit = static_cast(logits[offset + token_id]);
                logits[offset + token_id] =
                    static_cast(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);
                penalized[token_id] = true;
            }
        }
    }
    delete[] penalized;
}

template
void initLogitsAndBias(
    T* logits, T* bias, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded)
{
    initRandom(logits, batch_size * vocab_size_padded, -5.0f, 5.0f);
    if (bias != nullptr) {
        initRandom(bias, vocab_size, -5.0f, 5.0f);
    }
    bool is_half = std::is_same::value;
    for (size_t i = 0; i < batch_size; ++i) {
        for (size_t j = 0; j < vocab_size_padded; ++j) {
            if (j >= vocab_size) {
                logits[i * vocab_size_padded + j] = static_cast(is_half ? -65504.f : -FLT_MAX);
                if (bias != nullptr && i == 0) {
                    bias[j] = (T)0.0f;
                }
            }
        }
    }
}

/////////////////////////////////// Tests //////////////////////////////////////////

template
class TemperaturePenaltyTest: public FtTestBase {
protected:
    // Set up test
    size_t batch_size_;
    size_t vocab_size_;
    size_t vocab_size_padded_;

    T* h_logits_;
    T* h_bias_;
    T* d_logits_;
    T* d_bias_;

    float* d_temperatures_;

    void subsetup(TemperatureTestParam param)
    {
        batch_size_        = param.batch_size;
        vocab_size_        = param.vocab_size;
        vocab_size_padded_ = pad_vocab_size(vocab_size_);

        h_logits_ = new T[batch_size_ * vocab_size_padded_];
        h_bias_   = new T[vocab_size_padded_];
        initLogitsAndBias(h_logits_, h_bias_, batch_size_, vocab_size_, vocab_size_padded_);

        d_logits_ = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));
        d_bias_   = reinterpret_cast(allocator->malloc(sizeof(T) * vocab_size_padded_));
        cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream);
        cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream);
        if (param.temperatures_size > 1) {
            ASSERT_EQ(param.temperatures_size, param.batch_size) << "Invalid test configuration.";
            d_temperatures_ = reinterpret_cast(allocator->malloc(sizeof(T) * param.temperatures_size));
            cudaAutoCpy(d_temperatures_, param.temperatures, batch_size_, stream);
        }
    }

    void subteardown()
    {
        delete[] h_logits_;
        delete[] h_bias_;
    }

    void computeReference(T*           logits,
                          const T*     bias,
                          const float* temperatures,
                          const size_t temperatures_size,
                          const size_t batch_size,
                          const size_t vocab_size,
                          const size_t vocab_size_padded)
    {
        for (size_t i = 0; i < batch_size; ++i) {
            float temperature = temperatures_size > 1 ? temperatures[i] : temperatures[0];
            ASSERT_GT(temperature, 0.0f) << "temperature should be positive but got " << temperature;
            for (size_t j = 0; j < vocab_size; ++j) {
                size_t index = i * vocab_size_padded + j;
                float  logit = static_cast(logits[index]);
                if (bias != nullptr) {
                    logit += static_cast(bias[j]);
                }
                logits[index] = static_cast(logit / temperature);
            }
        }
    }

public:
    void runTest(TemperatureTestParam param)
    {
        subsetup(param);
        // Do test
        if (param.temperatures_size == 1) {
            invokeApplyTemperaturePenalty(
                d_logits_, d_bias_, param.temperatures[0], batch_size_, vocab_size_, vocab_size_padded_, stream);
        }
        else {
            invokeBatchApplyTemperaturePenalty(
                d_logits_, d_bias_, d_temperatures_, batch_size_, vocab_size_, vocab_size_padded_, stream);
        }
        computeReference(h_logits_,
                         h_bias_,
                         param.temperatures,
                         param.temperatures_size,
                         batch_size_,
                         vocab_size_,
                         vocab_size_padded_);
        bool passed = checkResult(param.toString(), d_logits_, h_logits_, batch_size_ * vocab_size_padded_);
        EXPECT_TRUE(passed);
        subteardown();
    }

    void runConsistencyTest(TemperatureTestParam param)
    {
        // Set up test
        ASSERT_EQ(param.temperatures_size, 1) << "A consistency test assumes temperatures_size=1";
        subsetup(param);

        // Run a single runtime value case.
        invokeApplyTemperaturePenalty(
            d_logits_, d_bias_, param.temperatures[0], batch_size_, vocab_size_, vocab_size_padded_, stream);

        float  temperature    = param.temperatures[0];
        float* h_temperatures = new float[batch_size_];
        for (size_t i = 0; i < batch_size_; ++i) {
            h_temperatures[i] = temperature;
        }
        d_temperatures_ = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_));
        cudaAutoCpy(d_temperatures_, h_temperatures, batch_size_, stream);

        T* d_logits_batch = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));
        T* d_bias_batch   = reinterpret_cast(allocator->malloc(sizeof(T) * vocab_size_padded_));
        cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream);
        cudaAutoCpy(d_bias_batch, h_bias_, vocab_size_padded_, stream);

        invokeBatchApplyTemperaturePenalty(
            d_logits_batch, d_bias_batch, d_temperatures_, batch_size_, vocab_size_, vocab_size_padded_, stream);
        bool passed =
            checkResult(param.toString(), d_logits_, d_logits_batch, batch_size_ * vocab_size_padded_, true, true);
        EXPECT_TRUE(passed);

        // Tear down test
        delete[] h_temperatures;
        subteardown();
    }
};

// Since a compiler doesn't correctly catch the use of a variable inside gtest,
// we carefully suppress a compile warning message.
#pragma nv_diag_suppress 177

TYPED_TEST_SUITE(TemperaturePenaltyTest, testing::Types<__nv_bfloat16>);

TYPED_TEST(TemperaturePenaltyTest, NoPenalty)
{
    float temperature = 1.0f;
    this->runTest({6, 4, &temperature, 1});
}

TYPED_TEST(TemperaturePenaltyTest, LessThanOne)
{
    float temperature = 0.53f;
    this->runTest({6, 4, &temperature, 1});
}

TYPED_TEST(TemperaturePenaltyTest, GreaterThaneOne)
{
    float temperature = 2.01f;
    this->runTest({6, 4, &temperature, 1});
}

TYPED_TEST(TemperaturePenaltyTest, LargeVocab)
{
    float temperature = 2.01f;
    this->runTest({6, 50001, &temperature, 1});
}

TYPED_TEST(TemperaturePenaltyTest, BatchNoPenalty)
{
    size_t batch_size   = 6;
    float* temperatures = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        temperatures[i] = 1.0f;
    }
    this->runTest({batch_size, 4, temperatures, batch_size});
}

TYPED_TEST(TemperaturePenaltyTest, BatchLessThanOne)
{
    size_t batch_size   = 6;
    float* temperatures = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        temperatures[i] = 0.53f;
    }
    this->runTest({batch_size, 4, temperatures, batch_size});
}

TYPED_TEST(TemperaturePenaltyTest, BatchGreaterThaneOne)
{
    size_t batch_size   = 6;
    float* temperatures = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        temperatures[i] = 2.01f;
    }
    this->runTest({batch_size, 4, temperatures, batch_size});
}

TYPED_TEST(TemperaturePenaltyTest, BatchMixed)
{
    size_t batch_size   = 6;
    float* temperatures = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        temperatures[i] = i % 2 == 0 ? 2.01f : 0.53f;
    }
    this->runTest({batch_size, 4, temperatures, batch_size});
}

TYPED_TEST(TemperaturePenaltyTest, Consistency)
{
    float temperature = 2.01f;
    this->runConsistencyTest({6, 4, &temperature, 1});
}

struct RepetitionPenaltyTestCase {
    size_t                batch_size;
    size_t                vocab_size;
    size_t                max_input_length;
    float*                repetition_penalties;
    size_t                repetition_penalties_size;
    RepetitionPenaltyType repetition_penalty_type;

    std::string toString()
    {
        static const std::unordered_map typestr_map{
            {RepetitionPenaltyType::Additive, "additive"},
            {RepetitionPenaltyType::Multiplicative, "multiplicative"},
            {RepetitionPenaltyType::None, "none"}};
        return fmtstr("RepetitionPenaltyTestCase[batch=%ld, vocab=%ld, max_input_length=%ld, "
                      "repetition_penalties=%s, repetition_penalty_type=%s]",
                      batch_size,
                      vocab_size,
                      max_input_length,
                      arr2str(repetition_penalties, repetition_penalties_size).c_str(),
                      typestr_map.at(repetition_penalty_type).c_str());
    }
};

template
class RepetitionPenaltyTest: public FtTestBase {
protected:
    // Set up test
    size_t batch_size_;
    size_t vocab_size_;
    size_t vocab_size_padded_;
    size_t max_input_length_;
    size_t sequence_length_;
    size_t step_;

    T*   h_logits_;
    T*   h_bias_;
    int* h_output_ids_;
    int* h_input_lengths_;

    T*   d_logits_;
    T*   d_bias_;
    int* d_output_ids_;
    int* d_input_lengths_;
    int* d_penalty_workspace_;

    float* d_repetition_penalties_;

    void subsetup(RepetitionPenaltyTestCase param)
    {
        batch_size_        = param.batch_size;
        vocab_size_        = param.vocab_size;
        vocab_size_padded_ = pad_vocab_size(vocab_size_);
        max_input_length_  = param.max_input_length;
        sequence_length_   = 2 * max_input_length_;  // input + output
        step_              = sequence_length_ * 0.7;

        h_logits_        = new T[batch_size_ * vocab_size_padded_];
        h_bias_          = new T[vocab_size_padded_];
        h_output_ids_    = new int[sequence_length_ * batch_size_];
        h_input_lengths_ = new int[batch_size_];
        initLogitsAndBias(h_logits_, h_bias_, batch_size_, vocab_size_, vocab_size_padded_);
        initRandomInt(h_output_ids_, sequence_length_ * batch_size_, 0, vocab_size_);
        initRandomInt(h_input_lengths_, batch_size_, 1, max_input_length_);

        d_logits_        = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));
        d_bias_          = reinterpret_cast(allocator->malloc(sizeof(T) * vocab_size_padded_));
        d_output_ids_    = reinterpret_cast(allocator->malloc(sizeof(int) * sequence_length_ * batch_size_));
        d_input_lengths_ = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size_));
        d_penalty_workspace_ =
            reinterpret_cast(allocator->malloc((sizeof(int) + sizeof(float)) * batch_size_ * step_));

        cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream);
        cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream);
        cudaAutoCpy(d_output_ids_, h_output_ids_, sequence_length_ * batch_size_, stream);
        cudaAutoCpy(d_input_lengths_, h_input_lengths_, batch_size_, stream);
        if (param.repetition_penalties_size > 1) {
            ASSERT_EQ(param.repetition_penalties_size, param.batch_size) << "Invalid test configuration.";
            d_repetition_penalties_ =
                reinterpret_cast(allocator->malloc(sizeof(T) * param.repetition_penalties_size));
            cudaAutoCpy(d_repetition_penalties_, param.repetition_penalties, batch_size_, stream);
        }
    }

    void subteardown()
    {
        delete[] h_logits_;
        delete[] h_bias_;
        delete[] h_output_ids_;
        delete[] h_input_lengths_;
    }

    void computeReference(T*                          logits,
                          const int*                  output_ids,
                          const int*                  input_lengths,
                          const float*                repetition_penalties,
                          const size_t                repetition_penalties_size,
                          const RepetitionPenaltyType repetition_penalty_type,
                          const size_t                step,
                          const size_t                max_input_length,
                          const size_t                batch_size,
                          const size_t                vocab_size,
                          const size_t                vocab_size_padded)
    {
        bool* penalized = new bool[vocab_size];
        for (size_t i = 0; i < batch_size; ++i) {
            float repetition_penalty =
                repetition_penalties_size > 1 ? repetition_penalties[i] : repetition_penalties[0];

            std::fill_n(penalized, vocab_size, false);
            size_t offset = i * vocab_size_padded;
            for (size_t t = 0; t < step; ++t) {
                if (t >= (size_t)input_lengths[i] && t < max_input_length) {
                    continue;
                }
                int token_id = output_ids[i + t * batch_size];
                if (!penalized[token_id]) {
                    float logit = static_cast(logits[offset + token_id]);
                    switch (repetition_penalty_type) {
                        case RepetitionPenaltyType::Additive:
                            logits[offset + token_id] = static_cast(logit - repetition_penalty);
                            break;
                        case RepetitionPenaltyType::Multiplicative:
                            logits[offset + token_id] =
                                static_cast(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);
                            break;
                        case RepetitionPenaltyType::None:
                            // None. do nothing.
                            break;
                        default:
                            throw std::domain_error("Invalid repetition penalty type.");
                    }
                    penalized[token_id] = true;
                }
            }
        }
        delete[] penalized;
    }

public:
    void runTest(RepetitionPenaltyTestCase param)
    {
        subsetup(param);
        // Do test
        if (param.repetition_penalties_size == 1) {
            invokeApplyRepetitionPenalty(d_logits_,
                                         param.repetition_penalties[0],
                                         nullptr,
                                         d_output_ids_,
                                         batch_size_,
                                         batch_size_,
                                         vocab_size_,
                                         vocab_size_padded_,
                                         d_input_lengths_,
                                         max_input_length_,
                                         step_,
                                         param.repetition_penalty_type,
                                         stream);
        }
        else {
            invokeBatchApplyRepetitionPenalty(d_logits_,
                                              d_repetition_penalties_,
                                              d_penalty_workspace_,
                                              d_output_ids_,
                                              batch_size_,
                                              batch_size_,
                                              vocab_size_padded_,
                                              d_input_lengths_,
                                              max_input_length_,
                                              step_,
                                              param.repetition_penalty_type,
                                              stream);
        }
        computeReference(h_logits_,
                         h_output_ids_,
                         h_input_lengths_,
                         param.repetition_penalties,
                         param.repetition_penalties_size,
                         param.repetition_penalty_type,
                         step_,
                         max_input_length_,
                         batch_size_,
                         vocab_size_,
                         vocab_size_padded_);
        bool passed = checkResult(param.toString(), d_logits_, h_logits_, batch_size_ * vocab_size_padded_);
        EXPECT_TRUE(passed);
        subteardown();
    }

    void runConsistencyTest(RepetitionPenaltyTestCase param)
    {
        // Set up test
        ASSERT_EQ(param.repetition_penalties_size, 1) << "A consistency test assumes repetition_penalties_size=1";
        subsetup(param);

        // Run a single runtime value case.
        invokeApplyRepetitionPenalty(d_logits_,
                                     param.repetition_penalties[0],
                                     nullptr,
                                     d_output_ids_,
                                     batch_size_,
                                     batch_size_,
                                     vocab_size_,
                                     vocab_size_padded_,
                                     d_input_lengths_,
                                     max_input_length_,
                                     step_,
                                     param.repetition_penalty_type,
                                     stream);

        float* h_repetition_penalties = new float[batch_size_];
        for (size_t i = 0; i < batch_size_; ++i) {
            h_repetition_penalties[i] = param.repetition_penalties[0];
        }
        d_repetition_penalties_ = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_));
        cudaAutoCpy(d_repetition_penalties_, h_repetition_penalties, batch_size_, stream);

        T* d_logits_batch = reinterpret_cast(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));
        cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream);
        invokeBatchApplyRepetitionPenalty(d_logits_batch,
                                          d_repetition_penalties_,
                                          d_penalty_workspace_,
                                          d_output_ids_,
                                          batch_size_,
                                          batch_size_,
                                          vocab_size_padded_,
                                          d_input_lengths_,
                                          max_input_length_,
                                          step_,
                                          param.repetition_penalty_type,
                                          stream);
        bool passed =
            checkResult(param.toString(), d_logits_, d_logits_batch, batch_size_ * vocab_size_padded_, true, true);
        EXPECT_TRUE(passed);

        // Tear down test
        delete[] h_repetition_penalties;
        subteardown();
    }
};

TYPED_TEST_SUITE(RepetitionPenaltyTest, SamplingTypes);

TYPED_TEST(RepetitionPenaltyTest, NoPenalty)
{
    float repetition_penalty = 1.0f;
    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, LessThanOne)
{
    float repetition_penalty = 0.53f;
    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, GreaterThaneOne)
{
    float repetition_penalty = 2.01f;
    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, LargeVocab)
{
    float repetition_penalty = 2.01f;
    this->runTest({6, 50001, 1003, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)
{
    size_t batch_size           = 6;
    float* repetition_penalties = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        repetition_penalties[i] = 1.0f;
    }
    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, BatchLessThanOne)
{
    size_t batch_size           = 6;
    float* repetition_penalties = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        repetition_penalties[i] = 0.53f;
    }
    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, BatchGreaterThaneOne)
{
    size_t batch_size   = 6;
    float* temperatures = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        temperatures[i] = 2.01f;
    }
    this->runTest({batch_size, 4, 5, temperatures, batch_size, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, BatchMixed)
{
    size_t batch_size           = 6;
    float* repetition_penalties = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        repetition_penalties[i] = i % 2 == 0 ? 2.01f : 0.53f;
    }
    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, Consistency)
{
    float repetition_penalty = 2.01f;
    this->runConsistencyTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});
}

TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditive)
{
    size_t batch_size           = 6;
    float* repetition_penalties = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        repetition_penalties[i] = i % 2 == 0 ? 2.01f : 0.53f;
    }
    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Additive});
}

TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditiveHasDefaultValueZero)
{
    float repetition_penalty = 1.0f;
    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Additive});
}

TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditiveHasDefaultValueZero2)
{
    size_t batch_size           = 6;
    float* repetition_penalties = new float[batch_size];
    for (size_t i = 0; i < batch_size; ++i) {
        repetition_penalties[i] = i % 2 == 0 ? 1.0f : 0.0f;
    }
    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Additive});
}

// Turn on the warning message.
#pragma nv_diag_suppress 177


================================================
FILE: tests/csrc/unittests/test_sampling_kernels.cu
================================================
#include   // std::fill_n
#include    // snprintf
#include      // expf, log
#include    // rand
#include      // std::string
#include      // std::vector

#include 
#include 
#include 
#include 

#include "src/turbomind/kernels/sampling_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/layers/DynamicDecodeLayer.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/constant.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

#include "gtest_utils.h"

using namespace turbomind;

namespace {

__global__ void get_curand_uniform(curandState_t* curandstate, float* output, int n)
{
    int   batch_id   = blockIdx.x;
    float rand_num   = (float)curand_uniform(curandstate + batch_id);
    output[batch_id] = rand_num;
}

template
bool checkSorted(int  batch_size,
                 T*   expected_logits,
                 T*   output_logits,
                 int* expected_indices,
                 int* output_indices,
                 int* expected_kept,
                 int* output_kept,
                 int  vocab_size)
{
    for (int i = 0; i < batch_size; i++) {
        if (expected_kept[i] != output_kept[i]) {
            printf("batch=%d, expected_kept[i]=%d, output_kept[i]=%d\n", i, expected_kept[i], output_kept[i]);
            return false;
        }

        for (int j = 0; j < expected_kept[i]; j++) {
            int index = i * vocab_size + j;
            // soft check
            if (std::abs((float)expected_logits[index] - (float)output_logits[index]) > 1e-6
                && expected_indices[index] != output_indices[index]) {
                printf("batch=%d, ith=%d, expected=(%d, %.5f), output=(%d, %.5f)\n",
                       i,
                       j,
                       expected_indices[index],
                       (float)expected_logits[index],
                       output_indices[index],
                       (float)output_logits[index]);
                return false;
            }
        }
    }
    return true;
}

template
bool checkSample(int* expected_output_ids,
                 int* output_ids,
                 int  batch_size,
                 T*   expected_sampled_logprobs,
                 int* expected_sampled_indices,
                 int* expected_sampled_nums,
                 T*   output_sampled_logprobs,
                 int* output_sampled_indices,
                 int* output_sampled_nums)
{
    for (int i = 0; i < batch_size; i++) {
        if (expected_sampled_nums[i] != output_sampled_nums[i]) {
            printf("batch=%d, sampled_nums, cpu=%d, gpu=%d\n", i, expected_sampled_nums[i], output_sampled_nums[i]);
            return false;
        }
        if (expected_output_ids[i] != output_ids[i]) {
            printf("batch=%d, expected_output_ids=%d, output_ids=%d\n", i, expected_output_ids[i], output_ids[i]);
            return false;
        }
        for (int j = 0; j < expected_sampled_nums[i]; j++) {
            int   offset  = i * kMaxLogProb + j;
            float gpu_val = output_sampled_logprobs[offset];
            float cpu_val = expected_sampled_logprobs[offset];
            int   gpu_idx = output_sampled_indices[offset];
            int   cpu_idx = expected_sampled_indices[offset];
            if (std::abs(gpu_val - cpu_val) > 1e-5) {
                if (gpu_idx != cpu_idx) {
                    printf("%d %d\n", expected_output_ids[i], output_ids[i]);
                    printf("batch=%d, ith=%d, idx cpu=%d, gpu=%d, val cpu=%.5f, gpu=%.5f\n",
                           i,
                           j,
                           cpu_idx,
                           gpu_idx,
                           cpu_val,
                           gpu_val);
                    return false;
                }
            }
        }
    }
    return true;
}

template
void sampleCpu(int    batch_size,
               int    vocab_size,
               T*     logits,
               int*   indices,
               int*   kept,
               float* uniforms,
               int*   output_ids,
               T*     sampled_logprobs,
               int*   sampled_indices,
               int*   sampled_nums)
{

    for (int i = 0; i < batch_size; i++) {
        int   selected = -1;
        float sum_val  = 0.f;
        for (int j = 0; j < kept[i]; j++) {
            sum_val += (float)logits[i * vocab_size + j];
            if (sum_val > uniforms[i]) {
                selected      = j;
                output_ids[i] = indices[i * vocab_size + j];
                break;
            }
        }

        if (sampled_logprobs && sampled_indices && sampled_nums) {
            for (int j = 0; j < min(kept[i], kMaxLogProb); ++j) {
                sampled_logprobs[i * kMaxLogProb + j] = std::log((float)logits[i * vocab_size + j]);
                sampled_indices[i * kMaxLogProb + j]  = indices[i * vocab_size + j];
            }
            if (kept[i] > kMaxLogProb && selected >= kMaxLogProb) {
                sampled_logprobs[i * kMaxLogProb + kMaxLogProb - 1] =
                    std::log((float)logits[i * vocab_size + selected]);
                sampled_indices[i * kMaxLogProb + kMaxLogProb - 1] = indices[i * vocab_size + selected];
            }
            sampled_nums[i] = min(kept[i], kMaxLogProb);
        }
    }
}

template
void softmax(T* input, int batch_size, int vocab_size, int* kept, T* output)
{
    for (int i = 0; i < batch_size; i++) {
        int   offset  = i * vocab_size;
        float max_val = input[offset];
        for (int j = 0; j < kept[i]; j++) {
            max_val = std::max((float)input[offset + j], max_val);
        }
        float sum_val{};
        for (int j = 0; j < kept[i]; j++) {
            output[offset + j] = std::exp((float)input[offset + j] - max_val);
            sum_val += (float)output[offset + j];
        }
        for (int j = 0; j < kept[i]; j++) {
            output[offset + j] = (float)output[offset + j] / sum_val;
        }
    }
}

template
void filterCpu(int    batch_size,
               int*   top_ks,
               float* top_ps,
               float* min_ps,
               T*     logits,
               T*     sorted_logits,
               int*   sorted_indices,
               int*   kept,
               int    vocab_size,
               bool   filter_topp = false,
               bool   filter_minp = false)
{
    for (int i = 0; i < batch_size; i++) {
        // fill
        std::vector> work(vocab_size);
        for (int j = 0; j < vocab_size; j++) {
            work[j] = {logits[i * vocab_size + j], j};
        }

        // sort
        if (top_ks && top_ks[i] != 0) {
            std::partial_sort(work.begin(), work.begin() + top_ks[i], work.end(), std::greater{});
            kept[i] = top_ks[i];
        }
        else {
            std::sort(work.begin(), work.end(), std::greater{});
            kept[i] = vocab_size;
        }
        for (int j = 0; j < kept[i]; j++) {
            sorted_logits[i * vocab_size + j]  = work[j].first;
            sorted_indices[i * vocab_size + j] = work[j].second;
        }
        // softmax
        softmax(sorted_logits + i * vocab_size, 1, vocab_size, kept + i, sorted_logits + i * vocab_size);
        if (top_ks && top_ks[i] == 0) {
            if (top_ps && (float)sorted_logits[i * vocab_size] > top_ps[i]) {
                sorted_logits[i * vocab_size] = 1.f;
                kept[i]                       = 1;
            }
        }

        // topp filter
        if (filter_topp && top_ps[i] != 1.f) {
            float topp    = top_ps[i];
            float sum_val = 0;
            int   n       = kept[i];
            for (int j = 0; j < kept[i]; j++) {
                sum_val += (float)sorted_logits[i * vocab_size + j];
                if (sum_val > topp) {
                    n = j + 1;
                    break;
                }
            }
            if (n != kept[i]) {
                kept[i] = n;
                for (int j = 0; j < n; j++) {
                    sorted_logits[i * vocab_size + j] = (float)sorted_logits[i * vocab_size + j] / (sum_val + 1e-6f);
                }
            }
        }

        // minp filter
        if (filter_minp && min_ps[i] != 0.f) {
            float minp      = min_ps[i];
            float threshold = (float)sorted_logits[i * vocab_size] * minp;
            float sum_val   = 0;
            int   n         = kept[i];
            for (int j = 0; j < kept[i]; j++) {
                if ((float)sorted_logits[i * vocab_size + j] < threshold) {
                    n = j;
                    break;
                }
                sum_val += (float)sorted_logits[i * vocab_size + j];
            }
            if (n != kept[i]) {
                kept[i] = n;
                for (int j = 0; j < n; j++) {
                    sorted_logits[i * vocab_size + j] = (float)sorted_logits[i * vocab_size + j] / (sum_val + 1e-6f);
                }
            }
        }
    }
}

template
class SamplingKernelTest: public testing::Test {
public:
    void SetUp() override
    {
        check_cuda_error(cudaStreamCreate(&stream));
        allocator = new Allocator(getDevice());
        allocator->setStream(stream);
    }
    void TearDown() override
    {
        delete allocator;
        check_cuda_error(cudaStreamDestroy(stream));
    }

protected:
    cudaStream_t                    stream;
    Allocator* allocator;
    curandState_t*                  curand_states;
};

template
class TopKTopPSortTest: public SamplingKernelTest {
protected:
    using SamplingKernelTest::stream;
    using SamplingKernelTest::allocator;

public:
    void runTest(int batch_size, int* top_ks, float* top_ps, int vocab_size)
    {

        TopKSortFilterParams params1{};
        params1.batch_size = batch_size;
        int max_top_k      = *std::max_element(top_ks, top_ks + batch_size);
        params1.max_top_k  = std::min(1024, std::max(0, max_top_k));
        invokeTopKSortFilter(params1, stream);

        TopPSortParams params2{};
        params2.batch_size        = batch_size;
        params2.vocab_size        = vocab_size;
        params2.vocab_size_padded = vocab_size;
        invokeTopPSort(params2, stream);

        // host buffer
        std::vector   logits(batch_size * vocab_size);
        std::vector   expected_logits(batch_size * vocab_size);
        std::vector expected_indices(batch_size * vocab_size);
        std::vector expected_kept(batch_size);

        std::vector   output_logits(batch_size * vocab_size);
        std::vector output_indices(batch_size * vocab_size);
        std::vector output_kept(batch_size);

        // device buffer
        void*  d_ws_topk        = allocator->malloc(params1.workspace_size);
        void*  d_ws_topp        = allocator->malloc(params2.workspace_size);
        T*     d_logits         = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);
        T*     d_sorted_logits  = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);
        int*   d_sorted_indices = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);
        int*   d_kept           = (int*)allocator->malloc(sizeof(int) * batch_size);
        int*   d_top_ks         = (int*)allocator->malloc(sizeof(int) * batch_size);
        float* d_top_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);

        float boundary = 1.f;
        for (int x = vocab_size; x >= 10; x /= 10) {
            boundary *= 10;
        }
        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);

        std::fill_n(expected_kept.data(), batch_size, vocab_size);

        cudaAutoCpy(d_logits, logits.data(), batch_size * vocab_size, stream);
        cudaAutoCpy(d_top_ps, top_ps, batch_size, stream);
        cudaAutoCpy(d_top_ks, top_ks, batch_size, stream);
        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);

        // gpu
        params1.workspace         = d_ws_topk;
        params1.logits            = d_logits;
        params1.sorted_logits     = d_sorted_logits;
        params1.sorted_indices    = d_sorted_indices;
        params1.kept              = d_kept;
        params1.top_ks            = d_top_ks;
        params1.vocab_size        = vocab_size;
        params1.vocab_size_padded = vocab_size;
        invokeTopKSortFilter(params1, stream);

        invokeSoftmax(d_logits, vocab_size, vocab_size, batch_size, d_kept, stream);
        params2.workspace      = d_ws_topp;
        params2.logits         = d_logits;
        params2.sorted_logits  = d_sorted_logits;
        params2.sorted_indices = d_sorted_indices;
        params2.kept           = d_kept;
        params2.top_ks         = d_top_ks;
        params2.top_ps         = d_top_ps;
        invokeTopPSort(params2, stream);

        // outputs
        cudaAutoCpy(output_logits.data(), d_sorted_logits, batch_size * vocab_size);
        cudaAutoCpy(output_indices.data(), d_sorted_indices, batch_size * vocab_size);
        cudaAutoCpy(output_kept.data(), d_kept, batch_size, stream);
        cudaStreamSynchronize(stream);

        // cpu
        filterCpu(batch_size,
                  top_ks,
                  top_ps,
                  nullptr,
                  logits.data(),
                  expected_logits.data(),
                  expected_indices.data(),
                  expected_kept.data(),
                  vocab_size);

        EXPECT_TRUE(checkSorted(batch_size,
                                expected_logits.data(),
                                output_logits.data(),
                                expected_indices.data(),
                                output_indices.data(),
                                expected_kept.data(),
                                output_kept.data(),
                                vocab_size));
    }
};

TYPED_TEST_SUITE(TopKTopPSortTest, SamplingTypes);

TYPED_TEST(TopKTopPSortTest, OnlyTopKBatch)
{
    int   top_ks[] = {1, 2, 3, 4, 5, 6, 7, 8};
    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 20);
};

TYPED_TEST(TopKTopPSortTest, OnlyTopKLargeVocab)
{
    int   top_ks[] = {1, 2, 4, 8, 16, 32, 64, 1024};
    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);
};

TYPED_TEST(TopKTopPSortTest, OnlyTopPBatch)
{
    int   top_ks[] = {0, 0, 0, 0, 0, 0, 0, 0};
    float top_ps[] = {0.0f, 0.1f, 0.3f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};
    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 20);
};

TYPED_TEST(TopKTopPSortTest, OnlyTopPLargeVocab)
{
    int   top_ks[] = {0, 0, 0, 0, 0, 0, 0, 0};
    float top_ps[] = {0.0f, 0.1f, 0.3f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};
    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);
};

TYPED_TEST(TopKTopPSortTest, MixedTopKTopP)
{
    int   top_ks[] = {1, 0, 16, 0, 32, 0, 64, 1024};
    float top_ps[] = {0.0f, 0.1f, 0.0f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};
    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);
};

template
class TopPMinPFilterTest: public SamplingKernelTest {
protected:
    using SamplingKernelTest::stream;
    using SamplingKernelTest::allocator;

public:
    void runTest(int batch_size, float* top_ps, float* min_ps, int vocab_size)
    {

        // host buffer
        std::vector   logits(batch_size * vocab_size);
        std::vector   expected_logits(batch_size * vocab_size);
        std::vector expected_indices(batch_size * vocab_size);
        std::vector expected_kept(batch_size);

        std::vector   output_logits(batch_size * vocab_size);
        std::vector output_indices(batch_size * vocab_size);
        std::vector output_kept(batch_size);

        // device buffer
        T*     d_sorted_logits  = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);
        int*   d_sorted_indices = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);
        int*   d_kept           = (int*)allocator->malloc(sizeof(int) * batch_size);
        float* d_top_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);
        float* d_min_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);

        float boundary = 1.f;
        for (int x = vocab_size; x >= 10; x /= 10) {
            boundary *= 10;
        }
        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);
        std::fill_n(expected_kept.data(), batch_size, vocab_size);

        filterCpu(batch_size,
                  nullptr,
                  top_ps,
                  min_ps,
                  logits.data(),
                  expected_logits.data(),
                  expected_indices.data(),
                  expected_kept.data(),
                  vocab_size);

        cudaAutoCpy(d_sorted_logits, expected_logits.data(), batch_size * vocab_size);
        cudaAutoCpy(d_sorted_indices, expected_indices.data(), batch_size * vocab_size);
        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);
        cudaAutoCpy(d_top_ps, top_ps, batch_size, stream);
        cudaAutoCpy(d_min_ps, min_ps, batch_size, stream);

        TopPMinPFilterParams params{};
        params.sorted_logits     = d_sorted_logits;
        params.sorted_indices    = d_sorted_indices;
        params.kept              = d_kept;
        params.top_ps            = d_top_ps;
        params.min_ps            = d_min_ps;
        params.batch_size        = batch_size;
        params.vocab_size        = vocab_size;
        params.vocab_size_padded = vocab_size;
        invokeTopPMinPFilter(params, stream);
        cudaStreamSynchronize(stream);

        // outputs
        cudaAutoCpy(output_logits.data(), d_sorted_logits, batch_size * vocab_size);
        cudaAutoCpy(output_indices.data(), d_sorted_indices, batch_size * vocab_size);
        cudaAutoCpy(output_kept.data(), d_kept, batch_size, stream);
        cudaStreamSynchronize(stream);

        // cpu
        filterCpu(batch_size,
                  nullptr,
                  top_ps,
                  min_ps,
                  logits.data(),
                  expected_logits.data(),
                  expected_indices.data(),
                  expected_kept.data(),
                  vocab_size,
                  true,
                  true);

        EXPECT_TRUE(checkSorted(batch_size,
                                expected_logits.data(),
                                output_logits.data(),
                                expected_indices.data(),
                                output_indices.data(),
                                expected_kept.data(),
                                output_kept.data(),
                                vocab_size));
    }
};

TYPED_TEST_SUITE(TopPMinPFilterTest, SamplingTypes);

TYPED_TEST(TopPMinPFilterTest, OnlyTopP)
{
    float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};
    float min_ps[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

TYPED_TEST(TopPMinPFilterTest, OnlyMinP)
{
    float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};
    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

TYPED_TEST(TopPMinPFilterTest, MixedTopPMinP)
{
    float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};
    float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};
    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);
};

template
class SamplingTest: public SamplingKernelTest {
protected:
    using SamplingKernelTest::stream;
    using SamplingKernelTest::allocator;

public:
    void runTest(int batch_size, int vocab_size, int top_logprobs)
    {

        // host buffer
        std::vector     logits(batch_size * vocab_size);
        std::vector     expected_logits(batch_size * vocab_size);
        std::vector   expected_indices(batch_size * vocab_size);
        std::vector   expected_kept(batch_size);
        std::vector   expected_output_ids(batch_size);
        std::vector uniforms(batch_size);

        std::vector   sampled_logprobs(batch_size * kMaxLogProb);
        std::vector sampled_indexes(batch_size * kMaxLogProb);
        std::vector sampled_nums(batch_size);

        // std::vector     output_logits(batch_size * vocab_size);
        // std::vector   output_indices(batch_size * vocab_size);
        // std::vector   output_kept(batch_size);
        std::vector output_ids(batch_size);
        std::vector   output_sampled_logprobs(batch_size * kMaxLogProb);
        std::vector output_sampled_indexes(batch_size * kMaxLogProb);
        std::vector output_sampled_nums(batch_size);

        // device buffer
        T*             d_sorted_logits    = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);
        int*           d_sorted_indices   = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);
        int*           d_kept             = (int*)allocator->malloc(sizeof(int) * batch_size);
        float*         d_top_ps           = (float*)allocator->malloc(sizeof(float) * batch_size);
        float*         d_min_ps           = (float*)allocator->malloc(sizeof(float) * batch_size);
        float*         d_uniforms         = (float*)(allocator->malloc(sizeof(float) * batch_size));
        int*           d_output_ids       = (int*)(allocator->malloc(sizeof(int) * batch_size));
        T*             d_sampled_logprobs = (T*)(allocator->malloc(sizeof(T) * batch_size * kMaxLogProb));
        int*           d_sampled_indexes  = (int*)(allocator->malloc(sizeof(int) * batch_size * kMaxLogProb));
        int*           d_sampled_nums     = (int*)(allocator->malloc(sizeof(int) * batch_size));
        curandState_t* curand_states =
            reinterpret_cast(allocator->malloc(sizeof(curandState_t) * batch_size, false));

        float boundary = 1.f;
        for (int x = vocab_size; x >= 10; x /= 10) {
            boundary *= 10;
        }
        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);
        std::fill_n(expected_kept.data(), batch_size, vocab_size);

        // sort & softmax
        filterCpu(batch_size,
                  nullptr,
                  nullptr,
                  nullptr,
                  logits.data(),
                  expected_logits.data(),
                  expected_indices.data(),
                  expected_kept.data(),
                  vocab_size);

        cudaAutoCpy(d_sorted_logits, expected_logits.data(), batch_size * vocab_size);
        cudaAutoCpy(d_sorted_indices, expected_indices.data(), batch_size * vocab_size);
        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);

        // uniforms
        for (int i = 0; i < batch_size; i++) {
            invokeCurandInitialize(curand_states + i, 1, i, stream);
        }
        get_curand_uniform<<>>(curand_states, d_uniforms, batch_size);
        cudaAutoCpy(uniforms.data(), d_uniforms, batch_size, stream);
        for (int i = 0; i < batch_size; i++) {
            invokeCurandInitialize(curand_states + i, 1, i, stream);
        }

        // sample
        SamplingParams params{};
        params.logits           = d_sorted_logits;
        params.stride           = vocab_size;
        params.indices          = d_sorted_indices;
        params.kept             = d_kept;
        params.curandstate      = curand_states;
        params.batch_size       = batch_size;
        params.output_ids       = d_output_ids;
        params.sequence_length  = nullptr;
        params.sampled_logprobs = d_sampled_logprobs;
        params.sampled_indexes  = (uint32_t*)d_sampled_indexes;
        params.sampled_nums     = (uint32_t*)d_sampled_nums;
        invokeSampling(params, stream);

        // outputs
        cudaAutoCpy(output_ids.data(), d_output_ids, batch_size, stream);
        cudaAutoCpy(output_sampled_logprobs.data(), d_sampled_logprobs, batch_size * kMaxLogProb, stream);
        cudaAutoCpy(output_sampled_indexes.data(), d_sampled_indexes, batch_size * kMaxLogProb, stream);
        cudaAutoCpy(output_sampled_nums.data(), d_sampled_nums, batch_size, stream);
        cudaStreamSynchronize(stream);

        sampleCpu(batch_size,
                  vocab_size,
                  expected_logits.data(),
                  expected_indices.data(),
                  expected_kept.data(),
                  uniforms.data(),
                  expected_output_ids.data(),
                  sampled_logprobs.data(),
                  sampled_indexes.data(),
                  sampled_nums.data());

        EXPECT_TRUE(checkSample(expected_output_ids.data(),
                                output_ids.data(),
                                batch_size,
                                sampled_logprobs.data(),
                                sampled_indexes.data(),
                                sampled_nums.data(),
                                output_sampled_logprobs.data(),
                                output_sampled_indexes.data(),
                                output_sampled_nums.data()));
    }
};

TYPED_TEST_SUITE(SamplingTest, SamplingTypes);

TYPED_TEST(SamplingTest, Single)
{
    this->runTest(1, 20, 5);
};

TYPED_TEST(SamplingTest, Batch)
{
    this->runTest(32, 9700, 1024);
};

}  // end of namespace


================================================
FILE: tests/csrc/unittests/test_sampling_layer.cu
================================================
#include   // std::min, std::max
#include    // snprintf
#include      // expf, log
#include    // rand
#include      // std::string
#include      // std::vector

#include 
#include 
#include 

#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/layers/DynamicDecodeLayer.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

#include "gtest_utils.h"

using namespace turbomind;

struct SamplingLayerTestParam {
    size_t batch_size;
    size_t vocab_size;
    size_t beam_width;
    size_t top_k;
    float  top_p;
    size_t output_len;

    std::string toString()
    {
        return fmtstr("SamplingLayerTestParam[batch=%ld, vocab=%ld, beam=%ld, k=%ld, p=%3.1f, output_len=%ld]",
                      batch_size,
                      vocab_size,
                      beam_width,
                      top_k,
                      top_p,
                      output_len);
    }
};

template
void computeProb(T* probs, T* logits, int batch_size, int vocab_size)
{
    // Compute the log probability from logits.
    //   logits = batch_size x vocab_size vector.
    //   logprobs = log(softmax(logits)) (softmax along with vocab dimension)
    for (int bidx = 0; bidx < batch_size; ++bidx) {
        float sum = 0.0f;
        for (int i = 0; i < vocab_size; ++i) {
            sum += expf((float)logits[bidx * vocab_size + i]);
        }
        for (int i = 0; i < vocab_size; ++i) {
            int idx    = bidx * vocab_size + i;
            probs[idx] = static_cast(expf((float)logits[idx]) / (sum + EPSILON));
        }
    }
}

template
void computeLogProb(T* logprobs, T* logits, int batch_size, int vocab_size)
{
    // Compute the log probability from logits.
    //   logits = batch_size x vocab_size vector.
    //   logprobs = log(softmax(logits)) (softmax along with vocab dimension)
    for (int bidx = 0; bidx < batch_size; ++bidx) {
        float sum = 0.0f;
        for (int i = 0; i < vocab_size; ++i) {
            sum += expf(logits[bidx * vocab_size + i]);
        }
        for (int i = 0; i < vocab_size; ++i) {
            int idx       = bidx * vocab_size + i;
            logprobs[idx] = static_cast(logf(expf(logits[idx]) / (sum + EPSILON) + EPSILON));
        }
    }
}

template
class SamplingDecodeTest: public testing::Test {
protected:
    unsigned long long              seed           = 0;
    const static unsigned long long max_seed       = 30;
    const size_t                    batch_size     = 6;
    const size_t                    beam_width     = 1;
    const size_t                    batchxbeam     = batch_size * beam_width;
    const size_t                    vocab_size     = 8;
    const size_t                    max_input_len  = 0;  // has no effect.
    const size_t                    max_output_len = 3;
    const size_t                    max_seq_len    = max_input_len + max_output_len;
    const int                       end_id         = vocab_size - 1;
    const DataType                  data_type      = getTensorType();

    // vocab size 8 & length 3
    T* test_input_logits;

    cudaStream_t                            stream;
    ft::Allocator* allocator;
    cublasHandle_t                          cublas_handle;
    cublasLtHandle_t                        cublaslt_handle;
    std::mutex*                             cublas_wrapper_mutex;
    cublasMMWrapper*                        cublas_wrapper;
    DynamicDecodeLayer*                  dynamic_decode_layer;

    int*   h_output_ids;
    T*     h_logits;
    T*     h_probs;
    T*     h_log_probs;
    float* h_cum_log_probs;
    float* h_output_log_probs;

    T*                  d_logits;
    int*                d_input_lengths;
    float*              d_cum_log_probs;
    float*              d_output_log_probs;
    int*                d_output_ids;
    int*                d_end_ids;
    curandState_t*      d_curand_state;
    unsigned long long* d_random_seed;

    void setup(unsigned long long seed = 0)
    {
        this->seed = seed;

        check_cuda_error(cudaStreamCreate(&stream));
        allocator = new Allocator(getDevice());
        allocator->setStream(stream);

        struct cudaDeviceProp prop;
        check_cuda_error(cudaGetDeviceProperties(&prop, 0));
        check_cuda_error(cublasCreate(&cublas_handle));
        check_cuda_error(cublasLtCreate(&cublaslt_handle));
        check_cuda_error(cublasSetStream(cublas_handle, stream));
        cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
        cublas_wrapper_mutex = new std::mutex();

        cublas_wrapper = new cublasMMWrapper(
            cublas_handle, cublaslt_handle, stream, &cublas_algo_map, cublas_wrapper_mutex, allocator);

        dynamic_decode_layer = new DynamicDecodeLayer(vocab_size,
                                                         vocab_size,
                                                         stream,
                                                         cublas_wrapper,
                                                         allocator,
                                                         false,   // is_free_buffer_after_forward
                                                         &prop);  // cuda_device_prop

        h_output_ids       = new int[batchxbeam];
        h_logits           = new T[batchxbeam * vocab_size];
        h_probs            = new T[batchxbeam * vocab_size];
        h_log_probs        = new T[batchxbeam * vocab_size];
        h_cum_log_probs    = new float[batchxbeam];
        h_output_log_probs = new float[max_output_len * batchxbeam];

        // prob = (0.4, 0.3, 0.2, 0.1, ...)
        test_input_logits = new T[24]{
            -0.9163,  -1.2040,  -1.6094,  -2.3026,  -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX,  // step 0
            -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163,  -1.2040,  -1.6094,  -2.3026,   // step 1
            -FLT_MAX, -FLT_MAX, -0.9163,  -1.2040,  -1.6094,  -2.3026,  -FLT_MAX, -FLT_MAX   // step 2
        };

        d_logits           = reinterpret_cast(allocator->malloc(sizeof(T) * batchxbeam * vocab_size, true));
        d_input_lengths    = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam));
        d_cum_log_probs    = reinterpret_cast(allocator->malloc(sizeof(float) * batchxbeam));
        d_output_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * max_output_len * batchxbeam));
        d_output_ids       = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam));
        d_end_ids          = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam));
        d_curand_state     = reinterpret_cast(allocator->malloc(sizeof(curandState_t) * batch_size));
        d_random_seed =
            reinterpret_cast(allocator->malloc(sizeof(unsigned long long) * batch_size));

        // Init by zero.
        cudaMemset(d_cum_log_probs, 0, sizeof(float) * batchxbeam);
        cudaMemset(d_output_log_probs, 0, sizeof(float) * max_output_len * batchxbeam);
        cudaMemset(d_output_ids, 0, sizeof(int) * max_seq_len * batchxbeam);
        cudaMemset(d_random_seed, 0, sizeof(unsigned long long) * batch_size);
        invokeCurandBatchInitialize(d_curand_state, batch_size, d_random_seed, stream);
        deviceFill(d_end_ids, batchxbeam, end_id, stream);
    }

    void teardown()
    {
        delete[] test_input_logits;
        delete[] h_output_ids;
        delete[] h_logits;
        delete[] h_probs;
        delete[] h_log_probs;
        delete[] h_cum_log_probs;
        delete[] h_output_log_probs;
        delete dynamic_decode_layer;
        delete cublas_wrapper;
        delete cublas_wrapper_mutex;
        delete allocator;
        check_cuda_error(cublasDestroy(cublas_handle));
        check_cuda_error(cublasLtDestroy(cublaslt_handle));
        check_cuda_error(cudaStreamDestroy(stream));
    }

    TensorMap* createInputTensors(
        int* topk, size_t topk_size, float* topp, size_t topp_size, float* temperature, float* repetition_penalty)
    {
        // construct common input tensors
        TensorMap* input_tensors = new TensorMap();
        if (topk != nullptr) {
            input_tensors->insert({"runtime_top_k", {MEMORY_CPU, TYPE_INT32, {topk_size}, topk}});
        }
        if (topp != nullptr) {
            input_tensors->insert({"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {topp_size}, topp}});
        }
        if (temperature != nullptr) {
            input_tensors->insert({"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, temperature}});
        }
        if (repetition_penalty != nullptr) {
            input_tensors->insert({"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, repetition_penalty}});
        }
        input_tensors->insert(
            {"logits", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size}, d_logits}});
        input_tensors->insert({"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}});
        input_tensors->insert({"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}});
        input_tensors->insert(
            {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}});
        input_tensors->insert({"end_id", Tensor{MEMORY_CPU, TYPE_INT32, {batchxbeam}, &d_end_ids}});
        input_tensors->insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, {1}, &seed}});
        return input_tensors;
    }

    TensorMap* createOutputTensors()
    {
        // construct common output tensors
        TensorMap* output_tensors = new TensorMap();
        output_tensors->insert(
            {"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}});
        output_tensors->insert({"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}});
        output_tensors->insert(
            {"cum_log_probs", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width}, d_cum_log_probs}});
        output_tensors->insert(
            {"output_log_probs",
             Tensor{MEMORY_GPU, TYPE_FP32, {max_seq_len, batch_size, beam_width}, d_output_log_probs}});
        output_tensors->insert({"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}});
        output_tensors->insert({"curand_state"}, {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state});
        return output_tensors;
    }

    void batchH2Dcpy(T* dst, T* src, size_t m, size_t n)
    {
        for (size_t i = 0; i < m; ++i) {
            cudaH2Dcpy(dst + i * n, src, n);
        }
    }

    bool checkResult(int* d_output_ids, std::vector>& expected_ids)
    {
        assert(expected_ids.size() == max_seq_len * batchxbeam);
        int* h_output_ids = new int[max_seq_len * batchxbeam];
        cudaD2Hcpy(h_output_ids, d_output_ids, max_seq_len * batchxbeam);
        int failures = 0;
        for (size_t i = 0; i < max_seq_len * batchxbeam; ++i) {
            size_t        s     = i / batchxbeam;
            size_t        b     = i % batchxbeam;
            std::set expts = expected_ids.at(i);
            if (expts.count(h_output_ids[i]) == 0) {
                if (failures < 10) {
                    std::stringstream ss;
                    ss << " - Fail "
                       << " (step=" << s << ", batch=" << b << ") "
                       << "actual=" << h_output_ids[i] << ", expected";
                    for (auto& expt : expts) {
                        ss << " " << expt;
                    }
                    TM_LOG_DEBUG("%s", ss.str().c_str());
                }
                ++failures;
            }
        }
        TM_LOG_DEBUG(
            "check...%6s : failures: %d / %d", failures == 0 ? "....OK" : "FAILED", failures, max_seq_len * batchxbeam);
        delete[] h_output_ids;
        return failures == 0;
    }

public:
    void runTest(std::vector> expected_output_ids,
                 int*                       top_ks,
                 size_t                     top_k_size,
                 float*                     top_ps,
                 size_t                     top_p_size,
                 float*                     temperature,
                 float*                     repetition_penalty,
                 bool                       use_local_batch = false)
    {
        size_t local_batch_size = use_local_batch ? batch_size / 3 : batch_size;
        uint   ite              = use_local_batch ? 1 : 0;
        for (unsigned long long seed = 0; seed < max_seed; ++seed) {
            this->setup(seed);
            size_t     step = max_input_len;
            TensorMap* input_tensors =
                createInputTensors(top_ks, top_k_size, top_ps, top_p_size, temperature, repetition_penalty);
            input_tensors->insert({"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}});
            input_tensors->insert({"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}});
            input_tensors->insert({"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}});
            TensorMap* output_tensors = createOutputTensors();

            dynamic_decode_layer->setup(batch_size, beam_width, input_tensors);
            for (step = max_input_len; step < max_output_len; ++step) {
                // Reset by the test value since the sampling layer internally update the logit buffer.
                batchH2Dcpy(input_tensors->at("logits").getPtr(),
                            test_input_logits + step * vocab_size,
                            batchxbeam,
                            vocab_size);
                dynamic_decode_layer->forward(output_tensors, input_tensors);
            }
            bool passed = checkResult(d_output_ids, expected_output_ids);
            EXPECT_TRUE(passed) << "Failed at seed " << seed;
#ifndef NDEBUG
            if (!passed) {
                TM_LOG_ERROR("actual output ids");
                printMatrix(d_output_ids, max_seq_len, batch_size, batch_size, true);
            }
#endif
            delete output_tensors;
            delete input_tensors;
            this->teardown();
        }
    }
};

TYPED_TEST_SUITE(SamplingDecodeTest, SamplingTypes);

TYPED_TEST(SamplingDecodeTest, TopK)
{
    int                        top_k = 2;
    std::vector> expected_output_ids{
        // batch
        //  0       1       2       3       4       5
        {0, 1},
        {0, 1},
        {0, 1},
        {0, 1},
        {0, 1},
        {0, 1},  // step 0
        {4, 5},
        {4, 5},
        {4, 5},
        {4, 5},
        {4, 5},
        {4, 5},  // step 1
        {2, 3},
        {2, 3},
        {2, 3},
        {2, 3},
        {2, 3},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, BatchTopK)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{2, 1, 1, 2, 1, 1};
    std::vector> expected_output_ids{
        // batch
        //  0    1    2       3    4    5
        {0, 1},
        {0},
        {0},
        {0, 1},
        {0},
        {0},  // step 0
        {4, 5},
        {4},
        {4},
        {4, 5},
        {4},
        {4},  // step 1
        {2, 3},
        {2},
        {2},
        {2, 3},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr);
    delete[] top_ks;
}

TYPED_TEST(SamplingDecodeTest, TopP)
{
    float                      top_p = 0.3;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, BatchTopP)
{
    size_t                     batch_size = this->batch_size;
    float*                     top_ps     = new float[batch_size]{0.3f, 0.5f, 0.5f, 0.3f, 0.5f, 0.5f};
    std::vector> expected_output_ids{
        {0},
        {0, 1},
        {0, 1},
        {0},
        {0, 1},
        {0, 1},  // step 0
        {4},
        {4, 5},
        {4, 5},
        {4},
        {4, 5},
        {4, 5},  // step 1
        {2},
        {2, 3},
        {2, 3},
        {2},
        {2, 3},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, TopKTopP)
{
    int                        top_k = 2;
    float                      top_p = 0.3;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, BatchTopKTopP)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{2, 2, 1, 2, 2, 1};
    float                      top_p      = 0.3;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);
    delete[] top_ks;
}

TYPED_TEST(SamplingDecodeTest, TopKBatchTopP)
{
    size_t                     batch_size = this->batch_size;
    int                        top_k      = 2;
    float*                     top_ps     = new float[batch_size]{0.5, 0.3, 0.5, 0.5, 0.3, 0.5};
    std::vector> expected_output_ids{
        // batch
        {0, 1},
        {0},
        {0, 1},
        {0, 1},
        {0},
        {0, 1},  // step 0
        {4, 5},
        {4},
        {4, 5},
        {4, 5},
        {4},
        {4, 5},  // step 1
        {2, 3},
        {2},
        {2, 3},
        {2, 3},
        {2},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, BatchTopKBatchTopP)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{2, 2, 0, 2, 2, 0};
    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5};
    std::vector> expected_output_ids{
        // batch
        {0, 1},
        {0},
        {0, 1},
        {0, 1},
        {0},
        {0, 1},  // step 0
        {4, 5},
        {4},
        {4, 5},
        {4, 5},
        {4},
        {4, 5},  // step 1
        {2, 3},
        {2},
        {2, 3},
        {2, 3},
        {2},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ks;
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopK)
{
    size_t                     batch_size = this->batch_size;
    int                        top_k      = 0;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopP)
{
    size_t                     batch_size = this->batch_size;
    float                      top_p      = 0;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopKTopP)
{
    size_t                     batch_size = this->batch_size;
    int                        top_k      = 0;
    float                      top_p      = 0;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr);
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsZeroBatchTopKTopP)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{0, 0, 0, 0, 0, 0};
    float                      top_p      = 0;
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);
    delete[] top_ks;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopKBatchTopP)
{
    size_t                     batch_size = this->batch_size;
    int                        top_k      = 0;
    float*                     top_ps     = new float[batch_size]{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0},
        {0},
        {0},  // step 0
        {4},
        {4},
        {4},
        {4},
        {4},
        {4},  // step 1
        {2},
        {2},
        {2},
        {2},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKContainZero)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{2, 1, 0, 0, 2, 1};
    std::vector> expected_output_ids{
        // batch
        {0, 1},
        {0},
        {0},
        {0},
        {0, 1},
        {0},  // step 0
        {4, 5},
        {4},
        {4},
        {4},
        {4, 5},
        {4},  // step 1
        {2, 3},
        {2},
        {2},
        {2},
        {2, 3},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr);
    delete[] top_ks;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopPContainZero)
{
    size_t                     batch_size = this->batch_size;
    float*                     top_ps     = new float[batch_size]{0.5f, 0.5f, 0.0f, 0.5f, 0.0f, 0.3f};
    std::vector> expected_output_ids{
        // batch
        {0, 1},
        {0, 1},
        {0},
        {0, 1},
        {0},
        {0},  // step 0
        {4, 5},
        {4, 5},
        {4},
        {4, 5},
        {4},
        {4},  // step 1
        {2, 3},
        {2, 3},
        {2},
        {2, 3},
        {2},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKTopPContainZero)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{2, 2, 1, 0, 2, 0};
    float                      top_p      = 0.0;
    std::vector> expected_output_ids{
        // batch
        {0, 1},
        {0, 1},
        {0},
        {0},
        {0, 1},
        {0},  // step 0
        {4, 5},
        {4, 5},
        {4},
        {4},
        {4, 5},
        {4},  // step 1
        {2, 3},
        {2, 3},
        {2},
        {2},
        {2, 3},
        {2}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);
    delete[] top_ks;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsTopKBatchTopPContainZero)
{
    size_t                     batch_size = this->batch_size;
    int                        top_k      = 0;
    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5};
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0, 1},
        {0},
        {0},
        {0, 1},  // step 0
        {4},
        {4},
        {4, 5},
        {4},
        {4},
        {4, 5},  // step 1
        {2},
        {2},
        {2, 3},
        {2},
        {2},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ps;
}

TYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKBatchTopPContainZero)
{
    size_t                     batch_size = this->batch_size;
    int*                       top_ks     = new int[batch_size]{0, 2, 1, 2, 2, 0};
    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.9, 0.0, 0.3, 0.5};
    std::vector> expected_output_ids{
        // batch
        {0},
        {0},
        {0},
        {0, 1},
        {0},
        {0, 1},  // step 0
        {4},
        {4},
        {4},
        {4, 5},
        {4},
        {4, 5},  // step 1
        {2},
        {2},
        {2},
        {2, 3},
        {2},
        {2, 3}  // step 2
    };
    this->runTest(expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr);
    delete[] top_ks;
    delete[] top_ps;
}

template
class SamplingDecodeTest2: public FtTestBase {

public:
    void SetUp() override
    {
        FtTestBase::SetUp();
        check_cuda_error(cudaGetDeviceProperties(&prop, 0));
        check_cuda_error(cublasCreate(&cublas_handle));
        check_cuda_error(cublasLtCreate(&cublaslt_handle));
        check_cuda_error(cublasSetStream(cublas_handle, stream));
        cublas_algo_map      = new cublasAlgoMap("");
        cublas_wrapper_mutex = new std::mutex();
        cublas_wrapper       = new cublasMMWrapper(
            cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, allocator);
    }
    void TearDown() override
    {
        delete cublas_wrapper;
        delete cublas_wrapper_mutex;
        delete cublas_algo_map;
        check_cuda_error(cublasLtDestroy(cublaslt_handle));
        check_cuda_error(cublasDestroy(cublas_handle));
        FtTestBase::TearDown();
    }

protected:
    using FtTestBase::stream;
    using FtTestBase::allocator;

    struct cudaDeviceProp prop;
    cublasHandle_t        cublas_handle;
    cublasLtHandle_t      cublaslt_handle;
    cublasAlgoMap*        cublas_algo_map;
    std::mutex*           cublas_wrapper_mutex;
    cublasMMWrapper*      cublas_wrapper;

    DataType data_type = getTensorType();

    size_t batch_size;
    size_t beam_width;
    size_t batchxbeam;
    size_t vocab_size;
    size_t max_input_len;
    size_t max_output_len;
    size_t max_seq_len;

    uint  top_k;
    float top_p;
    float temperature;
    float repetition_penalty;
    int   end_id;

    T*     h_logits;
    T*     h_probs;
    T*     h_log_probs;
    float* h_cum_log_probs;
    float* h_output_log_probs;
    int*   h_output_ids;

    T*                  d_logits;
    int*                d_input_lengths;
    float*              d_cum_log_probs;
    float*              d_output_log_probs;
    int*                d_output_ids;
    int*                d_end_ids;
    curandState_t*      d_curand_state;
    unsigned long long* d_random_seed;

    void setup(SamplingLayerTestParam param)
    {
        batch_size     = param.batch_size;
        beam_width     = param.beam_width;
        batchxbeam     = batch_size * param.beam_width;
        vocab_size     = param.vocab_size;
        max_input_len  = 0;
        max_output_len = param.output_len;
        max_seq_len    = max_input_len + max_output_len;

        top_k = param.top_k;
        top_p = param.top_p;
        // use default values having no effect.
        temperature        = 1.0f;
        repetition_penalty = 1.0f;
        end_id             = 0;

        h_logits     = new T[batchxbeam * vocab_size];
        h_output_ids = new int[batchxbeam];

        d_logits        = reinterpret_cast(allocator->malloc(sizeof(T) * batchxbeam * vocab_size));
        d_input_lengths = reinterpret_cast(allocator->malloc(sizeof(int) * batchxbeam));
        d_output_ids    = reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam));
        d_end_ids       = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size));
        d_curand_state  = reinterpret_cast(allocator->malloc(sizeof(curandState_t) * batch_size));
        d_random_seed =
            reinterpret_cast(allocator->malloc(sizeof(unsigned long long) * batch_size));

        // Init by zero.
        deviceFill(d_input_lengths, batchxbeam, 0, stream);
        deviceFill(d_output_ids, max_seq_len * batchxbeam, 0, stream);
        deviceFill(d_end_ids, batch_size, end_id);
        cudaMemset(d_random_seed, 0, sizeof(unsigned long long) * batch_size);
    }

    void teardown()
    {
        delete[] h_logits;
        delete[] h_output_ids;
    }

    void runCurandTest(SamplingLayerTestParam param, bool use_local_batch, bool use_single_random_seed)
    {
        setup(param);
        const DataType data_type = getTensorType();

        const size_t local_batch_size = use_local_batch ? 3 : batch_size;
        assert(batch_size % local_batch_size == 0);

        DynamicDecodeLayer* dynamic_decode_layer = new DynamicDecodeLayer(vocab_size,
                                                                                vocab_size,
                                                                                stream,
                                                                                cublas_wrapper,
                                                                                allocator,
                                                                                false,   // is_free_buffer_after_forward
                                                                                &prop);  // cuda_device_prop

        // Prepare decoding arguments
        const size_t        random_seed_size = use_single_random_seed ? 1 : batch_size;
        const size_t        period_size      = 3;
        unsigned long long* random_seed      = new unsigned long long[random_seed_size];
        for (size_t i = 0; i < random_seed_size; ++i) {
            random_seed[i] = i / period_size;
        }
        cudaH2Dcpy(d_random_seed, random_seed, random_seed_size);
        if (use_single_random_seed) {
            invokeCurandInitialize(d_curand_state, batch_size, random_seed[0], stream);
        }
        else {
            invokeCurandBatchInitialize(d_curand_state, batch_size, d_random_seed, stream);
        }
        sync_check_cuda_error();

        TensorMap runtime_args;
        runtime_args.insert({"random_seed", Tensor(MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed)});
        runtime_args.insert({"runtime_top_k", Tensor(MEMORY_CPU, TYPE_UINT32, {1}, &top_k)});
        runtime_args.insert({"runtime_top_p", Tensor(MEMORY_CPU, TYPE_FP32, {1}, &top_p)});
        dynamic_decode_layer->setup(batch_size, beam_width, &runtime_args);

        for (size_t step = max_input_len; step < max_output_len; ++step) {
            const size_t iteration_num = batch_size / local_batch_size;
            initRandom(h_logits, beam_width * vocab_size, -3.0f, 3.0f);
            tile(h_logits, batch_size, beam_width * vocab_size);
            cudaH2Dcpy(d_logits, h_logits, batchxbeam * vocab_size);

            for (uint ite = 0; ite < iteration_num; ++ite) {
                TensorMap dynamic_decode_input_tensors(
                    {{"logits", Tensor{MEMORY_GPU, data_type, {batch_size, beam_width, vocab_size}, d_logits}},
                     {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}},
                     {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}},
                     {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}},
                     {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}},
                     {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}},
                     {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}},
                     {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, d_end_ids}},
                     {"random_seed", {MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed}},
                     {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},
                     {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}}});

                // common outputs
                TensorMap dynamic_decode_output_tensors(
                    {{"output_ids",
                      Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}},
                     {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}},
                     {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}},
                     {"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state}}});

                dynamic_decode_layer->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
                sync_check_cuda_error();

                // check results.
                cudaD2Hcpy(h_output_ids,
                           dynamic_decode_output_tensors.at("output_ids").getPtrWithOffset(step * batchxbeam),
                           batchxbeam);
            }
            // The same seed produces the same random number.
            for (size_t i = 0; i + period_size - 1 < batchxbeam; i += period_size) {
                for (size_t j = 1; j < period_size; ++j) {
                    EXPECT_TRUE(h_output_ids[i] == h_output_ids[i + j])
                        << fmtstr("Fail at step %u val[%d]=%d <> val[%d]=%d",
                                  step,
                                  i,
                                  h_output_ids[i],
                                  i + j,
                                  h_output_ids[i + j]);
                }
            }
        }
        delete dynamic_decode_layer;
        delete[] random_seed;
        teardown();
    }

    void runCumLogProbTest(SamplingLayerTestParam param)
    {
        setup(param);
        unsigned long long     seed                 = 43;
        const DataType         data_type            = getTensorType();
        DynamicDecodeLayer* dynamic_decode_layer = new DynamicDecodeLayer(vocab_size,
                                                                                vocab_size,
                                                                                stream,
                                                                                cublas_wrapper,
                                                                                allocator,
                                                                                false,   // is_free_buffer_after_forward
                                                                                &prop);  // cuda_device_prop

        // Logit values in the host of shape ((batch_size x beam) x vocab_size) where beam = 1.
        // T* h_logits = new T[batch_size * beam_width * vocab_size];
        T*     h_probs                = new T[batch_size * beam_width * vocab_size];
        T*     h_log_probs            = new T[batch_size * beam_width * vocab_size];
        float* h_cum_log_probs        = new float[batch_size * beam_width];
        float* h_output_log_probs     = new float[max_output_len * batch_size * beam_width];
        float* expected_cum_log_probs = new float[batch_size * beam_width];
        initRandom(h_logits, batch_size * beam_width * vocab_size, -3.0f, 3.0f);
        computeProb(h_probs, h_logits, batch_size * beam_width, vocab_size);
        computeLogProb(h_log_probs, h_logits, batch_size * beam_width, vocab_size);
        std::fill_n(expected_cum_log_probs, batch_size * beam_width, 0);

        int* tiled_input_lengths_buf = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size * beam_width));
        float* cum_log_probs = reinterpret_cast(allocator->malloc(sizeof(float) * batch_size * beam_width));
        float* output_log_probs =
            reinterpret_cast(allocator->malloc(sizeof(float) * max_output_len * batch_size * beam_width));

        int* output_ids =
            reinterpret_cast(allocator->malloc(sizeof(int) * max_seq_len * batch_size * beam_width));
        int* h_output_ids = new int[batch_size * beam_width];

        int* end_ids = reinterpret_cast(allocator->malloc(sizeof(int) * batch_size));
        deviceFill(end_ids, batch_size, end_id);

        // Init by zero.
        cudaMemset(cum_log_probs, 0, sizeof(float) * batch_size * beam_width);
        cudaMemset(output_log_probs, 0, sizeof(float) * max_output_len * batch_size * beam_width);
        cudaMemset(output_ids, 0, sizeof(int) * max_seq_len * batch_size * beam_width);

        TensorMap input_tensors({{"random_seed", {MEMORY_CPU, TYPE_INT32, {1}, &seed}},
                                 {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},
                                 {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}},
                                 {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}},
                                 {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}}});
        dynamic_decode_layer->setup(batch_size, beam_width, &input_tensors);

        for (size_t step = max_input_len; step < max_output_len; ++step) {
            uint ite = 0;
            // Reset by the test value since the sampling layer internally update the logit buffer (making it log-prob).
            cudaH2Dcpy(d_logits, h_logits, batch_size * beam_width * vocab_size);
            TensorMap dynamic_decode_input_tensors(
                {{"logits", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size}, d_logits}},
                 {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}},
                 {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}},
                 {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}},
                 {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf}},
                 {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}},
                 {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &batch_size}},
                 {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}},
                 {"random_seed", {MEMORY_CPU, TYPE_UINT64, {1}, &seed}},
                 {"runtime_top_k", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},
                 {"runtime_top_p", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}},
                 {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}},
                 {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}}});

            // common outputs
            TensorMap dynamic_decode_output_tensors(
                {{"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids}},
                 {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}},
                 {"cum_log_probs", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width}, cum_log_probs}},
                 {"output_log_probs",
                  Tensor{MEMORY_GPU, TYPE_FP32, {max_seq_len, batch_size, beam_width}, output_log_probs}},
                 {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}},
                 {"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state}}});

            dynamic_decode_layer->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);

            TM_LOG_DEBUG("Step %2d generated ids", step);
            cudaD2Hcpy(
                h_output_ids,
                dynamic_decode_output_tensors.at("output_ids").getPtrWithOffset(step * (batch_size * beam_width)),
                batch_size * beam_width);
            cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size * beam_width);
            cudaD2Hcpy(h_output_log_probs, output_log_probs, max_output_len * batch_size * beam_width);
            for (size_t i = 0; i < batch_size * beam_width; ++i) {
                int idx = i * vocab_size + h_output_ids[i];
                expected_cum_log_probs[i] += (float)h_log_probs[idx];
                TM_LOG_DEBUG("| step %2d batch %2d idx %7d id %6d | log-prob %9.4f (expt: %9.4f) "
                             "| cum-log-prob %9.4f (expt: %9.4f) | prob %9.4e",
                             (int)step,
                             (int)i,
                             (int)idx,
                             (int)h_output_ids[i],
                             h_output_log_probs[step * batch_size * beam_width + i],
                             (float)h_log_probs[idx],
                             h_cum_log_probs[i],
                             expected_cum_log_probs[i],
                             (float)h_probs[idx]);
            }
            TM_LOG_DEBUG("");
        }

        bool passed = checkResult(param.toString(), cum_log_probs, expected_cum_log_probs, batch_size * beam_width);
        EXPECT_TRUE(passed);

        delete[] expected_cum_log_probs;
        delete[] h_output_log_probs;
        delete[] h_cum_log_probs;
        delete[] h_log_probs;
        delete[] h_probs;

        delete dynamic_decode_layer;
    }
};

TYPED_TEST_SUITE(SamplingDecodeTest2, SamplingTypes);

TYPED_TEST(SamplingDecodeTest2, CorrectnessSingleRandTopK)
{
    // test TopKSampling
    this->runCurandTest({113, 1201, 1, 3, 1.0f, 5}, false, true);
}

TYPED_TEST(SamplingDecodeTest2, CorrectnessSingleRandTopP)
{
    this->runCurandTest({113, 1201, 1, 0, 1.0f, 5}, false, true);
}

TYPED_TEST(SamplingDecodeTest2, CorrectnessBatchRandTopK)
{
    // test TopKSampling
    this->runCurandTest({113, 1201, 1, 3, 1.0f, 5}, false, false);
}

TYPED_TEST(SamplingDecodeTest2, CorrectnessBatchRandTopP)
{
    this->runCurandTest({113, 1201, 1, 0, 1.0f, 5}, false, false);
}


================================================
FILE: tests/csrc/unittests/unittest_utils.h
================================================
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include   // min, max
#include    // assert
#include     // FLT_MAX
#include    // snprintf
#include      // numeric_limits
#include      // expf, log
#include    // rand
#include      // string
#include      // vector

#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/string_utils.h"

#define PRINT_LIMIT 16
#define EPSILON (1e-20)
#define EPSILON_FP16 (1e-10)

using namespace turbomind;

class TestFailureError: public std::exception {
private:
    std::string msg_;

public:
    explicit TestFailureError() = default;
    explicit TestFailureError(std::string name, std::string msg = "")
    {
        msg_ = fmtstr("TEST FAIL [%s] %s", name.c_str(), msg.c_str());
    }
    const char* what() const throw()
    {
        return msg_.c_str();
    }
};

#define EXPECT_TRUE(cond)                                                                                              \
    do {                                                                                                               \
        if (!(cond)) {                                                                                                 \
            TM_LOG_ERROR("TEST FAIL [%s]: %s at %s:%d", __func__, #cond, __FILE__, __LINE__);                          \
            throw TestFailureError(__func__);                                                                          \
        }                                                                                                              \
    } while (false)

#define EXPECT_FALSE(cond)                                                                                             \
    do {                                                                                                               \
        if (cond) {                                                                                                    \
            TM_LOG_ERROR("TEST FAIL [%s]: %s at %s:%d", __func__, #cond, __FILE__, __LINE__);                          \
            throw TestFailureError(__func__);                                                                          \
        }                                                                                                              \
    } while (false)

bool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8)
{
    // Params: a = value to compare and b = reference
    // This function follows implementation of numpy.isclose(), which checks
    //   abs(a - b) <= (atol + rtol * abs(b)).
    // Note that the inequality above is asymmetric where b is considered as
    // a reference value. To account into both absolute/relative errors, it
    // uses absolute tolerance and relative tolerance at the same time. The
    // default values of atol and rtol borrowed from numpy.isclose(). For the
    // case of nan value, the result will be true.
    if (isnan(a) && isnan(b)) {
        return true;
    }
    return fabs(a - b) <= (atol + rtol * fabs(b));
}

template
bool checkResult(std::string name, T* out, T* ref, size_t size, float atol, float rtol)
{
    size_t failures     = 0;
    float  relative_gap = 0.0f;
    ;

    for (size_t i = 0; i < size; ++i) {
        // The values for the output and the reference.
        float a = (float)out[i];
        float b = (float)ref[i];

        bool ok = almostEqual(a, b, atol, rtol);
        // Print the error.
        if (!ok && failures < 4) {
            TM_LOG_ERROR(">> invalid result for i=%lu:", i);
            TM_LOG_ERROR(">>    found......: %10.6f", a);
            TM_LOG_ERROR(">>    expected...: %10.6f", b);
            TM_LOG_ERROR(">>    error......: %.6f", fabsf(a - b));
            TM_LOG_ERROR(">>    tol........: %.6f", atol + rtol * fabs(b));
        }
        // Update the number of failures.
        failures += ok ? 0 : 1;
        // Update the relative gap.
        relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON);
    }

    relative_gap /= size;

    // Allow not matched up to 1% elements.
    size_t tol_failures = (size_t)(0.01 * size);
    TM_LOG_INFO("check...%6s : %-50s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)",
                failures <= tol_failures ? "....OK" : "FAILED",
                name.c_str(),
                100. * failures / size,
                atol,
                rtol,
                100. * relative_gap);
    return failures <= tol_failures;
}

template
bool checkResult(std::string name, T* out, T* ref, size_t size, bool device_out = true, bool device_ref = false)
{
    bool  is_fp32 = sizeof(T) == 4;
    float atol    = is_fp32 ? 1e-4f : 1e-3f;
    float rtol    = is_fp32 ? 1e-2f : 1e-1f;

    T* h_out = nullptr;
    if (device_out) {
        h_out = new T[size];
        cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost);
        out = h_out;
    }
    T* h_ref = nullptr;
    if (device_ref) {
        h_ref = new T[size];
        cudaMemcpy(h_ref, ref, sizeof(T) * size, cudaMemcpyDeviceToHost);
        ref = h_ref;
    }
    bool is_ok = checkResult(name, out, ref, size, atol, rtol);
    if (h_out != nullptr) {
        delete[] h_out;
    }
    if (h_ref != nullptr) {
        delete[] h_ref;
    }
    return is_ok;
}

template
void initRandom(T* ptr, size_t size, float minval, float maxval)
{
    for (size_t i = 0; i < size; ++i) {
        float val = static_cast(rand()) / static_cast(RAND_MAX);
        val *= (maxval - minval);
        ptr[i] = static_cast(minval + val);
    }
}

void initRandomInt(int* ptr, size_t size, int minval, int maxval)
{
    assert(minval < maxval);
    int mod = maxval - minval;
    for (size_t i = 0; i < size; ++i) {
        ptr[i] = minval + rand() % mod;
    }
}

template
void tile(T* x, int m, int n)
{
    for (int i = 1; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            x[i * n + j] = x[j];
        }
    }
}

template
void tile(T* dst, T* src, int m, int n)
{
    for (int i = 1; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            dst[i * n + j] = src[j];
        }
    }
}

#define HALF_FLT_MAX 65504.0f

template
bool isHalf()
{
    return std::is_same::value;
}

template
static inline void printMatrixWithLimit(T* ptr, int m, int k, int stride, bool is_device_ptr)
{
    printMatrix(ptr, std::min(PRINT_LIMIT, m), std::min(PRINT_LIMIT, k), stride, is_device_ptr);
}


================================================
FILE: tests/pytorch/config/test_hf_overrides.py
================================================
import pytest


class TestHFOverrides:

    @pytest.fixture
    def hf_config(self):
        from transformers.models.llava import LlavaConfig
        yield LlavaConfig()

    def test_hf_overrides(self, hf_config):
        from lmdeploy.pytorch.config import override_hf_config

        # update root
        assert hf_config.model_type == 'llava'
        overrides_dict = dict(model_type='llava_custom', )
        override_hf_config(hf_config, overrides_dict)
        assert hf_config.model_type == 'llava_custom'

        # update rope_parameters (renamed from rope_scaling in newer transformers)
        assert hf_config.text_config.model_type == 'llama'
        assert hf_config.text_config.rope_parameters['rope_type'] == 'default'
        overrides_dict = dict(text_config=dict(rope_parameters=dict(rope_type='yarn', )))
        override_hf_config(hf_config, overrides_dict)
        assert hf_config.text_config.model_type == 'llama'
        assert hf_config.text_config.rope_parameters['rope_type'] == 'yarn'

        # update both
        overrides_dict = dict(model_type='llava_custom2', text_config=dict(rope_parameters=dict(rope_type='yarn2', )))
        override_hf_config(hf_config, overrides_dict)
        assert hf_config.model_type == 'llava_custom2'
        assert hf_config.text_config.model_type == 'llama'
        assert hf_config.text_config.rope_parameters['rope_type'] == 'yarn2'


================================================
FILE: tests/pytorch/engine/test_logits_process.py
================================================
# yapf: disable
import torch
from transformers.generation.logits_process import (MinPLogitsWarper, RepetitionPenaltyLogitsProcessor,
                                                    TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)

# yapf: enable


def test_process_temperature():
    from lmdeploy.pytorch.engine.logits_process import _process_temperature_

    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens)
    temperatures = torch.rand(batch_size)

    gt = []
    for score, temperature in zip(scores, temperatures):
        warper = TemperatureLogitsWarper(temperature.item())
        gt.append(warper(None, score[None]))
    gt = torch.cat(gt)

    out = _process_temperature_(scores, temperatures)
    torch.testing.assert_close(out, gt)


def test_process_bad_words():
    from lmdeploy.pytorch.engine.logits_process import _process_bad_words_

    filter_value: float = -float('inf')
    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens)
    bad_words = torch.tensor([
        [0, 1],
        [3, -1],
        [4, 4],
        [-1, -1],
    ])
    mask = bad_words >= 0

    out_scores = _process_bad_words_(scores, bad_words, mask)

    for score, bw in zip(out_scores, bad_words):
        bw = bw.tolist()

        for w in bw:
            if w >= 0:
                assert score[w] == filter_value


def test_processrepetition_penalty():
    from lmdeploy.pytorch.engine.logits_process import _process_repetition_penalty_
    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens)
    input_ids = torch.tensor([
        [0, 1],
        [3, 6],
        [4, 4],
        [0, 0],
    ])
    penalties = 1 + torch.rand(batch_size)

    gt = []
    for score, ids, penalty in zip(scores, input_ids, penalties):
        warper = RepetitionPenaltyLogitsProcessor(penalty.item())
        gt.append(warper(ids[None], score[None].clone()))
    gt = torch.cat(gt)

    out = _process_repetition_penalty_(scores, input_ids, penalties)
    torch.testing.assert_close(out, gt)


def test_filter_topk_sorted():
    from lmdeploy.pytorch.engine.logits_process import _filter_topk_sorted_

    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]
    top_k = torch.randint(4, num_tokens - 4, (batch_size, ))

    gt = []
    for score, k in zip(scores, top_k):
        warper = TopKLogitsWarper(k.item())
        gt.append(warper(None, score[None].clone()))
    gt = torch.cat(gt)

    out = _filter_topk_sorted_(scores, top_k)
    torch.testing.assert_close(out, gt)


def test_filter_topp_sorted():
    from lmdeploy.pytorch.engine.logits_process import _filter_topp_sorted_

    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]
    top_p = torch.rand(batch_size)

    gt = []
    for score, p in zip(scores, top_p):
        warper = TopPLogitsWarper(p.item())
        gt.append(warper(None, score[None].clone()))
    gt = torch.cat(gt)

    out = _filter_topp_sorted_(scores, top_p)
    torch.testing.assert_close(out, gt)


def test_filter_minp_sorted():
    from lmdeploy.pytorch.engine.logits_process import _filter_minp_sorted_

    batch_size = 4
    num_tokens = 16
    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]
    min_p = torch.rand(batch_size)

    gt = []
    for score, p in zip(scores, min_p):
        warper = MinPLogitsWarper(p.item())
        gt.append(warper(None, score[None].clone()))
    gt = torch.cat(gt)

    out = _filter_minp_sorted_(scores, min_p)
    torch.testing.assert_close(out, gt)


def test_filter_ngram():
    from lmdeploy.pytorch.engine.logits_process import _filter_repetition_ngram_
    vocab_size = 100

    def _get_emtas(n, window_size):
        batch_size = generated_ids.size(0)
        max_n = int(n.max().item())
        same_n = n.eq(max_n).all().item()
        max_window_size = window_size
        if same_n:
            n = None
        return batch_size, max_n, max_window_size, n

    # base test
    generated_ids = torch.tensor([
        [2, 3, 4, 1, 2, 3, 4, 2, 3, 4],
        [9, 8, 7, 3, 8, 7, 5, 9, 8, 7],
        [9, 8, 7, 3, 8, 7, 5, 9, 8, 7],
    ],
                                 dtype=torch.int64)
    n = torch.tensor([3, 3, 2], dtype=torch.int64)
    threshold = torch.tensor([3, 3, 3], dtype=torch.int64)

    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)
    scores = torch.rand(batch_size, vocab_size)
    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)
    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)

    assert not scores[1].isinf().any().item()
    assert scores[0].isinf().sum().item() == vocab_size - 1
    assert scores[2].isinf().sum().item() == vocab_size - 1
    assert scores[0, stop_words[0, 0]] == 0
    assert scores[2, stop_words[2, 0]] == 0

    # test no ngram
    generated_ids = torch.tensor([
        [2, 3, 4, 1, 2, 3, 4, 2, 3, 4],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ])
    n = torch.tensor([3, 0], dtype=torch.int64)
    threshold = torch.tensor([3, 0], dtype=torch.int64)
    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)

    scores = torch.rand(batch_size, vocab_size)
    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)
    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)
    assert not scores[1].isinf().any().item()
    assert scores[0].isinf().sum().item() == vocab_size - 1

    # test ids all 0
    generated_ids = torch.tensor([
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ])
    n = torch.tensor([3], dtype=torch.int64)
    threshold = torch.tensor([3], dtype=torch.int64)
    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)

    scores = torch.rand(batch_size, vocab_size)
    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)
    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)
    assert scores[0].isinf().sum().item() == vocab_size - 1


================================================
FILE: tests/pytorch/engine/test_request.py
================================================
# yapf: disable
import asyncio

import pytest

from lmdeploy.pytorch.engine.request import RequestManager, RequestType, ResponseType

# yapf: enable


class TestRequestHander:

    @pytest.fixture
    def event_loop(self):
        old_loop = asyncio.get_event_loop()
        new_loop = asyncio.new_event_loop()
        try:
            asyncio.set_event_loop(new_loop)
            yield new_loop
        finally:
            new_loop.stop()
            asyncio.set_event_loop(old_loop)

    @pytest.fixture
    def manager(self):
        yield RequestManager()

    def test_bind(self, manager, event_loop):

        def __stop_engine_callback(reqs, **kwargs):
            for req in reqs:
                resp = req.resp
                resp.type = ResponseType.SUCCESS
                resp.data = f'{req.data} success'
                manager.response(resp)

        async def __dummy_loop():
            while True:
                try:
                    await manager.step()
                except Exception:
                    return

        sender = manager.build_sender()
        manager.set_main_loop_func(__dummy_loop)

        # test not bind
        resp = sender.send_async(RequestType.STOP_ENGINE, None)
        resp = sender.recv(resp)
        assert resp.type == ResponseType.HANDLER_NOT_EXIST

        assert manager.is_loop_alive()

        # test bind success
        sender.send_async(RequestType.STOP_ENGINE, None)
        manager.bind_func(RequestType.STOP_ENGINE, __stop_engine_callback)
        resp = sender.send_async(RequestType.STOP_ENGINE, 'test')
        resp = sender.recv(resp)
        assert resp.data == 'test success'

        # cleanup, cancel main task
        task_to_cancel = manager._loop_task
        manager.stop_loop()
        asyncio.run
        event_loop.run_until_complete(asyncio.gather(task_to_cancel, return_exceptions=True))


================================================
FILE: tests/pytorch/engine/test_zmq_rpc.py
================================================
import asyncio
import multiprocessing as mp


class TestZMQRPC:

    def sub_proc(self, shared_dict=None, condition=None):
        from lmdeploy.pytorch.engine.mp_engine.zmq_rpc import AsyncRPCServer
        server = AsyncRPCServer()
        with condition:
            shared_dict['rpc_server_port'] = server.port
            condition.notify()

        async def streaming_method(name):
            for i in range(3):
                yield f'{name}: streaming method {i}'

        def method(name):
            return f'{name}: method'

        async def async_method(name):
            return f'{name}: async method'

        def close():
            print('close server...')
            server.stop()

        server.register_method('method', method)
        server.register_method('async_method', async_method)
        server.register_method('streaming_method', streaming_method)
        server.register_method('close', close)

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        asyncio.run(server.run())

    async def async_main(self, port):
        from lmdeploy.pytorch.engine.mp_engine.zmq_rpc import AsyncRPCClient
        client = AsyncRPCClient(port=port)

        loop = asyncio.get_event_loop()
        _ = loop.create_task(client.listen())

        # Example usage
        result = client.call('async_method', 'test2')
        assert result == 'test2: async method'
        result = await client.async_call('method', 'test1')
        assert result == 'test1: method'

        async for result in client.async_stream_call('streaming_method', 'test3'):
            pass
        assert result == 'test3: streaming method 2'

        await client.async_call('close')
        client.stop()

    def test_zmq_rpc(self):
        with mp.Manager() as manager:
            shared_dict = manager.dict()
            condition = manager.Condition()
            ctx = mp.get_context('spawn')
            proc = ctx.Process(target=self.sub_proc, args=(shared_dict, condition), daemon=True)
            proc.start()

            with condition:
                if 'rpc_server_port' not in shared_dict:
                    condition.wait()
            port = shared_dict['rpc_server_port']

        asyncio.run(self.async_main(port))

        proc.join()


================================================
FILE: tests/pytorch/kernel/test_activation.py
================================================
import pytest
import torch


class TestSiluAndMul:

    @pytest.fixture
    def seqlen(self, request):
        yield request.param

    @pytest.fixture
    def feat_size(self, request):
        yield request.param

    @pytest.fixture
    def x(self, seqlen, feat_size):
        yield torch.rand(seqlen, feat_size, dtype=torch.float16, device='cuda')

    @pytest.fixture
    def gt(self, x):
        gate, up = x.chunk(2, -1)
        gate = torch.nn.functional.silu(gate)
        yield gate * up

    @pytest.mark.parametrize('seqlen', [65536, 256], indirect=True)
    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)
    def test_silu_and_mul(self, x, gt):
        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul

        out = silu_and_mul(x)
        torch.testing.assert_close(out, gt)


class TestSiluAndMulMoEEP:

    @pytest.fixture
    def num_experts(self, request):
        yield request.param

    @pytest.fixture
    def seqlen(self, request):
        yield request.param

    @pytest.fixture
    def feat_size(self, request):
        yield request.param

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def x(self, num_experts, seqlen, feat_size, dtype):
        yield torch.rand(num_experts, seqlen, feat_size, dtype=dtype, device='cuda')

    @pytest.fixture
    def mask_m(self, num_experts, seqlen):
        mask_m = torch.randint(0, seqlen, (num_experts, ), device='cuda')
        yield mask_m

    @pytest.fixture
    def elem_mask(self, mask_m, seqlen):
        elem_mask = torch.arange(seqlen, device='cuda').unsqueeze(0) < mask_m.unsqueeze(1)
        yield elem_mask[..., None]

    @pytest.fixture
    def gt(self, x):
        gate, up = x.chunk(2, -1)
        gate = torch.nn.functional.silu(gate)
        yield gate * up

    @pytest.mark.parametrize('num_experts', [4], indirect=True)
    @pytest.mark.parametrize('seqlen', [1024], indirect=True)
    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)
    def test_silu_and_mul(self, x, mask_m, elem_mask, gt):
        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_moe_ep

        out = silu_and_mul_moe_ep(x, mask_m)
        out.masked_fill_(~elem_mask, 0.0)
        gt.masked_fill_(~elem_mask, 0.0)
        torch.testing.assert_close(out, gt)


================================================
FILE: tests/pytorch/kernel/test_apply_rotary.py
================================================
import pytest
import torch

from lmdeploy.utils import is_bf16_supported


def _rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def _bf16_mark():
    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')


class TestApplyRotary:

    @pytest.fixture
    def dtype(self, request):
        yield request.param

    @pytest.fixture
    def batch_size(self):
        yield 4

    @pytest.fixture
    def num_heads_q(self, request):
        yield request.param

    @pytest.fixture
    def num_heads_k(self, request):
        yield request.param

    @pytest.fixture
    def feature_dim(self):
        yield 128

    @pytest.fixture
    def seq_length(self, batch_size):
        yield torch.randint(8, 16, (batch_size, ), device='cuda')

    @pytest.fixture
    def max_seqlen(self, seq_length):
        yield seq_length.max()

    @pytest.fixture
    def q_states(self, seq_length, num_heads_q, feature_dim, dtype):
        yield torch.randn(seq_length.sum(), num_heads_q, feature_dim, dtype=dtype, device='cuda')

    @pytest.fixture
    def k_states(self, seq_length, num_heads_k, feature_dim, dtype):
        yield torch.randn(seq_length.sum(), num_heads_k, feature_dim, dtype=dtype, device='cuda')

    @pytest.fixture
    def position_ids_1d(self, seq_length, max_seqlen):
        yield torch.randint(0, max_seqlen.item(), (seq_length.sum().item(), ), device='cuda')

    @pytest.fixture
    def cached_cos(self, max_seqlen, feature_dim, dtype):
        yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')

    @pytest.fixture
    def cached_sin(self, max_seqlen, feature_dim, dtype):
        yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')

    @pytest.fixture
    def cos(self, cached_cos, position_ids_1d):
        yield cached_cos[position_ids_1d, None, :]

    @pytest.fixture
    def sin(self, cached_sin, position_ids_1d):
        yield cached_sin[position_ids_1d, None, :]

    @pytest.fixture
    def gt(self, q_states, k_states, cos, sin, position_ids_1d):

        q_embed = q_states * cos + _rotate_half(q_states) * sin
        k_embed = k_states * cos + _rotate_half(k_states) * sin

        yield q_embed, k_embed

    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16, torch.float32],
                             indirect=True)
    @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True)
    def test_apply_rotary(self, q_states, k_states, cos, sin, gt):
        from lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb
        q_embed, k_embed = apply_rotary_pos_emb(q_states, k_states, cos, sin)
        q_gt, k_gt = gt

        rtol = None
        atol = None
        torch.testing.assert_close(q_embed, q_gt, rtol=rtol, atol=atol)
        torch.testing.assert_close(k_embed, k_gt, rtol=rtol, atol=atol)


================================================
FILE: tests/pytorch/kernel/test_bitonic_topk.py
================================================
import pytest
import torch


class TestBitonicTopk:

    @pytest.fixture
    def device(self):
        yield 'cuda'

    @pytest.fixture
    def k(self):
        yield 2048

    @pytest.fixture
    def q_seqlens(self, device):
        ret = [4, 16, 1, 32]
        ret = torch.tensor(ret, dtype=torch.int32, device=device)
        yield ret

    @pytest.fixture
    def kv_seqlens(self, device):
        ret = [1024, 2048, 4096, 4096 + 133]
        ret = torch.tensor(ret, dtype=torch.int32, device=device)
        yield ret

    @pytest.fixture
    def batch_size(self, kv_seqlens):
        return kv_seqlens.numel()

    @pytest.fixture
    def max_kv_len(self, kv_seqlens):
        return kv_seqlens.max().item()

    @pytest.fixture
    def scores(self, q_seqlens, max_kv_len, device):
        num_tokens = q_seqlens.sum().item()
        yield torch.randn((num_tokens, max_kv_len), device=device)

    @pytest.fixture
    def gt(self, scores, q_seqlens, kv_seqlens, k):
        batch_size = kv_seqlens.numel()
        num_tokens, _ = scores.shape
        topk_indices = torch.empty((num_tokens, k), dtype=torch.int32, device=scores.device)
        topk_indices.fill_(-1)

        start = 0
        for i in range(batch_size):
            q_seqlen = q_seqlens[i].item()
            seqlen = kv_seqlens[i].item()
            tmp_k = min(seqlen, k)
            end = start + q_seqlen
            _, topk_indices[start:end, :seqlen] = torch.topk(scores[start:end, :seqlen],
                                                             tmp_k,
                                                             largest=True,
                                                             sorted=True)
            start = end
        return topk_indices

    def test_bitonic_topk(self, scores, q_seqlens, kv_seqlens, k, gt):
        from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk
        out = bitonic_topk(scores, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, k=k, fill=-1, sorted=True)
        gt[gt < 0] = 0
        out[out < 0] = 0
        gt_score = torch.gather(scores, 1, gt.to(torch.int64))
        out_score = torch.gather(scores, 1, out.to(torch.int64))
        torch.testing.assert_close(gt_score, out_score)


================================================
FILE: tests/pytorch/kernel/test_causal_conv1d.py
================================================
import pytest
import torch


def do_test():
    try:
        import causal_conv1d  # noqa: F401
        import tilelang  # noqa: F401
        causal_conv1d_fn = causal_conv1d.causal_conv1d_fn  # noqa: F841
        causal_conv1d_update = causal_conv1d.causal_conv1d_update  # noqa: F841
        return True
    except Exception:
        return False


@pytest.mark.skipif(not do_test(), reason='tilelang or causal_conv1d is not available')
class TestCausalConv1dUpdate:

    @pytest.fixture
    def device(self):
        yield 'cuda'

    @pytest.fixture
    def batch(self):
        yield 512

    @pytest.fixture
    def hidden_size(self):
        yield 2048

    @pytest.fixture
    def width(self):
        yield 4

    @pytest.fixture
    def x(self, batch, hidden_size, device):
        yield torch.randn(batch, hidden_size, 1, device=device)

    @pytest.fixture
    def weight(self, hidden_size, width, device):
        yield torch.randn(hidden_size, width, device=device)

    @pytest.fixture
    def conv_state(self, batch, hidden_size, width, device):
        conv_state = torch.randn(batch * 4, hidden_size, width, device=device)
        conv_state = conv_state[::2]
        yield conv_state

    @pytest.fixture
    def bias(self, hidden_size, device):
        yield torch.randn(hidden_size, device=device)

    @pytest.fixture
    def conv_state_indices(self, batch, device):
        conv_state_indices = batch * 2 - 1 - torch.arange(0, batch * 2, 2, device=device)
        yield conv_state_indices.to(torch.int32)

    @pytest.fixture(params=[None, 'silu'])
    def activation(self, request):
        yield request.param

    def test_causal_conv1d_update(self, x, conv_state, weight, bias, activation, conv_state_indices):
        from causal_conv1d import causal_conv1d_update as causal_conv1d_update_gt

        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_update

        conv_state_clone = conv_state.clone()
        out = causal_conv1d_update(x=x,
                                   conv_state=conv_state_clone,
                                   weight=weight,
                                   bias=bias,
                                   activation=activation,
                                   conv_state_indices=conv_state_indices)
        out_gt = causal_conv1d_update_gt(x=x,
                                         conv_state=conv_state,
                                         weight=weight,
                                         bias=bias,
                                         activation=activation,
                                         conv_state_indices=conv_state_indices)
        torch.testing.assert_close(out, out_gt, rtol=1e-3, atol=1e-3)
        torch.testing.assert_close(conv_state_clone, conv_state, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(not do_test(), reason='tilelang or causal_conv1d is not available')
class TestCausalConv1dFn:

    @pytest.fixture
    def device(self):
        yield 'cuda'

    @pytest.fixture
    def hidden_size(self):
        yield 2048

    @pytest.fixture
    def seqlen(self):
        yield 4096

    @pytest.fixture
    def seq_idx(self, seqlen, device):
        seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=device)
        seq_idx[seqlen // 4 * 3:] = 1
        seq_idx = seq_idx.view(1, -1)
        yield seq_idx

    @pytest.fixture
    def x(self, hidden_size, seqlen, device):
        yield torch.randn(1, hidden_size, seqlen, device=device).transpose(1, 2).contiguous().transpose(1, 2)

    @pytest.fixture
    def weight(self, hidden_size, device):
        yield torch.randn(hidden_size, 4, device=device)

    @pytest.fixture
    def bias(self, hidden_size, device):
        yield torch.randn(hidden_size, device=device)

    @pytest.fixture(params=[None, 'silu'])
    def activation(self, request):
        yield request.param

    def test_causal_conv1d_fn(self, x, weight, bias, activation, seq_idx):
        from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_gt

        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_fn

        out = causal_conv1d_fn(x=x,
                               weight=weight,
                               bias=bias,
                               activation=activation,
                               return_final_states=False,
                               seq_idx=seq_idx)
        out_gt = causal_conv1d_fn_gt(x=x,
                                     weight=weight,
                                     bias=bias,
                                     activation=activation,
                                     return_final_states=False,
                                     seq_idx=seq_idx)
        torch.testing.assert_close(out, out_gt, rtol=1e-3, atol=1e-3)


================================================
FILE: tests/pytorch/kernel/test_ds_index.py
================================================
import pytest
import torch


def _make_A(M, K, group_size, out_dtype, device):
    quant_A = torch.randn(M, K // group_size, group_size, dtype=torch.float32, device=device)
    # -1 ~ 1
    quant_A = quant_A * 2 - 1
    # scaling abs max to fmax
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
    quant_A *= scaling
    quant_A = quant_A.to(out_dtype).to(torch.float32)

    # create scale and A
    scale = torch.randn(M, K // group_size, dtype=torch.float32, device=device)
    scale /= fmax
    A = quant_A * scale[..., None]

    A = A.reshape(M, K)
    quant_A = quant_A.reshape(M, K).to(out_dtype)
    scale = scale.T.contiguous().T
    return A, quant_A, scale


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestDSIndex:

    @pytest.fixture
    def num_heads(self):
        yield 64

    @pytest.fixture
    def head_dim(self):
        yield 128

    @pytest.fixture
    def block_size(self):
        yield 64

    @pytest.fixture
    def device(self):
        yield 'cuda'

    @pytest.fixture
    def q_seqlens(self, request):
        yield request.param

    @pytest.fixture
    def kv_seqlens(self, request):
        yield request.param

    @pytest.fixture
    def k_seqlens(self, kv_seqlens, device):
        yield torch.tensor(kv_seqlens, dtype=torch.int32, device=device)

    @pytest.fixture
    def cu_seqlen_q(self, q_seqlens, device):
        yield torch.tensor([0] + list(q_seqlens), dtype=torch.int32, device=device).cumsum(0)

    @pytest.fixture
    def cu_seqlen_kv(self, kv_seqlens, device):
        yield torch.tensor([0] + list(kv_seqlens), dtype=torch.int32, device=device).cumsum(0)

    @pytest.fixture
    def query(self, q_seqlens, num_heads, head_dim, device):
        total_len = sum(q_seqlens)
        fp_q, q, q_s = _make_A(total_len * num_heads, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device=device)
        fp_q = fp_q.view(total_len, num_heads, head_dim)
        q = q.view(total_len, num_heads, head_dim)
        q_s = q_s.view(total_len, num_heads)
        yield fp_q, q, q_s

    @pytest.fixture
    def q(self, query):
        yield query[1]

    @pytest.fixture
    def q_s(self, query):
        yield query[2]

    @pytest.fixture
    def key(self, kv_seqlens, head_dim):
        total_len = sum(kv_seqlens)
        fp_k, k, k_s = _make_A(total_len, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device='cuda')
        fp_k = fp_k.view(total_len, head_dim)
        k = k.view(total_len, head_dim)
        k_s = k_s.view(total_len)
        yield fp_k, k, k_s

    @pytest.fixture
    def k(self, key):
        yield key[1]

    @pytest.fixture
    def k_s(self, key):
        yield key[2]

    @pytest.fixture
    def cache_key(self, k, k_s, kv_seqlens, block_size, head_dim):
        batch_size = len(kv_seqlens)
        max_num_blocks = (max(kv_seqlens) + block_size - 1) // block_size

        # get block offsets
        batch_ids = torch.arange(batch_size, device='cuda') * max_num_blocks
        block_ids = torch.arange(max_num_blocks, device='cuda')
        block_offsets = (batch_ids[:, None] + block_ids[None, :])

        k_cache = torch.zeros((max_num_blocks * batch_size * block_size, head_dim),
                              dtype=torch.float8_e4m3fn,
                              device='cuda')
        k_s_cache = torch.zeros((max_num_blocks * batch_size * block_size), dtype=torch.float32, device='cuda')

        k = k.split(kv_seqlens, dim=0)
        k_s = k_s.split(kv_seqlens, dim=0)
        for i in range(batch_size):
            size = k[i].size(0)
            start = i * max_num_blocks * block_size
            end = start + size
            k_cache[start:end] = k[i]
            k_s_cache[start:end] = k_s[i]

        k_cache = k_cache.view(batch_size * max_num_blocks, block_size, head_dim)
        k_s_cache = k_s_cache.view(batch_size * max_num_blocks, block_size)

        yield k_cache, k_s_cache, block_offsets

    @pytest.fixture
    def k_cache(self, cache_key):
        yield cache_key[0]

    @pytest.fixture
    def k_s_cache(self, cache_key):
        yield cache_key[1]

    @pytest.fixture
    def block_offset(self, cache_key):
        yield cache_key[2]

    @pytest.mark.parametrize('q_seqlens', [(1, 1, 1, 1), (1024, 2048, 1024, 1)], indirect=True)
    @pytest.mark.parametrize('kv_seqlens', [(2048, 4096, 1024, 128)], indirect=True)
    def test_fp8_index(self, q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset):
        # gt requires tilelang, so this test just ensure the kernel works
        from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index
        fp8_index(q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset)


================================================
FILE: tests/pytorch/kernel/test_fill_kv_cache.py
================================================
import pytest
import torch


def _div_up(a, b):
    return (a + b - 1) // b


def quant(kv: torch.Tensor, nbits: int = 8):
    """Quant kv on the head_dim."""
    amax = kv.amax(dim=-1, keepdim=True)
    amin = kv.amin(dim=-1, keepdim=True)
    scales = (amax - amin) / (2**nbits - 1)
    zeros = -amin / scales
    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)
    if nbits == 4:
        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)
        q_kv = q_kv1 + q_kv2 * 16
    return q_kv, torch.cat([scales, zeros], dim=-1)


class TestFillKVCache:

    @pytest.fixture
    def num_heads(self):
        yield 4

    @pytest.fixture
    def head_dim(self):
        yield 32

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def seq_lens(self, request):
        yield request.param

    @pytest.fixture
    def history_lens(self, request):
        yield request.param

    @pytest.fixture
    def batch_size(self, seq_lens):
        yield len(seq_lens)

    @pytest.fixture
    def kv_lens(self, seq_lens, history_lens):
        yield [s + h for s, h in zip(seq_lens, history_lens)]

    @pytest.fixture
    def max_q_seq_length(self, seq_lens):
        yield max(seq_lens)

    @pytest.fixture
    def num_tokens(self, seq_lens):
        yield sum(seq_lens)

    @pytest.fixture
    def num_blocks_per_input(self, kv_lens, block_size):
        yield [_div_up(kv_len, block_size) for kv_len in kv_lens]

    @pytest.fixture
    def max_num_blocks(self, num_blocks_per_input):
        yield max(num_blocks_per_input)

    @pytest.fixture
    def q_seq_length(self, seq_lens):
        yield torch.tensor(seq_lens).cuda()

    @pytest.fixture
    def q_start_loc(self, q_seq_length):
        cum_seq_length = q_seq_length.cumsum(0)
        yield cum_seq_length - q_seq_length

    @pytest.fixture
    def kv_seq_length(self, kv_lens):
        yield torch.tensor(kv_lens).cuda()

    @pytest.fixture
    def k_states(self, num_tokens, num_heads, head_dim):
        yield torch.randn(num_tokens, num_heads, head_dim).cuda()

    @pytest.fixture
    def v_states(self, k_states):
        yield torch.randn_like(k_states)

    @pytest.fixture
    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
        yield torch.full(shape, 0.0).cuda()

    @pytest.fixture
    def v_caches(self, k_caches):
        yield torch.rand_like(k_caches)

    @pytest.fixture
    def block_offsets(self, num_blocks_per_input):
        batch_size = len(num_blocks_per_input)
        max_num_blocks = max(num_blocks_per_input)
        batch_ids = torch.arange(batch_size)
        ret = torch.arange(max_num_blocks)
        ret = batch_ids[:, None] + ret[None, :] * batch_size
        yield ret.cuda()

    @pytest.fixture
    def gt(self, k_states, v_states, k_caches, v_caches, seq_lens, history_lens, block_offsets, block_size):
        batch_size = len(seq_lens)
        k_caches = k_caches.clone()
        v_caches = v_caches.clone()
        splited_k_states = k_states.split(seq_lens)
        splited_v_states = v_states.split(seq_lens)
        for bidx in range(batch_size):
            k_state = splited_k_states[bidx]
            v_state = splited_v_states[bidx]
            h_len = history_lens[bidx]
            b_offs = block_offsets[bidx]
            block_id = _div_up(h_len + 1, block_size) - 1
            fill_start = h_len % block_size
            fill_size = min(block_size - fill_start, k_state.size(0))
            while True:
                boff = b_offs[block_id]
                tmp_ks = k_state[:fill_size]
                tmp_vs = v_state[:fill_size]
                fill_end = fill_start + fill_size
                k_caches[boff, fill_start:fill_end] = tmp_ks
                v_caches[boff, fill_start:fill_end] = tmp_vs
                k_state = k_state[fill_size:]
                v_state = v_state[fill_size:]
                block_id += 1
                fill_start = 0
                fill_size = min(block_size, k_state.size(0))
                if fill_size == 0:
                    break

        yield k_caches, v_caches

    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [
        ((1, 1, 1, 1), (1, 16, 31, 24)),
        ((1, 8, 16, 24), (1, 16, 31, 24)),
    ],
                             indirect=True)
    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, block_offsets, q_start_loc, q_seq_length,
                           kv_seq_length, max_q_seq_length, gt):
        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache
        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,
                      max_q_seq_length, block_offsets)

        torch.testing.assert_close(k_caches, gt[0])
        torch.testing.assert_close(v_caches, gt[1])


class TestFillKVCacheInt8(TestFillKVCache):

    @pytest.fixture
    def head_dim(self, request):
        yield request.param

    @pytest.fixture
    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
        yield torch.full(shape, 0, dtype=torch.uint8).cuda()

    @pytest.fixture
    def v_caches(self, k_caches):
        yield torch.full_like(k_caches.to(torch.float32), 0).to(torch.uint8)

    @pytest.fixture
    def k_scales_zeros(self, batch_size, max_num_blocks, block_size, num_heads):
        shape = (batch_size * max_num_blocks, block_size, num_heads, 2)
        yield torch.full(shape, 0.0).cuda()

    @pytest.fixture
    def v_scales_zeros(self, k_scales_zeros):
        yield torch.zeros_like(k_scales_zeros)

    @pytest.fixture
    def nbits(self):
        yield 8

    @pytest.fixture
    def gt(self, k_states, v_states, k_caches, v_caches, seq_lens, history_lens, block_offsets, block_size,
           k_scales_zeros, v_scales_zeros, nbits):
        k_states, k_states_sz = quant(k_states, nbits)
        v_states, v_states_sz = quant(v_states, nbits)
        batch_size = len(seq_lens)
        k_caches = k_caches.clone()
        v_caches = v_caches.clone()
        splited_k_states = k_states.split(seq_lens)
        splited_v_states = v_states.split(seq_lens)
        splited_k_states_sz = k_states_sz.split(seq_lens)
        splited_v_states_sz = v_states_sz.split(seq_lens)
        for bidx in range(batch_size):
            k_state = splited_k_states[bidx]
            v_state = splited_v_states[bidx]
            k_state_sz = splited_k_states_sz[bidx]
            v_state_sz = splited_v_states_sz[bidx]
            h_len = history_lens[bidx]
            b_offs = block_offsets[bidx]
            block_id = _div_up(h_len + 1, block_size) - 1
            fill_start = h_len % block_size
            fill_size = min(block_size - fill_start, k_state.size(0))
            while True:
                boff = b_offs[block_id]
                tmp_ks = k_state[:fill_size]
                tmp_vs = v_state[:fill_size]
                tmp_ks_sz = k_state_sz[:fill_size]
                tmp_vs_sz = v_state_sz[:fill_size]
                fill_end = fill_start + fill_size
                k_caches[boff, fill_start:fill_end] = tmp_ks
                v_caches[boff, fill_start:fill_end] = tmp_vs
                k_scales_zeros[boff, fill_start:fill_end] = tmp_ks_sz
                v_scales_zeros[boff, fill_start:fill_end] = tmp_vs_sz
                k_state = k_state[fill_size:]
                v_state = v_state[fill_size:]
                k_state_sz = k_state_sz[fill_size:]
                v_state_sz = v_state_sz[fill_size:]
                block_id += 1
                fill_start = 0
                fill_size = min(block_size, k_state.size(0))
                if fill_size == 0:
                    break

        yield k_caches, v_caches, k_scales_zeros, v_scales_zeros

    @pytest.mark.parametrize('head_dim', [128, 96], indirect=True)
    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [
        ((1, 1, 1, 1), (1, 16, 31, 24)),
        ((1, 8, 16, 24), (1, 16, 31, 24)),
    ],
                             indirect=True)
    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_zeros, v_scales_zeros, block_offsets,
                           q_start_loc, q_seq_length, kv_seq_length, max_q_seq_length, gt):
        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache
        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,
                      max_q_seq_length, block_offsets, k_scales_zeros, v_scales_zeros, 8)

        torch.testing.assert_close(k_caches / 256, gt[0] / 256, atol=1e-2, rtol=1e-2)
        torch.testing.assert_close(v_caches / 256, gt[1] / 256, atol=1e-2, rtol=1e-2)
        torch.testing.assert_close(k_scales_zeros, gt[2])
        torch.testing.assert_close(v_scales_zeros, gt[3])


class TestFillKVCacheInt4(TestFillKVCacheInt8):

    @pytest.fixture
    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // 2)
        yield torch.full(shape, 0, dtype=torch.uint8).cuda()

    @pytest.fixture
    def nbits(self):
        yield 4

    @pytest.mark.parametrize('head_dim', [128], indirect=True)
    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [
        ((1, 1, 1, 1), (1, 16, 31, 24)),
        ((1, 8, 16, 24), (1, 16, 31, 24)),
    ],
                             indirect=True)
    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_zeros, v_scales_zeros, block_offsets,
                           q_start_loc, q_seq_length, kv_seq_length, max_q_seq_length, gt, nbits):
        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache
        k_scales_zeros = torch.zeros_like(k_scales_zeros)
        v_scales_zeros = torch.zeros_like(v_scales_zeros)
        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,
                      max_q_seq_length, block_offsets, k_scales_zeros, v_scales_zeros, nbits)

        torch.testing.assert_close(k_scales_zeros, gt[2])
        torch.testing.assert_close(v_scales_zeros, gt[3])
        torch.testing.assert_close(k_caches, gt[0])
        torch.testing.assert_close(v_caches, gt[1])


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestFillKVCacheBlockedFP8(TestFillKVCache):

    @pytest.fixture(autouse=True, scope='class')
    def initialize(self):
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        yield

    @pytest.fixture
    def scale_fmt(self, request):
        yield request.param

    @pytest.fixture
    def quant_dtype(self):
        yield torch.float8_e4m3fn

    @pytest.fixture
    def num_heads(self):
        yield 4

    @pytest.fixture
    def head_dim(self):
        yield 128

    @pytest.fixture
    def block_size(self):
        yield 64

    @pytest.fixture
    def group_size(self):
        yield 128

    @pytest.fixture
    def cu_seqlen_q(self, q_start_loc, q_seq_length):
        batch_size = q_start_loc.size(0)
        cu_seqlen = torch.zeros(batch_size + 1, dtype=torch.int32).cuda()
        cu_seqlen[1:] = q_start_loc + q_seq_length
        return cu_seqlen

    @pytest.fixture
    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, quant_dtype):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
        yield torch.full(shape, 0, dtype=quant_dtype).cuda()

    @pytest.fixture
    def v_caches(self, k_caches):
        yield torch.zeros_like(k_caches)

    @pytest.fixture
    def ks_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, group_size):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // group_size)
        yield torch.full(shape, 0.0).cuda()

    @pytest.fixture
    def vs_caches(self, ks_caches):
        yield torch.ones_like(ks_caches)

    @pytest.fixture
    def gt(self, k_states, v_states, group_size, quant_dtype, scale_fmt):
        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
        batch_size = k_states.size(0)
        num_heads = k_states.size(1)
        head_dim = k_states.size(2)

        k_states = k_states.flatten(0, -2)
        v_states = v_states.flatten(0, -2)
        quant_k, quant_ks = quant_fp8(k_states, group_size=group_size, dtype=quant_dtype, scale_fmt=scale_fmt)
        quant_v, quant_vs = quant_fp8(v_states, group_size=group_size, dtype=quant_dtype, scale_fmt=scale_fmt)

        quant_k = quant_k.view(batch_size, num_heads, head_dim)
        quant_ks = quant_ks.view(batch_size, num_heads, head_dim // group_size)
        quant_v = quant_v.view(batch_size, num_heads, head_dim)
        quant_vs = quant_vs.view(batch_size, num_heads, head_dim // group_size)

        yield quant_k, quant_ks, quant_v, quant_vs

    def uncache(self, k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, kv_seqlens, block_offsets):
        batch_size = block_offsets.size(0)
        out_k = []
        out_ks = []
        out_v = []
        out_vs = []
        q_seqlens = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
        for bidx in range(batch_size):
            seqlen = q_seqlens[bidx].item()
            kv_len = kv_seqlens[bidx].item()
            start = kv_len - seqlen
            end = kv_len
            k = k_caches[block_offsets[bidx]].reshape(-1, k_caches.size(-2), k_caches.size(-1))
            ks = ks_caches[block_offsets[bidx]].reshape(-1, ks_caches.size(-2), ks_caches.size(-1))
            v = v_caches[block_offsets[bidx]].reshape(-1, v_caches.size(-2), v_caches.size(-1))
            vs = vs_caches[block_offsets[bidx]].reshape(-1, vs_caches.size(-2), vs_caches.size(-1))
            out_k.append(k[start:end])
            out_ks.append(ks[start:end])
            out_v.append(v[start:end])
            out_vs.append(vs[start:end])
        out_k = torch.cat(out_k, dim=0)
        out_ks = torch.cat(out_ks, dim=0)
        out_v = torch.cat(out_v, dim=0)
        out_vs = torch.cat(out_vs, dim=0)
        return out_k, out_ks, out_v, out_vs

    @pytest.mark.parametrize('scale_fmt', [None, 'ue8m0'], indirect=True)
    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [
        ((1, 1, 1, 1), (1, 128, 256, 200)),
        ((1, 64, 128, 50), (1, 128, 256, 200)),
    ],
                             indirect=True)
    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, vs_caches, block_offsets,
                           cu_seqlen_q, kv_seq_length, max_q_seq_length, gt, group_size, scale_fmt):
        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8
        fill_kv_cache_blocked_fp8(k_states,
                                  v_states,
                                  k_caches,
                                  v_caches,
                                  ks_caches,
                                  vs_caches,
                                  cu_seqlen_q,
                                  kv_seq_length,
                                  max_q_seq_length,
                                  block_offsets=block_offsets,
                                  group_size=group_size,
                                  scale_fmt=scale_fmt)

        gt_k, gt_ks, gt_v, gt_vs = gt

        # uncache
        out_k, out_ks, out_v, out_vs = self.uncache(k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q,
                                                    kv_seq_length, block_offsets)

        out_k = out_k.float()
        out_k = out_k / out_k.max()
        gt_k = gt_k.float()
        gt_k = gt_k / gt_k.max()
        out_v = out_v.float()
        out_v = out_v / out_v.max()
        gt_v = gt_v.float()
        gt_v = gt_v / gt_v.max()
        torch.testing.assert_close(out_k, gt_k)
        torch.testing.assert_close(out_ks, gt_ks)
        torch.testing.assert_close(out_v, gt_v)
        torch.testing.assert_close(out_vs, gt_vs)


================================================
FILE: tests/pytorch/kernel/test_flash_attention.py
================================================
import math

import pytest
import torch


def _conti_input(data, q_seqlens):
    data = [x[:l] for x, l in zip(data, q_seqlens)]
    data = torch.cat(data, dim=0)
    return data


def _make_bias(q_seqlens, history_lens, neg_val, causal):
    batch_size = q_seqlens.shape[0]
    kv_seqlens = q_seqlens + history_lens
    max_seq_len = q_seqlens.max().item()
    max_kv_len = kv_seqlens.max().item()
    if causal:
        seq_ranges = torch.arange(max_seq_len).cuda()
        seq_ranges = seq_ranges.repeat(batch_size, 1)
        seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

        kv_ranges = torch.arange(max_kv_len).cuda()
        kv_ranges = kv_ranges.repeat(batch_size, 1)

        mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
        return mask.float() * neg_val
    else:
        q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None]
        k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None]
        mask = q_mask[:, :, None] & k_mask[:, None, :]

        return (~mask).float() * neg_val


def _make_bias_alibi(q_seqlens, history_lens, neg_val, causal, alibi_slopes):

    batch_size = q_seqlens.shape[0]
    kv_seqlens = q_seqlens + history_lens
    max_q_len = q_seqlens.max().item()
    max_kv_len = kv_seqlens.max().item()

    device = 'cuda'
    q_ranges = torch.arange(max_q_len, device=device)
    seq_ranges = q_ranges.repeat(batch_size, 1) + history_lens[:, None]

    kv_ranges = torch.arange(max_kv_len, device=device)
    kv_ranges = kv_ranges.repeat(batch_size, 1)

    diff = (seq_ranges[:, :, None] - kv_ranges[:, None, :]).abs()
    slope_diff = -diff[:, None] * alibi_slopes[None, :, None, None]

    # add bias
    bias = _make_bias(q_seqlens, history_lens, neg_val, causal)
    bias = bias[:, None] + slope_diff
    return bias


def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,
                            block_sparse_size: int):
    """Make block sparse bias."""
    batch_size = q_seqlens.shape[0]
    kv_seqlens = q_seqlens + history_lens
    max_seq_len = q_seqlens.max().item()
    max_kv_len = kv_seqlens.max().item()

    seq_ranges = torch.arange(max_seq_len).cuda()
    seq_ranges = seq_ranges // block_sparse_size * block_sparse_size
    seq_ranges = seq_ranges.repeat(batch_size, 1)
    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

    kv_ranges = torch.arange(max_kv_len).cuda()
    kv_ranges = kv_ranges // block_sparse_size * block_sparse_size
    kv_ranges = kv_ranges.repeat(batch_size, 1)

    mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
    return mask.float() * neg_val


def _naive_attention(batched_q, batched_kv, bias, sinks=None):
    batched_k, batched_v = batched_kv

    num_heads_q = batched_q.shape[2]
    num_heads_k = batched_k.shape[2]
    head_dim = batched_q.shape[-1]
    group = num_heads_q // num_heads_k

    q = batched_q.transpose(1, 2)
    k = batched_k.permute(0, 2, 3, 1)
    v = batched_v.transpose(1, 2)

    # expand group
    k = k.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)
    v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)

    qk = torch.matmul(q, k) / math.sqrt(head_dim)
    if bias.dim() == 3:
        bias = bias[:, None]
    attn_weight = qk + bias
    if sinks is None:
        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
    else:
        sinks = sinks[None, :, None, None].to(torch.float32)
        sinks = sinks.expand(attn_weight.shape[0], -1, attn_weight.shape[2], -1)
        attn_weight = attn_weight.to(torch.float32)
        combined_logits = torch.cat([attn_weight, sinks], dim=-1)
        combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
        attn_weight = torch.softmax(combined_logits, dim=-1, dtype=torch.float32)
        attn_weight = attn_weight[..., :-1]
    attn_weight = attn_weight.to(q.dtype)
    attn_output = torch.matmul(attn_weight, v)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output


def _naive_window_attention(q, k, v, seqlens_q, seqlens_k, window_size):
    try:
        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
    except Exception:
        try:
            from flash_attn import flash_attn_varlen_func
        except Exception:
            pytest.skip('Skip window attention test since flash attention is not available.')

    def _make_cu_seqlens(seqlens):
        cu_seqlens = seqlens.cumsum(0)
        cu_zero = cu_seqlens.new_zeros(1)
        cu_seqlens = torch.cat([cu_zero, cu_seqlens])
        return cu_seqlens

    max_seqlen_q = seqlens_q.max().item()
    max_seqlen_k = seqlens_k.max().item()
    cu_seqlens_q = _make_cu_seqlens(seqlens_q).int()
    cu_seqlens_k = _make_cu_seqlens(seqlens_k).int()

    output = flash_attn_varlen_func(q,
                                    k,
                                    v,
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    max_seqlen_q=max_seqlen_q,
                                    max_seqlen_k=max_seqlen_k,
                                    causal=True,
                                    window_size=window_size)
    return output


class TestFlashAttention:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def head_dim_k(self, request):
        yield request.param

    @pytest.fixture
    def head_dim_v(self, request):
        yield request.param

    @pytest.fixture
    def num_heads_q(self, request):
        yield request.param

    @pytest.fixture
    def num_heads_k(self, request):
        yield request.param

    @pytest.fixture
    def causal(self, request):
        yield request.param

    @pytest.fixture
    def q_seqlens(self, request):
        yield torch.tensor(request.param, device='cuda')

    @pytest.fixture
    def cu_seqlens_q(self, q_seqlens):
        cu_seqlens = q_seqlens.cumsum(0)
        cu_zero = cu_seqlens.new_zeros(1)
        yield torch.cat([cu_zero, cu_seqlens]).int()

    @pytest.fixture
    def history_lens(self, request):
        yield torch.tensor(request.param, device='cuda')

    @pytest.fixture
    def kv_seqlens(self, q_seqlens, history_lens):
        yield q_seqlens + history_lens

    @pytest.fixture
    def cu_seqlens_k(self, kv_seqlens):
        cu_seqlens = kv_seqlens.cumsum(0)
        cu_zero = cu_seqlens.new_zeros(1)
        yield torch.cat([cu_zero, cu_seqlens]).int()

    @pytest.fixture
    def batched_q(self, q_seqlens, num_heads_q, head_dim_k, dtype):
        torch.manual_seed(123)
        batch_size = len(q_seqlens)
        max_seq_len = q_seqlens.max().item()
        inputs = torch.rand(batch_size, max_seq_len, num_heads_q, head_dim_k, dtype=dtype, device='cuda')
        yield inputs

    @pytest.fixture
    def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k, head_dim_v, dtype):
        torch.manual_seed(123)
        batch_size = len(q_seqlens)
        kv_seqlens = q_seqlens + history_lens
        max_seq_len = kv_seqlens.max().item()
        k = torch.rand(batch_size, max_seq_len, num_heads_k, head_dim_k, dtype=dtype, device='cuda')
        v = torch.rand(batch_size, max_seq_len, num_heads_k, head_dim_v, dtype=dtype, device='cuda')
        yield k, v

    @pytest.fixture
    def conti_q(self, q_seqlens, batched_q):
        yield _conti_input(batched_q, q_seqlens)

    @pytest.fixture
    def conti_kv(self, kv_seqlens, batched_kv):
        conti_k = _conti_input(batched_kv[0], kv_seqlens)
        conti_k = conti_k.transpose(0, 1).contiguous()
        conti_v = _conti_input(batched_kv[1], kv_seqlens)
        conti_v = conti_v.transpose(0, 1).contiguous()
        yield (conti_k, conti_v)

    @pytest.fixture
    def mask(self, q_seqlens, history_lens, causal):
        neg_val = -1e30
        yield _make_bias(q_seqlens, history_lens, neg_val, causal)

    @pytest.fixture
    def gt(self, batched_q, batched_kv, mask):
        yield _naive_attention(batched_q, batched_kv, mask)

    @pytest.fixture
    def conti_gt(self, gt, q_seqlens):
        yield _conti_input(gt, q_seqlens)

    @pytest.mark.parametrize('head_dim_k', [32, 48], indirect=True)
    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)
    @pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True)
    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)
    @pytest.mark.parametrize('causal', [True, False], indirect=True)
    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True)
    def test_flash_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, conti_gt):
        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func
        max_seq_len = q_seqlens.max().item()

        conti_k, conti_v = conti_kv
        out = flash_attn_varlen_func(conti_q,
                                     conti_k,
                                     conti_v,
                                     cu_seqlens_q,
                                     cu_seqlens_k,
                                     max_seqlen_q=max_seq_len,
                                     causal=causal)
        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)

    @pytest.fixture
    def win_size(self, request):
        yield request.param

    @pytest.fixture
    def window_gt(self, conti_q, conti_kv, q_seqlens, kv_seqlens, win_size):
        conti_k, conti_v = conti_kv
        yield _naive_window_attention(conti_q,
                                      conti_k.transpose(0, 1),
                                      conti_v.transpose(0, 1),
                                      q_seqlens,
                                      kv_seqlens,
                                      window_size=(win_size, win_size))

    @pytest.mark.parametrize('head_dim_k', [16], indirect=True)
    @pytest.mark.parametrize('head_dim_v', [16], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)
    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [
        ([30, 50, 70, 90], [50, 40, 30, 90]),
    ], indirect=True)
    @pytest.mark.parametrize('win_size', (32, ), indirect=True)
    def test_window_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, win_size, window_gt):
        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func
        max_seq_len = q_seqlens.max().item()

        conti_k, conti_v = conti_kv
        out = flash_attn_varlen_func(conti_q,
                                     conti_k,
                                     conti_v,
                                     cu_seqlens_q,
                                     cu_seqlens_k,
                                     max_seqlen_q=max_seq_len,
                                     window_size=win_size,
                                     causal=True)
        torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)

    @pytest.fixture
    def sinks(self, num_heads_q, dtype):
        yield torch.rand(num_heads_q, dtype=dtype, device='cuda')

    @pytest.fixture
    def sink_gt(self, batched_q, batched_kv, mask, sinks):
        yield _naive_attention(batched_q, batched_kv, mask, sinks)

    @pytest.fixture
    def conti_sink_gt(self, sink_gt, q_seqlens):
        yield _conti_input(sink_gt, q_seqlens)

    @pytest.mark.parametrize('head_dim_k', [32], indirect=True)
    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)
    @pytest.mark.parametrize('num_heads_q', [8], indirect=True)
    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)
    @pytest.mark.parametrize('causal', [True], indirect=True)
    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True)
    def test_sinks(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, sinks, conti_sink_gt):
        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func
        max_seq_len = q_seqlens.max().item()

        conti_k, conti_v = conti_kv
        out = flash_attn_varlen_func(conti_q,
                                     conti_k,
                                     conti_v,
                                     cu_seqlens_q,
                                     cu_seqlens_k,
                                     max_seqlen_q=max_seq_len,
                                     sinks=sinks,
                                     causal=causal)
        torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)

    # block sparse attention
    @pytest.fixture
    def block_sparse_size(self):
        yield 4

    @pytest.fixture
    def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size):
        neg_val = -1e30
        yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size)

    @pytest.fixture
    def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask):
        yield _naive_attention(batched_q, batched_kv, block_sparse_mask)

    @pytest.mark.parametrize('head_dim_k', [32], indirect=True)
    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)
    @pytest.mark.parametrize('num_heads_q', [8], indirect=True)
    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)
    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True)
    def test_block_sparse_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, block_sparse_size,
                                    block_sparse_gt):
        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func
        max_seq_len = q_seqlens.max().item()

        conti_k, conti_v = conti_kv
        out = flash_attn_varlen_func(conti_q,
                                     conti_k,
                                     conti_v,
                                     cu_seqlens_q,
                                     cu_seqlens_k,
                                     max_seqlen_q=max_seq_len,
                                     block_sparse_size=block_sparse_size,
                                     causal=True)
        gt = _conti_input(block_sparse_gt, q_seqlens)
        torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5)

    @pytest.fixture
    def alibi_slopes(self, num_heads_q):
        yield torch.rand(num_heads_q, dtype=torch.float32, device='cuda')

    @pytest.fixture
    def alibi_bias(self, q_seqlens, history_lens, causal, alibi_slopes):
        neg_val = -1e30
        yield _make_bias_alibi(q_seqlens, history_lens, neg_val, causal, alibi_slopes)

    @pytest.fixture
    def alibi_gt(self, batched_q, batched_kv, alibi_bias):
        yield _naive_attention(batched_q, batched_kv, alibi_bias)

    @pytest.fixture
    def conti_alibi_gt(self, alibi_gt, q_seqlens):
        yield _conti_input(alibi_gt, q_seqlens)

    @pytest.mark.parametrize('head_dim_k', [128], indirect=True)
    @pytest.mark.parametrize('head_dim_v', [128], indirect=True)
    @pytest.mark.parametrize('num_heads_q', [40], indirect=True)
    @pytest.mark.parametrize('num_heads_k', [8], indirect=True)
    @pytest.mark.parametrize('causal', [True], indirect=True)
    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [
        ([30, 50, 70, 90], [50, 40, 30, 20]),
    ], indirect=True)
    def test_alibi(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, alibi_slopes,
                   conti_alibi_gt):
        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func
        max_seq_len = q_seqlens.max().item()

        conti_k, conti_v = conti_kv
        out = flash_attn_varlen_func(conti_q,
                                     conti_k,
                                     conti_v,
                                     cu_seqlens_q,
                                     cu_seqlens_k,
                                     max_seqlen_q=max_seq_len,
                                     alibi_slopes=alibi_slopes,
                                     causal=causal)
        torch.testing.assert_close(out, conti_alibi_gt, atol=1e-3, rtol=1e-5)


================================================
FILE: tests/pytorch/kernel/test_flatten_kv_cache.py
================================================
import pytest
import torch


def _div_up(a, b):
    return (a + b - 1) // b


class TestFlattenKVCache:

    @pytest.fixture
    def out_dtype(self):
        yield torch.float16

    @pytest.fixture
    def num_heads(self):
        yield 4

    @pytest.fixture
    def head_dim(self):
        yield 32

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def kv_lens(self):
        yield [2, 24, 47, 48]

    @pytest.fixture
    def batch_size(self, kv_lens):
        yield len(kv_lens)

    @pytest.fixture
    def num_blocks_per_input(self, kv_lens, block_size):
        yield [_div_up(kv_len, block_size) for kv_len in kv_lens]

    @pytest.fixture
    def max_num_blocks(self, num_blocks_per_input):
        yield max(num_blocks_per_input)

    @pytest.fixture
    def out_size(self, kv_lens):
        yield sum(kv_lens)

    @pytest.fixture
    def kv_seqlens(self, kv_lens):
        yield torch.tensor(kv_lens).cuda()

    @pytest.fixture
    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, out_dtype):
        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
        yield torch.rand(shape, dtype=out_dtype, device='cuda')

    @pytest.fixture
    def v_caches(self, k_caches):
        yield torch.rand_like(k_caches)

    @pytest.fixture
    def block_offsets(self, num_blocks_per_input):
        batch_size = len(num_blocks_per_input)
        max_num_blocks = max(num_blocks_per_input)
        batch_ids = torch.arange(batch_size)
        ret = torch.arange(max_num_blocks)
        ret = batch_ids[:, None] + ret[None, :] * batch_size
        yield ret.cuda()

    @pytest.fixture
    def gt(self, k_caches, v_caches, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim):
        k_states = k_caches.new_empty(num_heads, out_size, head_dim)
        v_states = v_caches.new_empty(num_heads, out_size, head_dim)
        start_loc = 0
        for kv_len, block_offs in zip(kv_lens, block_offsets):
            remain_len = kv_len
            for idx, _ in enumerate(range(0, kv_len, block_size)):
                b_off = block_offs[idx]
                block_len = min(block_size, remain_len)
                end_loc = start_loc + block_len
                k_block = k_caches[b_off, :block_len]
                v_block = v_caches[b_off, :block_len]
                k_states[:, start_loc:end_loc] = k_block.transpose(0, 1)
                v_states[:, start_loc:end_loc] = v_block.transpose(0, 1)
                start_loc = end_loc
                remain_len -= block_len

        yield k_states, v_states

    def test_flatten_kv_cache(self, k_caches, v_caches, kv_seqlens, block_offsets, out_size, gt):
        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache

        k_states, v_states = flatten_kv_cache(k_caches, v_caches, kv_seqlens, block_offsets, out_size=out_size)
        torch.testing.assert_close(k_states, gt[0])
        torch.testing.assert_close(v_states, gt[1])


def precise_round(x: torch.Tensor):
    return x.sign() * (x.abs() + 0.5).floor()


def quant(kv: torch.Tensor, nbits: int = 8):
    """Quant kv on the head_dim."""
    amax = kv.amax(dim=-1, keepdim=True)
    amin = kv.amin(dim=-1, keepdim=True)
    scales = (amax - amin) / (2**nbits - 1)
    zeros = -amin / scales
    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)
    if nbits == 4:
        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)
        q_kv = q_kv1 + q_kv2 * 16
    return q_kv, torch.cat([scales, zeros], dim=-1)


class TestFlattenKVCacheQuant8(TestFlattenKVCache):

    @pytest.fixture
    def nbits(self):
        yield 8

    @pytest.fixture
    def atol(self):
        yield 4e-3

    @pytest.fixture
    def rtol(self):
        yield 1e-5

    @pytest.fixture
    def k_quant(self, k_caches, nbits):
        yield quant(k_caches, nbits)

    @pytest.fixture
    def v_quant(self, v_caches, nbits):
        yield quant(v_caches, nbits)

    def test_flatten_kv_cache(self, k_quant, v_quant, kv_seqlens, block_offsets, out_size, out_dtype, nbits, gt, atol,
                              rtol):
        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache

        k_caches, k_sz = k_quant
        v_caches, v_sz = v_quant

        k_sz = k_sz.to(out_dtype)
        v_sz = v_sz.to(out_dtype)

        k_states, v_states = flatten_kv_cache(k_caches,
                                              v_caches,
                                              kv_seqlens,
                                              block_offsets,
                                              out_size=out_size,
                                              out_dtype=out_dtype,
                                              k_scales_zeros=k_sz,
                                              v_scales_zeros=v_sz,
                                              quant_policy=nbits)

        torch.testing.assert_close(k_states, gt[0], atol=atol, rtol=rtol)
        torch.testing.assert_close(v_states, gt[1], atol=atol, rtol=rtol)


class TestFlattenKVCacheQuant4(TestFlattenKVCacheQuant8):

    @pytest.fixture
    def nbits(self):
        yield 4

    @pytest.fixture
    def atol(self):
        yield 0.05

    @pytest.fixture
    def rtol(self):
        yield 1e-3


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestFlattenKVCacheMLAFP8(TestFlattenKVCache):

    @pytest.fixture
    def out_dtype(self):
        yield torch.bfloat16

    @pytest.fixture
    def num_heads(self):
        yield 1

    @pytest.fixture
    def head_dim(self):
        yield 576

    @pytest.fixture
    def block_size(self):
        yield 64

    @pytest.fixture
    def k_cache_mla(self, k_caches):
        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
        num_blocks, block_size, num_heads, _ = k_caches.shape
        k_cache_pe = k_caches[:, :, :, 512:]
        k_cache_nope = k_caches[:, :, :, :512].flatten(0, -2)
        k_cache_nope, k_cache_scale = quant_fp8(k_cache_nope, group_size=128)
        k_cache_nope = k_cache_nope.view(num_blocks, block_size, num_heads, -1)
        k_cache_scale = k_cache_scale.reshape(num_blocks, block_size, num_heads, -1).to(torch.float32)
        dtype = k_cache_nope.dtype
        out = torch.cat([k_cache_nope, k_cache_scale.view(dtype), k_cache_pe.view(dtype)], dim=-1)
        yield out

    def _dequant(self, k_cache_mla):
        k_cache_nope = k_cache_mla[..., :512].to(torch.float32)
        k_cache_scale = k_cache_mla[..., 512:512 + 16].view(torch.float32)
        k_cache_pe = k_cache_mla[..., 512 + 16:].view(torch.bfloat16)
        k_cache_nope = k_cache_nope.unflatten(-1, (-1, 128))
        k_cache_scale = k_cache_scale[..., None]
        k_cache_nope *= k_cache_scale
        k_cache_nope = k_cache_nope.flatten(-2, -1).to(k_cache_pe.dtype)
        k_cache = torch.cat([k_cache_nope, k_cache_pe], dim=-1)
        return k_cache

    @pytest.fixture
    def gt(self, k_cache_mla, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim):
        k_caches = self._dequant(k_cache_mla)
        k_states = k_caches.new_empty(num_heads, out_size, head_dim)
        start_loc = 0
        for kv_len, block_offs in zip(kv_lens, block_offsets):
            remain_len = kv_len
            for idx, _ in enumerate(range(0, kv_len, block_size)):
                b_off = block_offs[idx]
                block_len = min(block_size, remain_len)
                end_loc = start_loc + block_len
                k_block = k_caches[b_off, :block_len]
                k_states[:, start_loc:end_loc] = k_block.transpose(0, 1)
                start_loc = end_loc
                remain_len -= block_len

        yield k_states

    def test_flatten_kv_cache(self, k_cache_mla, kv_seqlens, block_offsets, out_size, out_dtype, gt):
        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8

        k_states = flatten_kv_cache_mla_fp8(k_cache_mla,
                                            kv_seqlens,
                                            block_offsets,
                                            out_size=out_size,
                                            out_dtype=out_dtype)
        torch.testing.assert_close(k_states, gt)


================================================
FILE: tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py
================================================
import pytest
import torch


def _make_A(M, K, group_size, out_dtype, device='cuda'):
    quant_A = torch.rand(M, K // group_size, group_size, dtype=torch.float32, device=device)
    # -1 ~ 1
    quant_A = quant_A * 2 - 1
    # scaling abs max to fmax
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
    quant_A *= scaling
    quant_A = quant_A.to(out_dtype).to(torch.float32)

    # create scale and A
    scale = torch.rand(M, K // group_size, dtype=torch.float32, device=device)
    scale /= fmax
    A = quant_A * scale[..., None]

    A = A.reshape(M, K)
    quant_A = quant_A.reshape(M, K).to(out_dtype)
    return A, quant_A, scale


def _make_B(E, K, N, group_size, out_dtype, device='cuda'):
    quant_B = torch.rand(E,
                         N // group_size,
                         group_size,
                         K // group_size,
                         group_size,
                         dtype=torch.float32,
                         device=device)
    quant_B = quant_B * 2 - 1

    # scaling abs max to fmax
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scaling = fmax / quant_B.abs().amax((2, 4), keepdim=True)
    quant_B *= scaling
    quant_B = quant_B.to(out_dtype).to(torch.float32)

    scale = torch.rand(E, N // group_size, 1, K // group_size, 1, dtype=torch.float32, device=device)
    scale /= fmax

    B = quant_B * scale

    B = B.reshape(E, N, K)
    quant_B = quant_B.reshape(E, N, K).to(out_dtype)
    scale = scale.reshape(E, N // group_size, K // group_size)
    bias = torch.rand(E, N, dtype=torch.float32, device=device) - 0.5
    return B, quant_B, scale, bias


def _get_sorted_idx(topk_idx: torch.Tensor, num_experts: int):
    flatten_topk_idx = topk_idx.flatten()
    sorted_ids = flatten_topk_idx.argsort()
    exp_range = torch.arange(0, num_experts, device=topk_idx.device)
    exp_tok_cnt = (flatten_topk_idx[None, :] == exp_range[:, None]).sum(1)
    return sorted_ids, exp_tok_cnt


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestFusedMoEFP8KernelLauncher:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def quant_dtype(self):
        yield torch.float8_e4m3fn

    @pytest.fixture
    def device(self):
        yield torch.device('cuda')

    @pytest.fixture
    def N(self):
        yield 512

    @pytest.fixture
    def K(self):
        yield 1024

    @pytest.fixture
    def M(self):
        yield 256

    @pytest.fixture
    def num_experts(self):
        yield 64

    @pytest.fixture
    def top_k(self):
        yield 6

    @pytest.fixture
    def group_size(self):
        yield 128

    @pytest.fixture
    def build_A(self, M, K, group_size, quant_dtype, device):
        yield _make_A(M, K, group_size=group_size, out_dtype=quant_dtype, device=device)

    @pytest.fixture
    def A(self, build_A, dtype):
        yield build_A[0].to(dtype)

    @pytest.fixture
    def A_quant(self, build_A):
        yield build_A[1]

    @pytest.fixture
    def A_scale(self, build_A):
        yield build_A[2]

    @pytest.fixture
    def build_B(self, num_experts, N, K, group_size, quant_dtype, device):
        yield _make_B(num_experts, K, N, group_size=group_size, out_dtype=quant_dtype, device=device)

    @pytest.fixture
    def B(self, build_B, dtype):
        yield build_B[0].to(dtype)

    @pytest.fixture
    def B_quant(self, build_B):
        yield build_B[1]

    @pytest.fixture
    def B_scale(self, build_B):
        yield build_B[2]

    @pytest.fixture
    def bias(self, build_B, dtype):
        yield build_B[3].to(dtype)
        # yield None

    @pytest.fixture
    def router_weights(self, M, num_experts, device, dtype):
        yield torch.rand(M, num_experts, device=device, dtype=dtype)

    @pytest.fixture
    def topk_weights(self, router_weights, top_k):
        yield router_weights.topk(top_k, dim=-1)

    @pytest.fixture
    def topk_idx(self, topk_weights):
        yield topk_weights[1]

    @pytest.fixture
    def sort_and_cnt(self, topk_idx, num_experts):
        yield _get_sorted_idx(topk_idx, num_experts)

    @pytest.fixture
    def sorted_idx(self, sort_and_cnt):
        yield sort_and_cnt[0]

    @pytest.fixture
    def exp_tok_cnt(self, sort_and_cnt):
        yield sort_and_cnt[1]

    @pytest.fixture
    def exp_end(self, exp_tok_cnt):
        yield exp_tok_cnt.cumsum(0)

    @pytest.fixture
    def exp_start(self, exp_end, exp_tok_cnt):
        yield exp_end - exp_tok_cnt

    @pytest.fixture
    def gt(self, A, B, bias, top_k, sorted_idx, exp_start, exp_end, M):
        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe_kernel_launcher
        N = B.size(1)
        C = B.new_empty(M * top_k, N)
        fused_moe_kernel_launcher(
            A,
            B,
            C,
            sorted_idx,
            exp_start,
            exp_end,
            bias=bias,
            top_k=top_k,
            num_tokens=M,
        )

        yield C

    @torch.inference_mode()
    def test_launcher(self, A_quant, A_scale, B, B_quant, B_scale, bias, sorted_idx, exp_start, exp_end, top_k, M, gt):
        from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8_kernel_launcher
        N = B.size(1)
        C = B.new_empty(M * top_k, N)
        fused_moe_blocked_fp8_kernel_launcher(
            A=A_quant,
            A_scale=A_scale,
            B=B_quant,
            B_scale=B_scale,
            C=C,
            sorted_idx=sorted_idx,
            exp_start=exp_start,
            exp_end=exp_end,
            bias=bias,
            top_k=top_k,
            num_tokens=M,
        )

        gt_max = gt.abs().max()
        C = C / gt_max
        gt = gt / gt_max
        torch.testing.assert_close(C, gt, atol=4e-3, rtol=1e-3)


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestFusedMoeBlockedFP8:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def quant_dtype(self):
        yield torch.float8_e4m3fn

    @pytest.fixture
    def device(self):
        yield torch.device('cuda')

    @pytest.fixture
    def in_size(self):
        yield 512

    @pytest.fixture
    def seq_len(seq_len):
        yield 128

    @pytest.fixture
    def hidden_size(self):
        yield 2048

    @pytest.fixture
    def out_size(self):
        yield 1024

    @pytest.fixture
    def num_experts(self):
        yield 4

    @pytest.fixture
    def top_k(self):
        yield 2

    @pytest.fixture
    def group_size(self):
        yield 128

    @pytest.fixture
    def renormalize(self):
        yield True

    @pytest.fixture
    def build_hidden_states(self, seq_len, in_size, group_size, quant_dtype, device):
        yield _make_A(seq_len, in_size, group_size=group_size, out_dtype=quant_dtype, device=device)

    @pytest.fixture
    def hidden_states(self, build_hidden_states, dtype):
        yield build_hidden_states[0].to(dtype)

    @pytest.fixture
    def states_quanted(self, build_hidden_states):
        yield build_hidden_states[1]

    @pytest.fixture
    def states_scale(self, build_hidden_states):
        yield build_hidden_states[2]

    @pytest.fixture
    def build_w1(self, num_experts, hidden_size, in_size, group_size, quant_dtype, device):
        yield _make_B(num_experts, in_size, hidden_size, group_size=group_size, out_dtype=quant_dtype, device=device)

    @pytest.fixture
    def w1(self, build_w1, dtype):
        yield build_w1[0].to(dtype)

    @pytest.fixture
    def w1_quant(self, build_w1):
        yield build_w1[1]

    @pytest.fixture
    def w1_scale(self, build_w1):
        yield build_w1[2]

    @pytest.fixture
    def build_w2(self, num_experts, out_size, hidden_size, group_size, quant_dtype, device):
        yield _make_B(num_experts,
                      hidden_size // 2,
                      out_size,
                      group_size=group_size,
                      out_dtype=quant_dtype,
                      device=device)

    @pytest.fixture
    def w2(self, build_w2, dtype):
        yield build_w2[0].to(dtype)

    @pytest.fixture
    def w2_quant(self, build_w2):
        yield build_w2[1]

    @pytest.fixture
    def w2_scale(self, build_w2):
        yield build_w2[2]

    @pytest.fixture
    def router_logits(self, seq_len, num_experts, dtype, device):
        yield torch.rand(seq_len, num_experts, dtype=dtype, device=device)

    @pytest.fixture
    def topk_logits(self, router_logits, top_k):
        routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
        yield torch.topk(routing_weights, top_k, dim=-1)

    @pytest.fixture
    def topk_weights(self, topk_logits):
        yield topk_logits[0]

    @pytest.fixture
    def topk_idx(self, topk_logits):
        yield topk_logits[1]

    @pytest.fixture
    def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, renormalize):
        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe
        output = fused_moe(hidden_states, w1, w2, topk_weights, topk_idx, topk=top_k, renormalize=renormalize)
        yield output

    @torch.inference_mode()
    def test_fused_moe(self, states_quanted, states_scale, w1_quant, w1_scale, w2_quant, w2_scale, topk_weights,
                       topk_idx, top_k, renormalize, gt):
        from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
        output = fused_moe_blocked_fp8(states_quanted,
                                       states_scale,
                                       w1_quant,
                                       w1_scale,
                                       w2_quant,
                                       w2_scale,
                                       topk_weights,
                                       topk_idx,
                                       topk=top_k,
                                       renormalize=renormalize)
        out_max = output.abs().max()
        gt_max = gt.abs().max()
        assert (out_max - gt_max).abs() / out_max < 0.05

        norm_out = output / out_max
        norm_gt = gt / gt_max
        torch.testing.assert_close(norm_out, norm_gt, atol=0.05, rtol=1e-3)


================================================
FILE: tests/pytorch/kernel/test_fused_lora.py
================================================
import pytest
import torch

from lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora


class TestFusedLoRA:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def head_size(self):
        yield 32

    @pytest.fixture
    def out_head_size(self):
        yield 16

    @pytest.fixture
    def seq_lens(self, request):
        yield torch.tensor(request.param).cuda()

    @pytest.fixture
    def ranks(self):
        yield torch.tensor([2, 4]).cuda()

    @pytest.fixture
    def start_loc(self, seq_lens):
        yield seq_lens.cumsum(0) - seq_lens

    @pytest.fixture
    def input(self, seq_lens, head_size, dtype):
        total_len = seq_lens.sum()
        yield torch.rand(total_len, head_size, dtype=dtype).cuda()

    @pytest.fixture
    def adapter_ids(self, seq_lens, ranks):
        num_ranks = len(ranks)
        num_seqs = len(seq_lens)
        ret = torch.arange(0, num_seqs) % num_ranks
        ret = ret.cuda()
        yield ret

    @pytest.fixture
    def scaling(self, ranks):
        yield torch.arange(ranks.size(0)).cuda() + 1

    @pytest.fixture
    def lora_a(self, ranks, head_size, dtype):
        out = []
        for rank in ranks:
            w = torch.rand(head_size, rank, dtype=dtype).cuda()
            out.append(w)
        yield out

    @pytest.fixture
    def lora_b(self, ranks, out_head_size, dtype):
        out = []
        for rank in ranks:
            w = torch.rand(rank, out_head_size, dtype=dtype).cuda()
            out.append(w)
        yield out

    @pytest.fixture
    def fused_lora_a(self, lora_a):
        yield torch.cat(lora_a, dim=1).t().contiguous()

    @pytest.fixture
    def fused_lora_b(self, lora_b):
        yield torch.cat(lora_b, dim=0).contiguous()

    @pytest.fixture
    def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b, scaling):
        out = []
        for loc, s_len, r_id in zip(start_loc, seq_lens, adapter_ids):
            inp = input[loc:loc + s_len]
            l_a = lora_a[r_id]
            l_b = lora_b[r_id]
            s = scaling[r_id]
            out.append(inp @ l_a @ l_b * s)

        yield torch.cat(out)

    @pytest.mark.parametrize('seq_lens', [
        (2, 4, 6, 8),
        (1, 1, 1, 1),
    ], indirect=True)
    def test_fused_lora(self, input, fused_lora_a, fused_lora_b, start_loc, seq_lens, adapter_ids, scaling, ranks, gt):
        max_seq_len = max(seq_lens).item()
        max_rank = max(ranks).item()
        rank_offset = ranks.cumsum(0) - ranks

        output = fused_lora(
            input,
            fused_lora_a,
            fused_lora_b,
            scaling=scaling,
            rank_start=rank_offset,
            ranks=ranks,
            seq_start=start_loc,
            seq_lens=seq_lens,
            adapter_ids=adapter_ids,
            max_rank=max_rank,
            max_seqlen=max_seq_len,
        )

        torch.testing.assert_close(gt, output)


================================================
FILE: tests/pytorch/kernel/test_fused_moe.py
================================================
import pytest
import torch
import torch.nn.functional as F


def _get_sorted_idx(topk_idx: torch.Tensor, num_experts: int):
    flatten_topk_idx = topk_idx.flatten()
    sorted_ids = flatten_topk_idx.argsort()
    exp_range = torch.arange(0, num_experts, device=topk_idx.device)
    exp_tok_cnt = (flatten_topk_idx[None, :] == exp_range[:, None]).sum(1)
    return sorted_ids, exp_tok_cnt


class TestFusedMoEKernelLauncher:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def device(self):
        yield torch.device('cuda')

    @pytest.fixture
    def N(self):
        yield 128

    @pytest.fixture
    def K(self):
        yield 64

    @pytest.fixture
    def M(self):
        yield 256

    @pytest.fixture
    def num_experts(self):
        yield 64

    @pytest.fixture
    def top_k(self):
        yield 6

    @pytest.fixture
    def A(self, M, K, device, dtype):
        ret = torch.rand(M, K, device=device, dtype=dtype)
        yield (ret - 0.5) / 2

    @pytest.fixture
    def B(self, num_experts, N, K, device, dtype):
        ret = torch.rand(num_experts, N, K, device=device, dtype=dtype)
        yield (ret - 0.5) / 2

    @pytest.fixture
    def bias(self, num_experts, N, device, dtype):
        yield torch.rand(num_experts, N, device=device, dtype=dtype) - 0.5

    @pytest.fixture
    def router_weights(self, M, num_experts, device, dtype):
        yield torch.rand(M, num_experts, device=device, dtype=dtype)

    @pytest.fixture
    def topk_weights(self, router_weights, top_k):
        yield router_weights.topk(top_k, dim=-1)

    @pytest.fixture
    def topk_idx(self, topk_weights):
        yield topk_weights[1]

    @pytest.fixture
    def sort_and_cnt(self, topk_idx, num_experts):
        yield _get_sorted_idx(topk_idx, num_experts)

    @pytest.fixture
    def sorted_idx(self, sort_and_cnt):
        yield sort_and_cnt[0]

    @pytest.fixture
    def exp_tok_cnt(self, sort_and_cnt):
        yield sort_and_cnt[1]

    @pytest.fixture
    def exp_end(self, exp_tok_cnt):
        yield exp_tok_cnt.cumsum(0)

    @pytest.fixture
    def exp_start(self, exp_end, exp_tok_cnt):
        yield exp_end - exp_tok_cnt

    @pytest.fixture
    def gt(self, A, B, bias, top_k, topk_idx):
        M = A.size(0)
        N = B.size(1)
        E = B.size(0)
        C = B.new_empty(M, top_k, N)
        for eid in range(E):
            EB = B[eid].t()
            Ebias = bias[eid]
            token_idx, k_idx = torch.where(topk_idx == eid)
            if len(token_idx) == 0:
                continue
            EC = A[token_idx] @ EB + Ebias
            C[token_idx, k_idx] = EC
        yield C.flatten(0, 1)

    @torch.inference_mode()
    def test_launcher(self, A, B, bias, sorted_idx, exp_start, exp_end, top_k, M, gt):
        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe_kernel_launcher
        N = B.size(1)
        C = B.new_empty(M * top_k, N)

        fused_moe_kernel_launcher(
            A,
            B,
            C,
            sorted_idx,
            exp_start,
            exp_end,
            bias=bias,
            top_k=top_k,
            num_tokens=M,
        )
        torch.testing.assert_close(C, gt, atol=1e-3, rtol=1e-3)


def _mlp_forward(hidden_states, gate_proj, up_proj, down_proj):
    gate = F.linear(hidden_states, gate_proj)
    up = F.linear(hidden_states, up_proj)
    return F.linear(F.silu(gate) * up, down_proj)


class TestFusedMoe:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def device(self):
        yield torch.device('cuda')

    @pytest.fixture
    def in_size(self):
        yield 128

    @pytest.fixture
    def seq_len(seq_len):
        yield 128

    @pytest.fixture
    def hidden_size(self):
        yield 256

    @pytest.fixture
    def out_size(self):
        yield 128

    @pytest.fixture
    def num_experts(self):
        yield 64

    @pytest.fixture
    def top_k(self):
        yield 6

    @pytest.fixture
    def renormalize(self):
        yield True

    @pytest.fixture
    def hidden_states(self, seq_len, in_size, dtype, device):
        ret = torch.rand(seq_len, in_size, dtype=dtype, device=device)
        yield (ret - 0.5) / 2

    @pytest.fixture
    def w1(self, num_experts, hidden_size, in_size, dtype, device):
        ret = torch.rand(num_experts, hidden_size, in_size, dtype=dtype, device=device)
        yield (ret - 0.5) / 2

    @pytest.fixture
    def w2(self, num_experts, out_size, hidden_size, dtype, device):
        ret = torch.rand(num_experts, out_size, hidden_size // 2, dtype=dtype, device=device)
        yield (ret - 0.5) / 2

    @pytest.fixture
    def router_logits(self, seq_len, num_experts, dtype, device):
        yield torch.rand(seq_len, num_experts, dtype=dtype, device=device)

    @pytest.fixture
    def topk_logits(self, router_logits, top_k):
        routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
        yield torch.topk(routing_weights, top_k, dim=-1)

    @pytest.fixture
    def topk_weights(self, topk_logits):
        yield topk_logits[0]

    @pytest.fixture
    def topk_idx(self, topk_logits):
        yield topk_logits[1]

    @pytest.fixture
    def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, renormalize):
        if renormalize:
            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

        seq_len = hidden_states.size(0)
        out_size = w2.size(1)
        output = hidden_states.new_zeros(seq_len, out_size)
        num_experts = w1.size(0)
        for eid in range(num_experts):
            token_idx, k_idx = torch.where(topk_idx == eid)
            gate_proj, up_proj = w1[eid].chunk(2, dim=0)
            down_proj = w2[eid]
            tmp_out = _mlp_forward(hidden_states[token_idx], gate_proj, up_proj, down_proj)
            tmp_out = tmp_out * topk_weights[token_idx, k_idx, None]
            output.index_add_(0, token_idx, tmp_out.to(output.dtype))
        yield output

    @torch.inference_mode()
    def test_fused_moe(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, renormalize, gt):
        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe
        output = fused_moe(hidden_states, w1, w2, topk_weights, topk_idx, topk=top_k, renormalize=renormalize)
        torch.testing.assert_close(output, gt, atol=1e-3, rtol=1e-3)


class TestFusedMoeW8A8(TestFusedMoe):

    @pytest.fixture
    def quant_states(self, hidden_states):
        from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8
        states_i8, states_scale = per_token_quant_int8(hidden_states, 1e-7)
        yield states_i8, states_scale

    def quant_weight(self, w):
        from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_channel_quant
        num_experts, num_outs, _ = w.shape
        w = w.flatten(0, -2)
        w_i8, w_scale = per_channel_quant(w, torch.int8)
        w_i8 = w_i8.view(num_experts, num_outs, -1)
        w_scale = w_scale.view(num_experts, num_outs, -1)
        return w_i8, w_scale

    @pytest.fixture
    def quant_w1(self, w1):
        w_i8, w_scale = self.quant_weight(w1)
        yield w_i8, w_scale

    @pytest.fixture
    def quant_w2(self, w2):
        w_i8, w_scale = self.quant_weight(w2)
        yield w_i8, w_scale

    @torch.inference_mode()
    def test_fused_moe(self, quant_states, quant_w1, quant_w2, topk_weights, topk_idx, top_k, renormalize, gt):
        from lmdeploy.pytorch.kernels.cuda.w8a8_fused_moe import fused_moe_w8a8
        state_i8, state_scale = quant_states
        w1_i8, w1_scale = quant_w1
        w2_i8, w2_scale = quant_w2

        output = fused_moe_w8a8(state_i8,
                                state_scale,
                                w1_i8,
                                w1_scale,
                                w2_i8,
                                w2_scale,
                                topk_weights=topk_weights,
                                topk_ids=topk_idx,
                                topk=top_k,
                                out_dtype=torch.float16,
                                renormalize=renormalize)
        torch.testing.assert_close(output, gt, atol=5e-3, rtol=1e-3)


================================================
FILE: tests/pytorch/kernel/test_gated_delta_rule.py
================================================
import pytest
import torch


def do_test():
    try:
        import tilelang  # noqa: F401
        return torch.cuda.is_available()
    except Exception:
        return False


def naive_recurrent_gdr(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    beta: torch.Tensor,
    g: torch.Tensor,
    scale: float = None,
    initial_state: torch.Tensor = None,
    output_final_state: bool = False,
    use_qk_l2norm_in_kernel: bool = False,
):
    dtype = q.dtype
    if use_qk_l2norm_in_kernel:
        q = torch.nn.functional.normalize(q, p=2, dim=-1)
        k = torch.nn.functional.normalize(k, p=2, dim=-1)
    q, k, v, beta, g = map(lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g])
    B, H, T, K, V = *k.shape, v.shape[-1]
    o = torch.zeros(B, H, T, V).to(v)
    h = torch.zeros(B, H, K, V).to(v)
    if initial_state is not None:
        h = initial_state.to(torch.float32)
    if scale is None:
        scale = 1 / (q.shape[-1]**0.5)
    q = q * scale

    for i in range(T):
        b_q = q[:, :, i]
        b_k = k[:, :, i]
        b_v = v[:, :, i].clone()
        h = h.clone() * g[:, :, i].exp()[..., None, None]
        b_beta = beta[:, :, i]
        b_v = b_v - (h.clone() * b_k[..., None]).sum(-2)
        b_v = b_v * b_beta[..., None]
        h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2)
        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', b_q, h)

    if not output_final_state:
        h = None
    o = o.transpose(1, 2).contiguous()
    o = o.to(dtype)
    if output_final_state:
        h = h.to(dtype)
    return o, h


@pytest.mark.skipif(not do_test(), reason='tilelang is not available')
class TestRecurrentGatedDeltaRule:

    @pytest.fixture(autouse=True)
    def auto_context(self):
        origin_dtype = torch.get_default_dtype()
        origin_device = torch.get_default_device()
        with torch.inference_mode():
            torch.set_default_dtype(torch.bfloat16)
            torch.set_default_device('cuda')
            try:
                yield
            finally:
                torch.set_default_dtype(origin_dtype)
                torch.set_default_device(origin_device)

    @pytest.fixture
    def batch(self):
        yield 512

    @pytest.fixture
    def num_heads(self):
        yield 16

    @pytest.fixture
    def seqlen(self):
        yield 1

    @pytest.fixture
    def head_dim(self):
        yield 128

    @pytest.fixture(params=[True, False])
    def use_qk_l2norm_in_kernel(self, request):
        yield request.param

    @pytest.fixture
    def q(self, batch, seqlen, num_heads, head_dim):
        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5

    @pytest.fixture
    def k(self, batch, seqlen, num_heads, head_dim):
        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5

    @pytest.fixture
    def v(self, batch, seqlen, num_heads, head_dim):
        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5

    @pytest.fixture
    def g(self, batch, seqlen, num_heads):
        yield -2 * torch.rand(batch, seqlen, num_heads)

    @pytest.fixture
    def beta(self, batch, seqlen, num_heads):
        yield torch.rand(batch, seqlen, num_heads)

    @pytest.fixture
    def initial_state(self, batch, num_heads, head_dim):
        yield torch.rand(batch, num_heads, head_dim, head_dim) - 0.5

    @pytest.fixture
    def gt(self, q, k, v, g, beta, initial_state, use_qk_l2norm_in_kernel):
        state_copy = initial_state.clone()
        yield naive_recurrent_gdr(q,
                                  k,
                                  v,
                                  beta,
                                  g,
                                  initial_state=state_copy,
                                  output_final_state=True,
                                  use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel)

    def test_fused_gated_delta_rule(self, q, k, v, g, beta, initial_state, use_qk_l2norm_in_kernel, gt):
        from lmdeploy.pytorch.kernels.cuda.gated_delta_rule import fused_recurrent_gated_delta_rule
        state_copy = initial_state.clone()
        out, out_h = fused_recurrent_gated_delta_rule(
            q=q,
            k=k,
            v=v,
            g=g,
            beta=beta,
            initial_state=state_copy,
            output_final_state=True,
            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
        )
        gt_o, gt_h = gt
        torch.testing.assert_close(out, gt_o, atol=1e-3, rtol=1e-4)
        torch.testing.assert_close(out_h, gt_h, atol=1e-2, rtol=1e-3)


================================================
FILE: tests/pytorch/kernel/test_gemm_fp8.py
================================================
import pytest
import torch


def _make_quant_val(shape, out_dtype):
    x = torch.rand(shape, dtype=torch.float32, device='cuda')
    # -1 ~ 1
    x = x * 2 - 1
    # scaling abs max to fmax
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scaling = fmax / x.abs().amax(-1, keepdim=True)
    x *= scaling
    return x.to(out_dtype).to(torch.float32)


def fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor:
    bits_x = x.view(torch.int32)
    exp_x = (bits_x >> 23) & 0xFF
    man_bits = bits_x & ((1 << 23) - 1)
    result = (exp_x - 127).to(torch.int32)
    result = result + torch.where(man_bits != 0, 1, 0)

    return result.to(torch.int32)


def fast_pow2_torch(x: torch.Tensor) -> torch.Tensor:
    bits_x = (x + 127) << 23
    return bits_x.view(torch.float32)


def fast_round_scale_torch(amax: torch.Tensor, fp8_max_inv: torch.Tensor) -> torch.Tensor:
    return fast_pow2_torch(fast_log2_ceil_torch(amax * fp8_max_inv))


def _make_quant_scale_ue8m0(shape, out_dtype):
    scale = torch.randn(shape, dtype=torch.float32, device='cuda')
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scale = fast_round_scale_torch(scale, 1 / fmax)
    return scale


def _make_quant_scale(shape, out_dtype, scale_fmt: str = None):
    if scale_fmt == 'ue8m0':
        return _make_quant_scale_ue8m0(shape, out_dtype)

    # default
    scale = torch.rand(shape, dtype=torch.float32, device='cuda')
    finfo = torch.finfo(out_dtype)
    fmax = finfo.max
    scale /= fmax
    return scale


def _make_A(M, K, group_size, out_dtype, scale_fmt: str = None):
    quant_A = _make_quant_val((M, K // group_size, group_size), out_dtype)

    # create scale and A
    scale = _make_quant_scale((M, K // group_size), out_dtype, scale_fmt)
    A = quant_A * scale[..., None]

    A = A.reshape(M, K)
    quant_A = quant_A.reshape(M, K).to(out_dtype)
    scale = scale.T.contiguous().T
    return A, quant_A, scale


def _aligned_size(a, b):
    return (a + b - 1) // b * b


def _make_B(K, N, group_size, out_dtype, scale_fmt: str = None):
    K_aligned = _aligned_size(K, group_size)
    N_aligned = _aligned_size(N, group_size)

    quant_B = _make_quant_val((K_aligned // group_size, group_size, N_aligned // group_size, group_size), out_dtype)

    scale = _make_quant_scale((K_aligned // group_size, 1, N_aligned // group_size, 1), out_dtype, scale_fmt)

    B = quant_B * scale

    B = B.reshape(K_aligned, N_aligned)[:K, :N]
    quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
    scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
    quant_B = quant_B.transpose(0, 1).contiguous().transpose(0, 1)
    return B, quant_B, scale


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestQuantFP8:

    @pytest.fixture
    def M(self, request):
        yield request.param

    @pytest.fixture
    def K(self):
        yield 512

    @pytest.fixture
    def group_size(self):
        yield 128

    @pytest.fixture
    def out_dtype(self):
        yield torch.float8_e4m3fn

    @pytest.fixture
    def scale_fmt(self, request):
        yield request.param

    @pytest.fixture
    def build_A(self, M, K, group_size, out_dtype, scale_fmt):
        return _make_A(M, K, group_size, out_dtype, scale_fmt)

    @pytest.fixture
    def A(self, build_A):
        return build_A[0]

    @pytest.fixture
    def quant_A(self, build_A):
        return build_A[1]

    @pytest.fixture
    def scale(self, build_A):
        return build_A[2]

    @pytest.fixture
    def gt(self, quant_A, scale):
        yield quant_A, scale

    @pytest.mark.parametrize('scale_fmt', [None, 'ue8m0'], indirect=True)
    @pytest.mark.parametrize('M', [65536, 256], indirect=True)
    def test_quant_fp8(self, A, group_size, out_dtype, scale_fmt, gt):
        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
        quant_A_gt, scale_gt = gt

        quant_A, scale = quant_fp8(A, group_size=group_size, dtype=out_dtype, scale_fmt=scale_fmt)
        torch.testing.assert_close(scale, scale_gt)
        diff = (quant_A.to(torch.float16) - quant_A_gt.to(torch.float16)).abs()
        diff_count = (diff > 1e-5).count_nonzero()
        assert diff_count / diff.numel() < 1e-4


@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
class TestGemmFP8:

    @pytest.fixture
    def M(self):
        yield 256

    @pytest.fixture
    def N(self):
        # test non-aligned
        yield 1024 + 64

    @pytest.fixture
    def K(self):
        yield 512

    @pytest.fixture
    def group_size(self):
        yield 128

    @pytest.fixture
    def quant_dtype(self):
        yield torch.float8_e4m3fn

    @pytest.fixture
    def out_dtype(self):
        yield torch.float16

    @pytest.fixture
    def build_A(self, M, K, group_size, quant_dtype):
        return _make_A(M, K, group_size, quant_dtype)

    @pytest.fixture
    def A(self, build_A, out_dtype):
        return build_A[0].to(out_dtype)

    @pytest.fixture
    def quant_A(self, build_A):
        return build_A[1]

    @pytest.fixture
    def scale_A(self, build_A):
        return build_A[2]

    @pytest.fixture
    def build_B(self, K, N, group_size, quant_dtype):
        return _make_B(K, N, group_size, quant_dtype)

    @pytest.fixture
    def B(self, build_B, out_dtype):
        return build_B[0].to(out_dtype)

    @pytest.fixture
    def quant_B(self, build_B):
        return build_B[1]

    @pytest.fixture
    def scale_B(self, build_B):
        return build_B[2]

    @pytest.fixture
    def gt(self, A, B):
        yield A @ B

    def test_gemm_fp8(self, quant_A, scale_A, quant_B, scale_B, out_dtype, gt):
        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8
        C = blocked_gemm_fp8(quant_A, scale_A, quant_B, scale_B, out_dtype=out_dtype)
        torch.testing.assert_close(C, gt, atol=0.5, rtol=1e-4)


================================================
FILE: tests/pytorch/kernel/test_moe_route.py
================================================
import pytest
import torch


def reference_noaux_tc_routing(
    logits: torch.Tensor,
    bias: torch.Tensor,
    num_experts: int = 256,
    n_group: int = 8,
    topk_group: int = 4,
    top_k: int = 8,
    renormalize: bool = True,
    routed_scaling_factor: float = 2.5,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size = logits.shape[0]
    scores = torch.sigmoid(logits.float())
    scores_for_choice = scores + bias[None, :]

    group_size = num_experts // n_group
    grouped_scores = scores_for_choice.view(batch_size, n_group, group_size)
    group_scores = grouped_scores.topk(2, dim=-1)[0].sum(dim=-1)

    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores).scatter_(1, group_idx, 1)

    score_mask = group_mask.unsqueeze(-1).expand(batch_size, n_group, group_size).reshape(batch_size, -1)
    # Note: Using 0.0 matches the actual inference code in deepseek_v2.py
    # Works correctly because sigmoid scores are always in (0, 1)
    tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)

    _, topk_idx = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
    topk_weight = scores.gather(1, topk_idx)

    if renormalize:
        topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)

    return topk_weight * routed_scaling_factor, topk_idx


class TestNoauxTC:

    @pytest.fixture(autouse=True)
    def auto_context(self):
        origin_dtype = torch.get_default_dtype()
        origin_device = torch.get_default_device()
        with torch.inference_mode():
            torch.set_default_dtype(torch.float32)
            torch.set_default_device('cuda')
            try:
                yield
            finally:
                torch.set_default_dtype(origin_dtype)
                torch.set_default_device(origin_device)

    @pytest.fixture
    def batch_size(self):
        yield 32

    @pytest.fixture
    def num_experts(self):
        yield 256

    @pytest.fixture
    def logits(self, batch_size, num_experts):
        yield torch.randn(batch_size, num_experts)

    @pytest.fixture
    def bias(self, num_experts):
        yield torch.randn(num_experts)

    @pytest.fixture
    def kwargs(self):
        yield {
            'num_experts': 256,
            'n_group': 8,
            'topk_group': 4,
            'top_k': 8,
            'renormalize': True,
            'routed_scaling_factor': 2.5,
        }

    @pytest.fixture
    def gt(self, logits, bias, kwargs):
        yield reference_noaux_tc_routing(logits, bias, **kwargs)

    def test_noaux_tc_router(self, logits, bias, kwargs, gt):
        from lmdeploy.pytorch.kernels.cuda.fused_noaux_tc import fused_noaux_tc_routing

        out_weights, out_ids = fused_noaux_tc_routing(logits, bias, **kwargs)
        gt_weights, gt_ids = gt

        torch.testing.assert_close(out_weights, gt_weights, rtol=1e-4, atol=1e-5)
        # topk in torch is not stable, so we won't assert ids


================================================
FILE: tests/pytorch/kernel/test_multinomial_sampling.py
================================================
import pytest
import torch

from lmdeploy.utils import is_bf16_supported


def _bf16_mark():
    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')


class TestMultinomialSampling:

    @pytest.fixture
    def num_tokens(self, request):
        yield request.param

    @pytest.fixture
    def select_ids(self, request):
        yield request.param

    @pytest.fixture
    def batch_size(self, select_ids):
        yield len(select_ids)

    @pytest.fixture
    def dtype(self, request):
        yield request.param

    @pytest.fixture
    def scores(self, num_tokens, batch_size, select_ids, dtype):
        ret = torch.zeros(batch_size, num_tokens).cuda()
        batch_ids = torch.arange(batch_size).cuda()
        ret[batch_ids, select_ids] = 1
        ret = ret.to(dtype)
        yield ret

    @pytest.fixture
    def seeds(self, batch_size):
        yield torch.randint(1000, 2000, (batch_size, )).cuda()

    @pytest.fixture
    def offsets(self, batch_size):
        yield torch.randint(1000, 2000, (batch_size, )).cuda()

    @pytest.fixture
    def indices(self, scores):
        num_tokens = scores.size(1)
        ret = [torch.randperm(num_tokens) for _ in scores]
        ret = torch.stack(ret, 0).cuda()
        yield ret

    @pytest.fixture
    def gt(self, batch_size, select_ids, indices):
        batch_ids = torch.arange(batch_size).cuda()
        yield indices[batch_ids, select_ids]

    @pytest.mark.parametrize('dtype', [torch.float32, torch.half, pytest.param(torch.bfloat16, marks=_bf16_mark())])
    @pytest.mark.parametrize(['num_tokens', 'select_ids'], [
        (8, (4, 2) * 30),
        (2000, (500, 1500)),
    ], indirect=True)
    def test_multinomial_sampling(self, scores, seeds, offsets, indices, gt):
        from lmdeploy.pytorch.kernels.cuda import multinomial_sampling
        output = multinomial_sampling(scores, seeds, offsets, indices)
        torch.testing.assert_close(output, gt)


================================================
FILE: tests/pytorch/kernel/test_paged_attention.py
================================================
import math

import pytest
import torch


def _conti_input(data, seq_lens):
    data = [x[:l] for x, l in zip(data, seq_lens)]
    data = torch.cat(data, dim=0)
    return data


def _make_bias(q_seqlens, history_lens, neg_val):
    batch_size = q_seqlens.shape[0]
    full_seq_lens = q_seqlens + history_lens
    max_seq_len = q_seqlens.max().item()
    max_kv_len = full_seq_lens.max().item()
    seq_ranges = torch.arange(max_seq_len).cuda()
    seq_ranges = seq_ranges.repeat(batch_size, 1)
    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

    kv_ranges = torch.arange(max_kv_len).cuda()
    kv_ranges = kv_ranges.repeat(batch_size, 1)
    mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]
    return mask.float() * neg_val


def _make_alibi_bias(q_seqlens, history_lens, neg_val, alibi_slopes):
    batch_size = q_seqlens.shape[0]
    kv_seqlens = q_seqlens + history_lens
    max_seq_len = q_seqlens.max().item()
    max_kv_len = kv_seqlens.max().item()

    seq_ranges = torch.arange(max_seq_len).cuda()
    seq_ranges = seq_ranges.repeat(batch_size, 1) + history_lens[:, None]

    kv_ranges = torch.arange(max_kv_len).cuda()
    kv_ranges = kv_ranges.repeat(batch_size, 1)

    diff = (seq_ranges[:, :, None] - kv_ranges[:, None, :]).abs()
    slope_diff = -diff[:, None] * alibi_slopes[None, :, None, None]

    # add bias
    bias = _make_bias(q_seqlens, history_lens, neg_val)
    bias = bias[:, None] + slope_diff
    return bias


def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,
                            block_sparse_size: int):
    """Make block sparse bias."""
    batch_size = q_seqlens.shape[0]
    kv_seqlens = q_seqlens + history_lens
    max_seq_len = q_seqlens.max().item()
    max_kv_len = kv_seqlens.max().item()

    seq_ranges = torch.arange(max_seq_len).cuda()
    seq_ranges = seq_ranges // block_sparse_size * block_sparse_size
    seq_ranges = seq_ranges.repeat(batch_size, 1)
    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

    kv_ranges = torch.arange(max_kv_len).cuda()
    kv_ranges = kv_ranges // block_sparse_size * block_sparse_size
    kv_ranges = kv_ranges.repeat(batch_size, 1)

    mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
    return mask.float() * neg_val


def _make_blocked_cache(batched_k,
                        batched_v,
                        seq_lens,
                        history_lens,
                        block_offsets,
                        block_size,
                        num_heads_k,
                        feat_dim,
                        feat_dim_v,
                        layout: str = 'bshd'):
    max_blocks_nums = block_offsets.max() + 1
    full_seq_lens = seq_lens + history_lens
    blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim)
    blocked_v = batched_v.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim_v)

    for batch_id, offset in enumerate(block_offsets):
        ori_k = batched_k[batch_id]
        ori_v = batched_v[batch_id]
        seq_len = full_seq_lens[batch_id]
        for block_id, block_start in enumerate(range(0, seq_len, block_size)):
            block_off = offset[block_id]
            tmp_k = ori_k[block_start:block_start + block_size]
            tmp_v = ori_v[block_start:block_start + block_size]
            size = tmp_k.size(0)
            blocked_k[block_off, :size] = tmp_k
            blocked_v[block_off, :size] = tmp_v

    if layout == 'bhsd':
        blocked_k = blocked_k.transpose(1, 2).contiguous()
        blocked_v = blocked_v.transpose(1, 2).contiguous()

    return blocked_k, blocked_v


def _naive_attention(batched_q, batched_kv, bias, sinks=None):
    batched_k, batched_v = batched_kv

    num_heads_q = batched_q.shape[2]
    num_heads_k = batched_k.shape[2]
    head_dim = batched_q.shape[-1]
    group = num_heads_q // num_heads_k

    q = batched_q.transpose(1, 2)
    k = batched_k.permute(0, 2, 3, 1)
    v = batched_v.transpose(1, 2)

    # expand group
    k = k.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)
    v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)

    qk = torch.matmul(q, k) / math.sqrt(head_dim)
    if bias.dim() == 3:
        bias = bias[:, None]
    attn_weight = qk + bias
    if sinks is None:
        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
    else:
        sinks = sinks[None, :, None, None].to(torch.float32)
        sinks = sinks.expand(attn_weight.shape[0], -1, attn_weight.shape[2], -1)
        attn_weight = attn_weight.to(torch.float32)
        combined_logits = torch.cat([attn_weight, sinks], dim=-1)
        combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
        attn_weight = torch.softmax(combined_logits, dim=-1, dtype=torch.float32)
        attn_weight = attn_weight[..., :-1]
    attn_weight = attn_weight.to(q.dtype)
    attn_output = torch.matmul(attn_weight, v)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output


def _naive_window_attention(q, k, v, seqlens_q, seqlens_k, window_size):
    try:
        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
    except Exception:
        try:
            from flash_attn import flash_attn_varlen_func
        except Exception:
            pytest.skip('Skip window attention test since flash attention is not available.')

    def _make_cu_seqlens(seqlens):
        cu_seqlens = seqlens.cumsum(0)
        cu_zero = cu_seqlens.new_zeros(1)
        cu_seqlens = torch.cat([cu_zero, cu_seqlens])
        return cu_seqlens

    max_seqlen_q = seqlens_q.max().item()
    max_seqlen_k = seqlens_k.max().item()
    cu_seqlens_q = _make_cu_seqlens(seqlens_q).int()
    cu_seqlens_k = _make_cu_seqlens(seqlens_k).int()

    output = flash_attn_varlen_func(q,
                                    k,
                                    v,
                                    cu_seqlens_q,
                                    cu_seqlens_k,
                                    max_seqlen_q=max_seqlen_q,
                                    max_seqlen_k=max_seqlen_k,
                                    causal=True,
                                    window_size=window_size)
    return output


class TestPagedAttentionBase:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def feat_dim(self, request):
        yield request.param

    @pytest.fixture
    def feat_dim_v(self, request):
        yield request.param

    @pytest.fixture
    def num_heads_q(self, request):
        yield request.param

    @pytest.fixture
    def num_heads_k(self, request):
        yield request.param

    @pytest.fixture
    def block_size(self, request):
        yield request.param

    @pytest.fixture
    def layout(self, request):
        yield request.param

    @pytest.fixture
    def history_lens(self, request):
        yield torch.tensor(request.param, device='cuda')

    @pytest.fixture
    def seq_len(self):
        yield 1

    @pytest.fixture
    def seq_lens(self, seq_len, history_lens):
        yield torch.ones_like(history_lens) * seq_len

    @pytest.fixture
    def kv_seqlens(self, seq_lens, history_lens):
        yield seq_lens + history_lens

    @pytest.fixture
    def batched_q(self, seq_len, kv_seqlens, num_heads_q, feat_dim, dtype):
        torch.manual_seed(123)
        batch_size = len(kv_seqlens)
        inputs = torch.rand(batch_size, seq_len, num_heads_q, feat_dim, dtype=dtype, device='cuda')
        yield inputs

    @pytest.fixture
    def batched_kv(self, kv_seqlens, num_heads_k, feat_dim, feat_dim_v, dtype):
        torch.manual_seed(123)
        batch_size = len(kv_seqlens)
        max_seq_len = kv_seqlens.max().item()
        k = torch.rand(batch_size, max_seq_len, num_heads_k, feat_dim, dtype=dtype, device='cuda')
        v = torch.rand(batch_size, max_seq_len, num_heads_k, feat_dim_v, dtype=dtype, device='cuda')
        yield k, v

    @pytest.fixture
    def conti_q(self, seq_lens, batched_q):
        yield _conti_input(batched_q, seq_lens)

    @pytest.fixture
    def block_offsets(self, kv_seqlens, block_size):
        batch_size = kv_seqlens.size(0)
        num_blocks = (kv_seqlens + block_size - 1) // block_size

        offset = [torch.arange(size) * batch_size + idx for idx, size in enumerate(num_blocks)]
        max_len = max(len(o) for o in offset)
        new_offset = offset[0].new_zeros(batch_size, max_len)
        for o, no in zip(offset, new_offset):
            len_o = o.size(0)
            no[:len_o] = o

        yield new_offset.cuda()

    @pytest.fixture
    def conti_kv(self, batched_kv, history_lens):
        full_seq_lens = 1 + history_lens
        conti_k = _conti_input(batched_kv[0], full_seq_lens)
        conti_v = _conti_input(batched_kv[1], full_seq_lens)
        yield (conti_k, conti_v)

    @pytest.fixture
    def blocked_kv(self, batched_kv, kv_seqlens, history_lens, block_offsets, block_size, num_heads_k, feat_dim,
                   feat_dim_v, layout):
        batched_k, batched_v = batched_kv
        seq_lens = torch.ones_like(kv_seqlens)
        yield _make_blocked_cache(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, num_heads_k,
                                  feat_dim, feat_dim_v, layout)

    @pytest.fixture
    def mask(self, history_lens):
        neg_val = -1e30
        seq_lens = torch.ones_like(history_lens)
        yield _make_bias(seq_lens, history_lens, neg_val)

    @pytest.fixture
    def gt(self, batched_q, batched_kv, mask):
        yield _naive_attention(batched_q, batched_kv, mask)

    @pytest.fixture
    def conti_gt(self, gt, seq_lens):
        yield _conti_input(gt, seq_lens)


class TestPagedAttention(TestPagedAttentionBase):

    @pytest.mark.parametrize('feat_dim', [32, 32], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    @pytest.mark.parametrize('layout', ['bshd', 'bhsd'], indirect=True)
    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, conti_gt):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v = blocked_kv
        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      kv_layout=layout)
        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)

    @pytest.fixture
    def win_size(self, request):
        yield request.param

    @pytest.fixture
    def window_gt(self, conti_q, conti_kv, seq_lens, history_lens, win_size):
        kv_lens = seq_lens + history_lens
        yield _naive_window_attention(conti_q,
                                      conti_kv[0],
                                      conti_kv[1],
                                      seq_lens,
                                      kv_lens,
                                      window_size=(win_size, win_size))

    @pytest.mark.parametrize('feat_dim', [16], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [16], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [
        (50, 40, 30, 20),
    ], indirect=True)
    @pytest.mark.parametrize('win_size', (32, ), indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)
    def test_window_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, win_size, layout, window_gt):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v = blocked_kv
        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      window_size=win_size,
                                      kv_layout=layout)
        torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)


class TestPagedAttentionSink(TestPagedAttentionBase):

    @pytest.fixture
    def sinks(self, num_heads_q, dtype):
        yield torch.rand(num_heads_q, dtype=dtype, device='cuda')

    @pytest.fixture
    def sink_gt(self, batched_q, batched_kv, mask, sinks):
        yield _naive_attention(batched_q, batched_kv, mask, sinks)

    @pytest.fixture
    def conti_sink_gt(self, sink_gt, seq_lens):
        yield _conti_input(sink_gt, seq_lens)

    @pytest.mark.parametrize('feat_dim', [32], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)
    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, sinks, conti_sink_gt):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v = blocked_kv

        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      sinks=sinks,
                                      kv_layout=layout)
        torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)


def quant(kv: torch.Tensor, nbits: int = 8):
    """Quant kv on the head_dim."""
    amax = kv.amax(dim=-1, keepdim=True)
    amin = kv.amin(dim=-1, keepdim=True)
    scales = (amax - amin) / (2**nbits - 1)
    zeros = -amin / scales
    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)
    if nbits == 4:
        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)
        q_kv = q_kv1 + q_kv2 * 16
    return q_kv, torch.cat([scales, zeros], dim=-1)


def _make_blocked_cache_quant(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, num_heads_k,
                              feat_dim, feat_dim_v, nbits):
    max_blocks_nums = block_offsets.max() + 1
    full_seq_lens = seq_lens + history_lens
    batched_k, k_scales_zeros = quant(batched_k, nbits)
    batched_v, v_scales_zeros = quant(batched_v, nbits)
    if nbits == 4:
        feat_dim //= 2
        feat_dim_v //= 2
    blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim)
    blocked_v = batched_v.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim_v)
    blocked_ksz = k_scales_zeros.new_zeros(max_blocks_nums, block_size, num_heads_k, 2)
    blocked_vsz = v_scales_zeros.new_zeros(max_blocks_nums, block_size, num_heads_k, 2)

    for batch_id, offset in enumerate(block_offsets):
        ori_k = batched_k[batch_id]
        ori_v = batched_v[batch_id]
        ori_ksz = k_scales_zeros[batch_id]
        ori_vsz = v_scales_zeros[batch_id]
        seq_len = full_seq_lens[batch_id]
        for block_id, block_start in enumerate(range(0, seq_len, block_size)):
            block_off = offset[block_id]
            tmp_k = ori_k[block_start:block_start + block_size]
            tmp_v = ori_v[block_start:block_start + block_size]
            tmp_ksz = ori_ksz[block_start:block_start + block_size]
            tmp_vsz = ori_vsz[block_start:block_start + block_size]
            size = tmp_k.size(0)
            blocked_k[block_off, :size] = tmp_k
            blocked_v[block_off, :size] = tmp_v
            blocked_ksz[block_off, :size] = tmp_ksz
            blocked_vsz[block_off, :size] = tmp_vsz

    return blocked_k, blocked_v, blocked_ksz, blocked_vsz


class TestPagedAttentionInt8(TestPagedAttention):

    @pytest.fixture
    def nbits(self):
        yield 8

    @pytest.fixture
    def blocked_kv(self, batched_kv, seq_lens, history_lens, block_offsets, block_size, num_heads_k, feat_dim,
                   feat_dim_v, nbits):
        batched_k, batched_v = batched_kv
        yield _make_blocked_cache_quant(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size,
                                        num_heads_k, feat_dim, feat_dim_v, nbits)

    @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, conti_gt, nbits):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v, blocked_ksz, blocked_vsz = blocked_kv

        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      k_scales_zeros=blocked_ksz,
                                      v_scales_zeros=blocked_vsz,
                                      quant_policy=nbits,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens)
        if nbits == 4:
            torch.testing.assert_close(out, conti_gt, atol=0.05, rtol=0.01)
        else:
            torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)

    @pytest.mark.parametrize('feat_dim', [16], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [16], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [
        (50, 40, 30, 20),
    ], indirect=True)
    @pytest.mark.parametrize('win_size', (32, ), indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    def test_window_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, win_size, window_gt, nbits):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v, blocked_ksz, blocked_vsz = blocked_kv
        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      k_scales_zeros=blocked_ksz,
                                      v_scales_zeros=blocked_vsz,
                                      quant_policy=nbits,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      window_size=win_size)
        if nbits == 4:
            torch.testing.assert_close(out, window_gt, atol=0.05, rtol=0.01)
        else:
            torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)


class TestPagedAttentionInt4(TestPagedAttentionInt8):

    @pytest.fixture
    def nbits(self):
        yield 4


class TestPagedAttentionBlockDecoding(TestPagedAttentionBase):

    @pytest.fixture
    def seq_len(self):
        yield 4

    @pytest.fixture
    def mask(self, seq_lens, history_lens, seq_len):
        neg_val = -1e30
        yield _make_block_sparse_bias(seq_lens, history_lens, neg_val, seq_len)

    @pytest.fixture
    def gt(self, batched_q, batched_kv, mask):
        yield _naive_attention(batched_q, batched_kv, mask)

    @pytest.fixture
    def conti_gt(self, gt, seq_lens):
        yield _conti_input(gt, seq_lens)

    @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)
    @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)
    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, conti_gt):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v = blocked_kv

        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      kv_layout=layout)
        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)


class TestPagedAttentionAlibi(TestPagedAttentionBase):

    @pytest.fixture
    def alibi_slopes(self, num_heads_q):
        yield torch.rand(num_heads_q, dtype=torch.float32, device='cuda')

    @pytest.fixture
    def mask(self, seq_lens, history_lens, alibi_slopes):
        neg_val = -1e30
        yield _make_alibi_bias(seq_lens, history_lens, neg_val, alibi_slopes)

    @pytest.mark.parametrize('feat_dim', [128], indirect=True)
    @pytest.mark.parametrize('feat_dim_v', [128], indirect=True)
    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(40, 8)], indirect=True)
    @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True)
    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)
    @pytest.mark.parametrize('block_size', [16], indirect=True)
    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, alibi_slopes, conti_gt):
        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache

        blocked_k, blocked_v = blocked_kv

        out = flash_attn_with_kvcache(conti_q,
                                      blocked_k,
                                      blocked_v,
                                      page_table=block_offsets,
                                      cache_seqlens=kv_seqlens,
                                      alibi_slopes=alibi_slopes,
                                      kv_layout=layout)
        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)


================================================
FILE: tests/pytorch/kernel/test_rms_norm.py
================================================
import pytest
import torch

from lmdeploy.utils import is_bf16_supported


def _bf16_mark():
    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')


class TestRMSNorm:

    @pytest.fixture(autouse=True, scope='class')
    def initialize(self):
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        yield

    @pytest.fixture(scope='class')
    def dtype(self, request):
        yield request.param

    @pytest.fixture(scope='class')
    def input_shape(self, request):
        yield request.param

    @pytest.fixture(scope='class')
    def hidden_size(self, input_shape):
        yield input_shape[-1]

    @pytest.fixture(scope='class')
    def input(self, dtype, input_shape):
        yield torch.randn(input_shape, dtype=dtype, device='cuda')

    @pytest.fixture(scope='class')
    def weight(self, dtype, hidden_size):
        yield torch.randn(hidden_size, dtype=dtype, device='cuda')

    @pytest.fixture(scope='class')
    def eps(self):
        yield 1e-6

    @pytest.fixture(scope='class')
    def gt(self, input, weight, eps):
        input_dtype = input.dtype
        input = input.to(torch.float32)
        variance = (input * input).mean(-1, keepdim=True)
        input = input * torch.rsqrt(variance + eps)
        return weight * input.to(input_dtype)

    @pytest.mark.parametrize('input_shape', [(2, 4, 4096), (4, 4096), (4096, )], indirect=True)
    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16], indirect=True)
    def test_rms_norm(self, input, weight, eps, gt):
        from lmdeploy.pytorch.kernels.cuda import rms_norm

        out = rms_norm(input, weight, eps)
        torch.testing.assert_close(out, gt)

    @pytest.fixture(scope='class')
    def residual(self, dtype, input_shape):
        yield torch.randn(input_shape, dtype=dtype, device='cuda')

    @pytest.fixture(scope='class')
    def gt_residual(self, input, residual, weight, eps):

        input = input + residual
        out_res = input
        input_dtype = input.dtype
        input = input.to(torch.float32)
        variance = (input * input).mean(-1, keepdim=True)
        input = input * torch.rsqrt(variance + eps)
        return weight * input.to(input_dtype), out_res

    @pytest.mark.parametrize('input_shape', [(2, 4, 4096), (4, 4096), (4096, )], indirect=True)
    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16], indirect=True)
    def test_rms_norm_residual(self, input, residual, weight, eps, gt_residual):
        from lmdeploy.pytorch.kernels.cuda import rms_norm

        out, out_res = rms_norm(input, weight, eps, residual=residual)
        gt, gt_res = gt_residual
        torch.testing.assert_close(out, gt)
        torch.testing.assert_close(out_res, gt_res)


================================================
FILE: tests/pytorch/nn/test_embedding.py
================================================
import os
import time

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn

from lmdeploy.pytorch.distributed import DefaultContext
from lmdeploy.pytorch.nn import ParallelEmbedding


def parallel_emb(rank: int, world_size: int, vocab_size: int, feat_size: int, padding_idx: int, dtype: torch.dtype,
                 x: torch.Tensor, weight: torch.Tensor, result_queue: mp.Queue):
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    gpu_group = dist.new_group(ranks=list(range(world_size)), backend='nccl')

    DefaultContext.attn_tp_group.rank = rank
    DefaultContext.dist_config.attn_tp = world_size
    DefaultContext.attn_tp_group.gpu_group = gpu_group

    model = ParallelEmbedding(vocab_size=vocab_size,
                              hidden_size=feat_size,
                              padding_idx=padding_idx,
                              dtype=dtype,
                              is_tp=True,
                              device=torch.device(type='cuda', index=rank))

    weight = weight.to(torch.device(type='cuda', index=rank))
    model.weight_loader(model.weight, weight)

    input = x.to(torch.device(type='cuda', index=rank))

    with torch.inference_mode():
        out = model(input)

    if rank == 0:
        result_queue.put(mp.reductions.reduce_tensor(out))

    if dist.is_initialized():
        dist.destroy_process_group()


class TestEmbedding:

    @pytest.fixture
    def vocab_size(self, request):
        yield request.param

    @pytest.fixture
    def feat_size(self, request):
        yield request.param

    @pytest.fixture
    def padding_idx(self, request):
        yield request.param

    @pytest.fixture
    def dtype(self, request):
        yield request.param

    @pytest.fixture
    def tp(self, request):
        yield request.param

    @pytest.fixture
    def seqlen(self, request):
        yield request.param

    @pytest.fixture
    def weight(self, vocab_size, feat_size, dtype):
        yield torch.rand(vocab_size, feat_size, dtype=dtype)

    @pytest.fixture
    def x(self, seqlen, vocab_size):
        yield torch.randint(low=0, high=vocab_size, size=(seqlen, ), dtype=torch.int32)

    @pytest.fixture
    def gt(self, x, vocab_size, feat_size, padding_idx, dtype, weight):
        token_emb = nn.Embedding(vocab_size,
                                 feat_size,
                                 padding_idx=padding_idx,
                                 dtype=dtype,
                                 device=torch.device(type='cuda', index=0))
        token_emb.weight.data.copy_(weight)
        token_emb._fill_padding_idx_with_zero()
        input = x.to(torch.device(type='cuda', index=0))
        yield token_emb(input)

    @pytest.mark.parametrize('vocab_size', [65576, 65533, 3333], indirect=True)
    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)
    @pytest.mark.parametrize('padding_idx', [None], indirect=True)
    @pytest.mark.parametrize('seqlen', [1024, 1011, 128], indirect=True)
    @pytest.mark.parametrize('tp', [2], indirect=True)
    @pytest.mark.parametrize('dtype', [torch.bfloat16], indirect=True)
    def test_embedding(self, vocab_size, feat_size, padding_idx, seqlen, tp, dtype, x, weight, gt):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '29500'
        os.environ['NCCL_SOCKET_IFNAME'] = 'lo'

        world_size = tp
        processes = []
        mp.set_start_method('spawn', force=True)
        result_queue = mp.Queue()

        for rank in range(world_size):
            p = mp.Process(target=parallel_emb,
                           args=(rank, world_size, vocab_size, feat_size, padding_idx, dtype, x, weight, result_queue))
            processes.append(p)
            p.start()
            time.sleep(0.5)

        func, args = result_queue.get()
        out = func(*args)

        for p in processes:
            p.join(timeout=10)
            if p.is_alive():
                p.terminate()
                p.join(timeout=5)
                if p.is_alive():
                    p.kill()

        torch.testing.assert_close(out, gt)


================================================
FILE: tests/pytorch/paging/test_block_manager.py
================================================
# yapf: disable
import pytest
import torch

from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
from lmdeploy.pytorch.messages import SequenceMeta
from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator
from lmdeploy.pytorch.paging.scheduler import Scheduler

# yapf: enable


class TestAllocator:

    @pytest.fixture
    def num_gpu_blocks(self):
        yield 16

    @pytest.fixture
    def num_cpu_blocks(self):
        yield 4

    @pytest.fixture
    def allocator(self, num_cpu_blocks, num_gpu_blocks):
        yield LogicalAllocator(num_cpu_blocks, num_gpu_blocks)

    def test_alloc(self, allocator, num_cpu_blocks, num_gpu_blocks):

        # initialize
        num_blocks = num_cpu_blocks + num_gpu_blocks
        gpu_allocator = allocator.get_phy_allocator('gpu')
        cpu_allocator = allocator.get_phy_allocator('cpu')
        assert allocator.get_num_free_blocks() == num_blocks
        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks
        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks

        # test allocate
        block_size = 4
        blocks = allocator.allocate(block_size, 'gpu')
        assert len(blocks) == block_size
        assert allocator.get_num_free_blocks() == num_blocks - block_size
        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks - block_size

        # test free
        allocator.add_ref_count(blocks, 1)
        allocator.free(blocks)
        assert allocator.get_num_free_blocks() == num_blocks - block_size
        allocator.free(blocks)
        assert allocator.get_num_free_blocks() == num_blocks
        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks
        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks

    def test_full(self, allocator, num_cpu_blocks, num_gpu_blocks):

        num_blocks = num_cpu_blocks + num_gpu_blocks
        gpu_allocator = allocator.get_phy_allocator('gpu')
        cpu_allocator = allocator.get_phy_allocator('cpu')

        # no free blocks
        gpu_block_size = num_gpu_blocks
        gpu_blocks = allocator.allocate(gpu_block_size, 'gpu')
        cpu_block_size = num_cpu_blocks
        cpu_blocks = allocator.allocate(cpu_block_size, 'cpu')
        assert cpu_allocator.get_num_free_blocks() == 0
        assert gpu_allocator.get_num_free_blocks() == 0
        with pytest.raises(MemoryError):
            allocator.allocate(1, 'gpu')
        allocator.free(gpu_blocks)
        allocator.free(cpu_blocks)
        assert allocator.get_num_free_blocks() == num_blocks
        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks
        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks


class TestDefaultBlockManager:

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def num_cpu_blocks(self):
        yield 4

    @pytest.fixture
    def num_gpu_blocks(self):
        yield 4

    @pytest.fixture
    def max_batch_size(self):
        yield 4

    @pytest.fixture
    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):
        yield CacheConfig(max_batches=max_batch_size,
                          block_size=block_size,
                          num_cpu_blocks=num_cpu_blocks,
                          num_gpu_blocks=num_gpu_blocks)

    @pytest.fixture
    def scheduler_config(self, max_batch_size):
        yield SchedulerConfig(max_batches=max_batch_size,
                              max_session_len=128,
                              max_request_output_len=64,
                              eviction_type='recompute')

    @pytest.fixture
    def seq_meta(self, block_size):
        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
        strategy = ARSequenceStrategy()
        yield SequenceMeta(block_size, strategy=strategy)

    @pytest.fixture
    def scheduler(self, cache_config, scheduler_config, seq_meta):
        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)

    @pytest.fixture
    def block_mgr(self, scheduler):
        yield scheduler.block_manager

    def test_alloc(self, scheduler, block_mgr, num_gpu_blocks):
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        # test alloc
        token_ids = torch.tensor([1])
        msg = sess.add_sequence(token_ids)
        assert block_mgr.can_allocate(msg)
        block_mgr.allocate(msg)
        block_table = block_mgr.get_block_table(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1
        assert block_table is not None
        assert len(block_table) == 1

        # test free
        block_mgr.free(msg)
        block_table = block_mgr.get_block_table(msg)
        assert block_table is None or len(block_table) == 0
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks

        # alloc over limit
        token_ids = torch.zeros((num_gpu_blocks * block_size + 1, ), dtype=torch.int64)
        msg = sess.add_sequence(token_ids)
        assert not block_mgr.can_allocate(msg)

    def test_num_required_blocks(self, scheduler, block_mgr):
        from lmdeploy.pytorch.messages import InputEmbeddings
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        token_ids = torch.tensor([1])
        msg = sess.add_sequence(token_ids)
        num_required = block_mgr.num_required_blocks(msg)
        assert num_required == 1

        embedding = InputEmbeddings(None, 0, block_size * 2)
        msg = sess.add_sequence(token_ids, input_embeddings=[embedding])
        num_required = block_mgr.num_required_blocks(msg)
        assert num_required == 1

        token_ids = torch.tensor([1] * block_size * 3)
        embedding = InputEmbeddings(None, 0, block_size * 2)
        msg = sess.add_sequence(token_ids, input_embeddings=[embedding])
        num_required = block_mgr.num_required_blocks(msg)
        assert num_required == 3

    def test_append_slot(self, scheduler, block_mgr, num_gpu_blocks):
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        # test append
        token_ids = torch.tensor([1])
        msg = sess.add_sequence(token_ids)
        block_mgr.allocate(msg)
        block_table = block_mgr.get_block_table(msg)
        assert len(block_table) == 1

        # no new logical block
        msg.update_token_ids(torch.tensor([1] * (block_size - 1)))
        assert block_mgr.can_allocate(msg)
        block_mgr.allocate(msg)
        block_table = block_mgr.get_block_table(msg)
        assert len(block_table) == 1
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1

        # with new logical block
        msg.update_token_ids(torch.tensor([1]))
        block_mgr.allocate(msg)
        block_table = block_mgr.get_block_table(msg)
        assert len(block_table) == 2
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2

    def test_swap(self, scheduler, block_mgr, num_gpu_blocks):
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        token_ids = torch.tensor([1] * (block_size + 1))
        msg = sess.add_sequence(token_ids)
        block_mgr.allocate(msg)

        old_phy_blocks = block_mgr.get_block_table(msg)
        success, swap_map = block_mgr.try_swap_out(msg)
        new_phy_blocks = block_mgr.get_block_table(msg)
        assert success
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks
        assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks - 2
        assert len(swap_map) == 2
        for block_id in old_phy_blocks:
            assert block_id in swap_map
        for block_id in new_phy_blocks:
            assert block_id - num_gpu_blocks in swap_map.values()

        old_phy_blocks = block_mgr.get_block_table(msg)
        success, swap_map = block_mgr.try_swap_in(msg)
        new_phy_blocks = block_mgr.get_block_table(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2
        assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks
        assert len(swap_map) == 2
        for block_id in old_phy_blocks:
            assert block_id - num_gpu_blocks in swap_map
        for block_id in new_phy_blocks:
            assert block_id in swap_map.values()

        success, swap_map = block_mgr.try_swap_out(msg)
        assert success
        token_ids = torch.tensor([1] * (block_size * 4))
        msg_full = sess.add_sequence(token_ids)
        block_mgr.allocate(msg_full)
        success, swap_map = block_mgr.try_swap_out(msg)
        assert not success


class TestWindowBlockManager:

    @pytest.fixture
    def window_size(self):
        yield 32

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def num_cpu_blocks(self):
        yield 4

    @pytest.fixture
    def num_gpu_blocks(self):
        yield 4

    @pytest.fixture
    def max_batch_size(self):
        yield 4

    @pytest.fixture
    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size, window_size):
        yield CacheConfig(max_batches=max_batch_size,
                          block_size=block_size,
                          num_cpu_blocks=num_cpu_blocks,
                          num_gpu_blocks=num_gpu_blocks,
                          window_size=window_size)

    @pytest.fixture
    def scheduler_config(self, max_batch_size):
        yield SchedulerConfig(max_batches=max_batch_size,
                              max_session_len=128,
                              max_request_output_len=64,
                              eviction_type='recompute')

    @pytest.fixture
    def seq_meta(self, block_size):
        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
        strategy = ARSequenceStrategy()
        yield SequenceMeta(block_size, strategy=strategy)

    @pytest.fixture
    def scheduler(self, cache_config, scheduler_config, seq_meta):
        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)

    @pytest.fixture
    def block_mgr(self, scheduler):
        yield scheduler.block_manager

    def test_alloc(self, scheduler, block_mgr, num_gpu_blocks):
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        # test alloc
        token_ids = torch.tensor([1])
        msg = sess.add_sequence(token_ids)
        assert block_mgr.can_allocate(msg)
        block_mgr.allocate(msg)
        block_table = block_mgr.get_block_table(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1
        assert block_table is not None
        assert len(block_table) == 1

        # test free
        block_mgr.free(msg)
        block_table = block_mgr.get_block_table(msg)
        assert block_table is None or len(block_table) == 0
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks

        # alloc over limit
        token_ids = torch.zeros((num_gpu_blocks * block_size + 1, ), dtype=torch.int64)
        msg = sess.add_sequence(token_ids)
        assert not block_mgr.can_allocate(msg)

    def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size):
        sess = scheduler.add_session(0)

        # 2 win block
        token_ids = torch.tensor([1] * window_size)
        msg = sess.add_sequence(token_ids)
        block_mgr.allocate(msg)
        msg.update_token_ids(torch.tensor([1]))
        block_mgr.allocate(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3
        block_table = block_mgr.get_block_table(msg)
        assert block_table is None or len(block_table) == 3
        block_mgr.free(msg)

        # 3 win block
        token_ids = torch.tensor([1] * (window_size + 2))
        msg = sess.add_sequence(token_ids)
        block_mgr.allocate(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3
        msg.update_token_ids(torch.tensor([1]))
        block_mgr.allocate(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3
        block_table = block_mgr.get_block_table(msg)
        assert block_table is None or len(block_table) == 3
        block_mgr.free(msg)

        # not full win
        token_ids = torch.tensor([1] * (window_size - 2))
        msg = sess.add_sequence(token_ids)
        block_mgr.allocate(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2
        msg.update_token_ids(torch.tensor([1]))
        block_mgr.allocate(msg)
        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2
        block_table = block_mgr.get_block_table(msg)
        assert block_table is None or len(block_table) == 2
        block_mgr.free(msg)


================================================
FILE: tests/pytorch/paging/test_block_trie.py
================================================
import numpy as np
import pytest

from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
from lmdeploy.pytorch.messages import SequenceMeta
from lmdeploy.pytorch.paging import Scheduler


class TestBlockTire:

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def num_cpu_blocks(self):
        yield 4

    @pytest.fixture
    def num_gpu_blocks(self):
        yield 16

    @pytest.fixture
    def max_batch_size(self):
        yield 4

    @pytest.fixture
    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):
        yield CacheConfig(max_batches=max_batch_size,
                          block_size=block_size,
                          num_cpu_blocks=num_cpu_blocks,
                          num_gpu_blocks=num_gpu_blocks,
                          enable_prefix_caching=True)

    @pytest.fixture
    def scheduler_config(self, max_batch_size):
        yield SchedulerConfig(max_batches=max_batch_size,
                              max_session_len=128,
                              max_request_output_len=64,
                              eviction_type='recompute')

    @pytest.fixture
    def seq_meta(self, block_size):
        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
        strategy = ARSequenceStrategy()
        yield SequenceMeta(block_size, strategy=strategy)

    @pytest.fixture
    def scheduler(self, cache_config, scheduler_config, seq_meta):
        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)

    @pytest.fixture
    def block_mgr(self, scheduler):
        yield scheduler.block_manager

    @pytest.fixture
    def block_trie(self, scheduler):
        yield scheduler.block_trie

    def test_allocate(self, block_trie, block_mgr, scheduler):
        allocator = block_trie.allocator
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size
        token_ids = ([1] * block_size + [2] * block_size)
        token_ids += [3] * (block_size // 2)
        seq = sess.add_sequence(token_ids)

        # first allocate
        block_mgr.allocate(seq)
        block_trie.allocate(seq)
        logical_blocks = seq.logical_blocks
        assert len(logical_blocks) == 3
        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())
        assert np.array_equal(ref_cnt, [2, 2, 1])
        node = getattr(seq.logical_blocks, 'last_shared_node', None)
        assert node is not None
        assert node.num_matched == block_size * 2
        assert np.array_equal(node.tokens, [2] * block_size)
        assert np.array_equal(node.parent.tokens, [1] * block_size)
        assert node in block_trie.leaves
        assert node.parent not in block_trie.leaves

        # append
        seq.update_token_ids([4] * block_size)
        block_mgr.allocate(seq)
        block_trie.allocate(seq)
        logical_blocks = seq.logical_blocks
        assert len(logical_blocks) == 4
        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())
        assert np.array_equal(ref_cnt, [2, 2, 2, 1])
        node = getattr(seq.logical_blocks, 'last_shared_node', None)
        assert node is not None
        assert node.num_matched == block_size * 3
        expect_tokens = [3] * (block_size // 2) + [4] * (block_size // 2)
        assert np.array_equal(node.tokens, expect_tokens)
        assert node in block_trie.leaves
        assert len(block_trie.leaves) == 1

    def test_match(self, block_trie, block_mgr, scheduler):
        allocator = block_trie.allocator
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size

        # initialize cache
        token_ids = ([1] * block_size + [2] * block_size)
        token_ids += [3] * (block_size // 2)
        seq = sess.add_sequence(token_ids)
        block_mgr.allocate(seq)
        block_trie.allocate(seq)

        # test1
        token_ids = ([1] * block_size + [3] * block_size)
        seq = sess.add_sequence(token_ids)
        block_trie.match(seq)
        logical_blocks = seq.logical_blocks
        assert len(logical_blocks) == 1
        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())
        assert np.array_equal(ref_cnt, [3])
        node = getattr(seq.logical_blocks, 'last_shared_node', None)
        assert node is not None
        assert node.num_matched == block_size
        assert np.array_equal(node.tokens, [1] * block_size)
        block_mgr.allocate(seq)
        block_trie.allocate(seq)
        assert len(block_trie.leaves) == 2

        # test2
        token_ids = ([1] * block_size + [2] * block_size)
        token_ids += [4] * (block_size // 2)
        seq = sess.add_sequence(token_ids)
        block_trie.match(seq)
        logical_blocks = seq.logical_blocks
        assert len(logical_blocks) == 2
        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())
        assert np.array_equal(ref_cnt, [4, 3])

    def test_evict(self, block_trie, scheduler, num_gpu_blocks):
        block_mgr = block_trie.block_manager
        sess = scheduler.add_session(0)
        block_size = sess.seq_meta.block_size
        token_ids = ([1] * block_size * (num_gpu_blocks - 1))
        token_ids += [2] * (block_size // 2)
        seq = sess.add_sequence(token_ids)
        block_mgr.allocate(seq)
        block_trie.allocate(seq)
        assert block_mgr.get_num_free_gpu_blocks() == 0

        # test free
        block_mgr.free(seq)
        seq.set_step(0)
        assert block_mgr.get_num_free_gpu_blocks() == 1

        # test evict
        leaf = next(iter(block_trie.leaves))
        block_trie.evict(4)
        new_leaf = next(iter(block_trie.leaves))
        assert leaf != new_leaf
        assert block_mgr.get_num_free_gpu_blocks() == 5


================================================
FILE: tests/pytorch/paging/test_scheduler.py
================================================
import pytest
import torch

from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta
from lmdeploy.pytorch.paging.scheduler import Scheduler


class TestScheduler:

    @pytest.fixture
    def block_size(self):
        yield 16

    @pytest.fixture
    def num_cpu_blocks(self):
        yield 4

    @pytest.fixture
    def num_gpu_blocks(self):
        yield 4

    @pytest.fixture
    def max_batch_size(self):
        yield 4

    @pytest.fixture
    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):
        yield CacheConfig(max_batches=max_batch_size,
                          block_size=block_size,
                          num_cpu_blocks=num_cpu_blocks,
                          num_gpu_blocks=num_gpu_blocks)

    @pytest.fixture
    def scheduler_config(self, max_batch_size):
        yield SchedulerConfig(max_batches=max_batch_size,
                              max_session_len=128,
                              max_request_output_len=64,
                              eviction_type='recompute')

    @pytest.fixture
    def seq_meta(self, block_size):
        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
        strategy = ARSequenceStrategy()
        yield SequenceMeta(block_size, strategy=strategy)

    @pytest.fixture
    def scheduler(self, cache_config, scheduler_config, seq_meta):
        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)

    def test_schedule_base(self, scheduler, block_size, num_gpu_blocks):
        block_manager = scheduler.block_manager
        session_id = 0
        session = scheduler.add_session(session_id)
        assert session_id in scheduler.sessions
        assert scheduler.sessions[session_id] == session

        num_blocks = 2
        token_ids = torch.tensor([0] * block_size * num_blocks)
        seq = session.add_sequence(token_ids)

        assert seq.status == MessageStatus.WAITING
        assert seq in scheduler.waiting

        output = scheduler.schedule(is_prefill=True)
        block_tables = scheduler.get_block_tables(output.running)

        assert seq.status == MessageStatus.READY
        assert seq in output.running
        assert len(block_tables) == 1
        assert len(block_tables[0]) == num_blocks
        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - num_blocks

        assert scheduler.has_unfinished()

    def test_update(self, scheduler, block_size, num_gpu_blocks):
        block_manager = scheduler.block_manager
        session_id1 = 0
        session1 = scheduler.add_session(session_id1)
        token_ids1 = torch.tensor([0] * block_size * 1)
        seq1 = session1.add_sequence(token_ids1)

        session_id2 = 1
        session2 = scheduler.add_session(session_id2)
        token_ids2 = torch.tensor([0] * block_size * 2)
        seq2 = session2.add_sequence(token_ids2)
        token_ids3 = torch.tensor([0] * block_size * 3)
        seq3 = session2.add_sequence(token_ids3)

        scheduler.schedule(is_prefill=True)
        assert seq1.status == MessageStatus.READY
        assert seq2.status == MessageStatus.READY
        assert seq3.status == MessageStatus.WAITING

        # stop seq
        seq1.state.stop()
        assert len(scheduler.ready) == 1
        assert seq1 in scheduler.hanging

        # end seq
        seq1.session.remove_sequence(seq1)
        assert session_id1 in scheduler.sessions
        assert seq1 not in scheduler.ready
        assert seq1 not in scheduler.hanging
        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2

        # stop session
        scheduler.stop_session(session_id2)
        assert len(scheduler.ready) == 0
        assert len(scheduler.waiting) == 0
        assert len(scheduler.hanging) == 2

        # end session
        scheduler.end_session(session_id2)
        assert session_id2 not in scheduler.sessions
        assert len(scheduler.hanging) == 0
        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks

    def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks):
        block_manager = scheduler.block_manager
        session_id = 0
        session = scheduler.add_session(session_id)

        # test: add 3 seq
        token_ids1 = torch.tensor([0] * block_size * 1)
        seq1 = session.add_sequence(token_ids1)
        token_ids2 = torch.tensor([0] * block_size * 2)
        seq2 = session.add_sequence(token_ids2)
        token_ids3 = torch.tensor([0] * block_size * 3)
        seq3 = session.add_sequence(token_ids3)
        scheduler.schedule(is_prefill=True)
        # seq1: 1 running gpu
        # seq2: 2 running gpu
        # seq3: 3 waiting empty
        assert seq1.status == MessageStatus.READY
        assert seq2.status == MessageStatus.READY
        assert seq3.status == MessageStatus.WAITING
        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3

        # test: waiting alloc
        seq2.state.stop()
        assert len(scheduler.ready) == 1
        assert len(scheduler.waiting) == 1
        assert len(scheduler.hanging) == 1

        scheduler.schedule(is_prefill=True)
        # seq1: 1 running gpu
        # seq2: 2 hanging cpu
        # seq3: 3 running gpu
        assert seq1.status == MessageStatus.READY
        assert seq2.status == MessageStatus.STOPPED
        assert seq3.status == MessageStatus.READY
        assert block_manager.get_num_free_gpu_blocks() == 0

        # test: waiting append token
        seq2.state.activate()
        seq3.session.remove_sequence(seq3)
        seq2.update_token_ids(torch.tensor([1] * block_size))
        assert len(scheduler.ready) == 1
        assert len(scheduler.waiting) == 1
        assert len(scheduler.hanging) == 0

        scheduler.schedule(is_prefill=True)
        # seq1: 1 running gpu
        # seq2: 3 running gpu
        # seq3: 3 nan
        assert seq1.status == MessageStatus.READY
        assert seq2.status == MessageStatus.READY
        assert block_manager.get_num_free_gpu_blocks() == 0

        # test running append
        seq1.update_token_ids(torch.tensor([1] * block_size))
        seq2.update_token_ids(torch.tensor([1] * block_size))
        assert len(scheduler.ready) == 2
        scheduler.schedule(is_prefill=False)
        # seq1: 2 running gpu
        # seq2: 4 waiting cpu
        # seq3: 3 nan
        assert seq1.status == MessageStatus.READY
        assert seq2.status == MessageStatus.WAITING
        assert block_manager.get_num_free_gpu_blocks() == 2


================================================
FILE: tests/test_lmdeploy/test_auto_backend.py
================================================
import os
import tempfile

import numpy as np
import pytest


class TestAutoBackend:

    @pytest.fixture
    def turbomind_workspace(self):
        workspace = tempfile.TemporaryDirectory('internlm-chat-7b-turbomind').name
        os.makedirs(os.path.join(workspace, 'triton_models'), exist_ok=True)
        return workspace

    @pytest.fixture
    def models(self):
        # example models to test
        # format (model_path, is_turbomind_supported)
        models = [
            ('baichuan-inc/Baichuan-7B', True),
            ('baichuan-inc/Baichuan2-7B-Chat', True),
            ('baichuan-inc/Baichuan-13B-Chat', False),
            ('baichuan-inc/Baichuan2-13B-Chat', False),
            ('internlm/internlm-chat-7b', True),
            ('internlm/internlm2-chat-7b', True),
            ('internlm/internlm-xcomposer2-7b', True),
            ('internlm/internlm-xcomposer-7b', False),
            ('THUDM/chatglm2-6b', False),
            ('THUDM/chatglm3-6b', False),
            ('deepseek-ai/deepseek-moe-16b-chat', False),
            ('01-ai/Yi-34B-Chat', True),
            ('codellama/CodeLlama-7b-Instruct-hf', True),
            ('Qwen/Qwen-7B-Chat', True),
            ('Qwen/Qwen-VL-Chat', True),
            ('Qwen/Qwen1.5-4B-Chat', True),
            ('Qwen/Qwen1.5-0.5B-Chat', True),
        ]
        return models

    def test_turbomind_is_supported(self, turbomind_workspace, models):
        from lmdeploy.turbomind.supported_models import is_supported
        assert is_supported(turbomind_workspace) is True
        for m, flag in models:
            assert is_supported(m) is flag

    def test_autoget_backend(self, turbomind_workspace, models):
        from lmdeploy.archs import autoget_backend
        assert autoget_backend(turbomind_workspace) == 'turbomind'
        n = len(models)
        choices = np.random.choice(n, n // 2, replace=False)
        for i in choices:
            model, is_support_turbomind = models[i]
            target = 'turbomind' if is_support_turbomind else 'pytorch'
            backend = autoget_backend(model)
            assert backend == target


================================================
FILE: tests/test_lmdeploy/test_content_merge.py
================================================
import pytest

from lmdeploy.serve.processors import MultimodalProcessor


class TestMergeMessageContent:
    """Test suite for merge_message_content function."""

    def test_missing_content_field(self):
        """Test that missing content field is added with empty string.

        This case occurs with assistant messages that only have tool_calls.
        """
        msg = {
            'role':
            'assistant',
            'tool_calls': [{
                'id': 'chatcmpl-tool-123',
                'type': 'function',
                'function': {
                    'name': 'get_weather',
                    'arguments': '{"city": "Paris"}'
                }
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert 'content' in result
        assert result['content'] == ''
        assert 'tool_calls' in result
        assert result['tool_calls'] == msg['tool_calls']

    def test_explicit_none_content(self):
        """Test that explicit None content is converted to empty string.

        This matches vLLM's behavior: None → [] → ''.join([]) → ''.
        """
        msg = {
            'role':
            'assistant',
            'content':
            None,
            'tool_calls': [{
                'id': 'chatcmpl-tool-456',
                'type': 'function',
                'function': {
                    'name': 'Bash',
                    'arguments': '{"command": "ls"}'
                }
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == ''
        assert 'tool_calls' in result

    def test_string_content_unchanged(self):
        """Test that string content remains unchanged."""
        msg = {'role': 'user', 'content': 'Hello, world!'}
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == 'Hello, world!'
        assert result is msg  # Should return the same object

    def test_single_text_block(self):
        """Test extraction of single text block from list content."""
        msg = {'role': 'user', 'content': [{'type': 'text', 'text': 'Single block'}]}
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == 'Single block'

    def test_multiple_text_blocks_newline_join(self):
        """Test that multiple text blocks are merged with newline separator.

        This matches vLLM's behavior: text_prompt = "\\n".join(texts)
        """
        msg = {
            'role':
            'user',
            'content': [{
                'type': 'text',
                'text': 'First block'
            }, {
                'type': 'text',
                'text': 'Second block'
            }, {
                'type': 'text',
                'text': 'Third block'
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == 'First block\nSecond block\nThird block'

    def test_mixed_content_types(self):
        """Test that only text blocks are extracted from mixed content.

        Non-text blocks (like image_url) should be filtered out.
        """
        msg = {
            'role':
            'user',
            'content': [{
                'type': 'text',
                'text': 'Analyze this image:'
            }, {
                'type': 'image_url',
                'image_url': {
                    'url': 'http://example.com/img.jpg'
                }
            }, {
                'type': 'text',
                'text': 'What do you see?'
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == 'Analyze this image:\nWhat do you see?'

    def test_empty_list_content(self):
        """Test that empty list content produces empty string."""
        msg = {'role': 'user', 'content': []}
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == ''

    def test_list_with_non_text_blocks_only(self):
        """Test content with only non-text blocks (e.g., only images)."""
        msg = {
            'role':
            'user',
            'content': [{
                'type': 'image_url',
                'image_url': {
                    'url': 'http://example.com/img1.jpg'
                }
            }, {
                'type': 'image_url',
                'image_url': {
                    'url': 'http://example.com/img2.jpg'
                }
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == ''

    def test_preserve_all_message_fields(self):
        """Test that all message fields are preserved during content merge."""
        msg = {
            'role': 'assistant',
            'content': [{
                'type': 'text',
                'text': 'Response'
            }],
            'tool_calls': [{
                'id': '123',
                'type': 'function'
            }],
            'name': 'assistant',
            'custom_field': 'custom_value'
        }
        result = MultimodalProcessor.merge_message_content(msg)

        assert result['content'] == 'Response'
        assert result['tool_calls'] == msg['tool_calls']
        assert result['name'] == 'assistant'
        assert result['custom_field'] == 'custom_value'
        assert set(result.keys()) == set(msg.keys())

    def test_text_block_with_missing_text_field(self):
        """Test handling of text block without 'text' field."""
        msg = {
            'role':
            'user',
            'content': [
                {
                    'type': 'text',
                    'text': 'First'
                },
                {
                    'type': 'text'
                },  # Missing 'text' field
                {
                    'type': 'text',
                    'text': 'Third'
                }
            ]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        # Missing text field should be treated as empty string
        assert result['content'] == 'First\n\nThird'

    def test_gpt_oss_tool_call_scenario(self):
        """Test the specific GPT-OSS tool call scenario from the bug report.

        When GPT-OSS assistant returns tool calls, content is empty/missing.
        """
        msg = {
            'role':
            'assistant',
            'tool_calls': [{
                'id': 'chatcmpl-tool-UK9rkwzMAyxt9DxBezk7E2',
                'type': 'function',
                'function': {
                    'name': 'Bash',
                    'arguments': '{"command": "ls", "description": "List files in current directory"}'
                }
            }]
        }
        result = MultimodalProcessor.merge_message_content(msg)

        # Should add content field with empty string
        assert 'content' in result
        assert result['content'] == ''
        # Should preserve tool_calls
        assert len(result['tool_calls']) == 1
        assert result['tool_calls'][0]['function']['name'] == 'Bash'


@pytest.mark.parametrize(
    'msg,expected_content',
    [
        # Basic cases
        ({
            'role': 'user',
            'content': 'test'
        }, 'test'),
        ({
            'role': 'user',
            'content': None
        }, ''),
        ({
            'role': 'assistant'
        }, ''),

        # List content cases
        ({
            'role': 'user',
            'content': [{
                'type': 'text',
                'text': 'a'
            }]
        }, 'a'),
        ({
            'role': 'user',
            'content': [{
                'type': 'text',
                'text': 'a'
            }, {
                'type': 'text',
                'text': 'b'
            }]
        }, 'a\nb'),

        # Empty cases
        ({
            'role': 'user',
            'content': []
        }, ''),
        ({
            'role': 'user',
            'content': [{
                'type': 'image_url'
            }]
        }, ''),
    ])
def test_merge_message_content_parametrized(msg, expected_content):
    """Parametrized test for various message content scenarios."""
    result = MultimodalProcessor.merge_message_content(msg)
    assert result['content'] == expected_content


def test_batch_message_processing():
    """Test processing multiple messages in a batch (typical usage pattern)."""
    messages = [{
        'role': 'user',
        'content': 'Hello'
    }, {
        'role': 'assistant',
        'tool_calls': [{
            'id': '123',
            'type': 'function'
        }]
    }, {
        'role': 'user',
        'content': [{
            'type': 'text',
            'text': 'Block 1'
        }, {
            'type': 'text',
            'text': 'Block 2'
        }]
    }]

    processed = [MultimodalProcessor.merge_message_content(msg) for msg in messages]

    # Verify all messages have content field
    assert all('content' in msg for msg in processed)

    # Verify content values
    assert processed[0]['content'] == 'Hello'
    assert processed[1]['content'] == ''
    assert processed[2]['content'] == 'Block 1\nBlock 2'

    # Should pass model.py assertion
    assert all(isinstance(m, dict) and 'role' in m and 'content' in m for m in processed)


================================================
FILE: tests/test_lmdeploy/test_grammar.py
================================================
import json
import re

import pytest
from jsonschema import validate

from lmdeploy import pipeline
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig

MODEL_IDS = [
    'Qwen/Qwen3-0.6B',
    'OpenGVLab/InternVL3_5-1B',
]

BACKEND_FACTORIES = [
    ('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)),
    ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),
]

SCHEMA_MAP = {
    'json_schema': {
        'type': 'object',
        'properties': {
            'name': {
                'type': 'string'
            },
            'skills': {
                'type': 'array',
                'items': {
                    'type': 'string',
                    'maxLength': 10
                },
                'minItems': 3,
                'maxItems': 10,
            },
            'work history': {
                'type': 'array',
                'items': {
                    'type': 'object',
                    'properties': {
                        'company': {
                            'type': 'string'
                        },
                        'duration': {
                            'type': 'string'
                        },
                    },
                    'required': ['company'],
                },
            },
        },
        'required': ['name', 'skills', 'work history'],
    },
    'regex_schema': 'call me [A-Za-z]{1,10}',
    'json_object': None,
}


@pytest.mark.parametrize('model_id', MODEL_IDS)
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None])
def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
    pipe = pipeline(
        model_id,
        backend_config=backend_factory(),
        log_level='INFO',
    )

    if schema_type is None:
        enable_guide = False
    else:
        enable_guide = True
        response_format = {'type': schema_type}
        schema = SCHEMA_MAP[schema_type]
        if schema_type == 'json_schema':
            response_format[schema_type] = dict(name='test', schema=schema)
        elif schema_type == 'regex_schema':
            response_format[schema_type] = schema

    try:
        if enable_guide:
            gen_config = GenerationConfig(response_format=response_format)
        else:
            gen_config = GenerationConfig()

        response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)
        assert response and response[0].text

        if enable_guide:
            if schema_type == 'json_schema':
                validate(instance=json.loads(response[0].text), schema=schema)
            elif schema_type == 'json_object':
                validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True})
            elif schema_type == 'regex_schema':
                assert re.fullmatch(schema, response[0].text)
    finally:
        pipe.close()


@pytest.mark.parametrize('model_id', MODEL_IDS)
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
def test_mix_guided_matrix(model_id, backend_name, backend_factory):
    pipe = pipeline(
        model_id,
        backend_config=backend_factory(),
        log_level='INFO',
    )

    schema_type = 'json_schema'
    response_format = {'type': schema_type}
    schema = SCHEMA_MAP[schema_type]
    response_format[schema_type] = dict(name='test', schema=schema)

    prompts = ['Make a self introduction please.'] * 4
    try:
        config = GenerationConfig(response_format=response_format)

        gen_config = [None if idx % 3 else config for idx in range(4)]

        responses = pipe.batch_infer(prompts, gen_config=gen_config)

        for resp, c in zip(responses, gen_config):
            if c is None:
                # Unguided generation: ensure we get some text, and that it does not
                # accidentally produce JSON that conforms to the guided schema.
                assert resp and resp.text
                try:
                    data = json.loads(resp.text)
                except json.JSONDecodeError:
                    # Not valid JSON, so it cannot conform to the schema.
                    continue
                else:
                    try:
                        validate(instance=data, schema=schema)
                    except Exception:
                        # JSON is present but does not satisfy the schema.
                        continue
                    else:
                        pytest.fail('Unguided generation unexpectedly produced schema-conformant JSON')
            else:
                validate(instance=json.loads(resp.text), schema=schema)
    finally:
        pipe.close()


================================================
FILE: tests/test_lmdeploy/test_harmony_gpt_oss_parser.py
================================================
import collections
import json
import os
import sys
import time
import types
from typing import Generator, List

import pytest
import shortuuid

# Ensure local package is imported (not any site-packages installation)
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)


def _install_openai_harmony_stub():
    """Install a minimal stub for `openai_harmony` so the module imports
    without the real dependency.

    The GptOssChatParser test injects its own dummy parser, so the stub is sufficient.
    """
    if 'openai_harmony' in sys.modules:
        return
    m = types.ModuleType('openai_harmony')

    class HarmonyEncodingName:
        HARMONY_GPT_OSS = 'HARMONY_GPT_OSS'

    class Role:
        ASSISTANT = 'assistant'

    class StreamableParser:  # pragma: no cover - constructor only used

        def __init__(self, encoding, role=None):
            self.encoding = encoding
            self.role = role

    def load_harmony_encoding(name):  # pragma: no cover - not used in test
        return object()

    m.HarmonyEncodingName = HarmonyEncodingName
    m.Role = Role
    m.StreamableParser = StreamableParser
    m.load_harmony_encoding = load_harmony_encoding
    sys.modules['openai_harmony'] = m


TestExpects = collections.namedtuple('TestExpects', 'func_name location')


class DummyParser:
    """A minimal stand-in for Harmony's StreamableParser with channels.

    Control tokens:
      -1: start functions.get_weather (commentary)
      -4: start functions.get_time (commentary)
      -6: start functions.get_weather (again)
      -9: end current tool call, append to `messages`
      -2: switch to final (visible) content
      -3: switch to analysis (reasoning)
    Other tokens are interpreted as chr(token).
    """

    class _Msg:

        def __init__(self, channel, recipient):
            self.channel = channel
            self.recipient = recipient

    def __init__(self):
        self.current_channel = None
        self.current_recipient = None
        self.last_content_delta = ''
        self.messages = []

    def process(self, token):
        if token == -1:
            self.current_channel = 'commentary'
            self.current_recipient = 'functions.get_weather'
            self.last_content_delta = ''
            return
        if token == -4:
            self.current_channel = 'commentary'
            self.current_recipient = 'functions.get_time'
            self.last_content_delta = ''
            return
        if token == -6:
            self.current_channel = 'commentary'
            self.current_recipient = 'functions.get_weather'
            self.last_content_delta = ''
            return
        if token == -9:
            if self.current_channel == 'commentary' and self.current_recipient and self.current_recipient.startswith(
                    'functions.'):
                self.messages.append(self._Msg(self.current_channel, self.current_recipient))
            # reset recipient to signal end of current tool call
            self.current_recipient = None
            self.current_channel = None
            self.last_content_delta = ''
            return
        if token == -2:
            self.current_channel = 'final'
            self.current_recipient = None
            self.last_content_delta = ''
            return
        if token == -3:
            self.current_channel = 'analysis'
            self.current_recipient = None
            self.last_content_delta = ''
            return
        # regular character token
        self.last_content_delta = chr(token)


def _chat_completion_v1(request, token_chunks: List[List[int]]):
    from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
    from lmdeploy.serve.openai.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice,
                                                ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
                                                UsageInfo)

    request_id = f'chat-{shortuuid.random()}'
    created_time = int(time.time())
    model_name = request.model

    parser = GptOssChatParser()
    parser.parser = DummyParser()

    if request.stream:

        def completion_stream_generator() -> Generator['ChatCompletionStreamResponse', None, None]:
            finish_reason = 'stop'
            for chunk in token_chunks:
                delta_message = parser.parse_streaming(chunk)
                choice_data = ChatCompletionResponseStreamChoice(index=0,
                                                                 delta=delta_message,
                                                                 finish_reason=finish_reason,
                                                                 logprobs=None)
                response = ChatCompletionStreamResponse(id=request_id,
                                                        created=created_time,
                                                        model=model_name,
                                                        choices=[choice_data],
                                                        usage=None)
                yield response

        return completion_stream_generator()

    # Non-stream path: parse all tokens at once using parse_full
    tokens: List[int] = []
    for c in token_chunks:
        tokens.extend(c)
    message = parser.parse_full(tokens)
    finish_reason = 'tool_calls' if message.tool_calls else 'stop'
    choice_data = ChatCompletionResponseChoice(index=0, message=message, finish_reason=finish_reason)
    return ChatCompletionResponse(id=request_id,
                                  created=created_time,
                                  model=model_name,
                                  choices=[choice_data],
                                  usage=UsageInfo())


def _stream_parse(request, token_chunks: List[List[int]]):
    from lmdeploy.serve.openai.protocol import DeltaMessage

    content = ''
    reasoning_content = ''
    tool_calls_by_index = {}

    for i, stream_resp in enumerate(_chat_completion_v1(request, token_chunks)):
        delta_message: DeltaMessage = stream_resp.choices[0].delta
        if delta_message.content:
            content += delta_message.content
        if delta_message.reasoning_content:
            reasoning_content += delta_message.reasoning_content
        if delta_message.tool_calls:
            for c in delta_message.tool_calls:
                idx = c.index
                existing_call = tool_calls_by_index.get(idx, None)
                if not existing_call:
                    tool_calls_by_index[idx] = c
                    continue
                if c.function.name:
                    existing_call.function.name = c.function.name
                if c.function.arguments:
                    existing_call.function.arguments = existing_call.function.arguments or ''
                    existing_call.function.arguments += c.function.arguments
    # sorted list for stable order
    tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index.keys())]
    return content, reasoning_content, tool_calls


def _t(s: str) -> List[int]:
    return [ord(c) for c in s]


# Basic: single function call split across two chunks (bug repro scenario)
TOKENS_SINGLE_CALL_TWO_CHUNKS = [
    [-1] + _t('{"location": "Paris'),
    _t(', France"}'),
]

# Multiple calls with indices and different function names
TOKENS_TWO_CALLS_DIFFERENT_FUNCS = [
    [-1] + _t('{"location": "Berlin"}') + [-9] + [-4] + _t('{"city": "New'),
    _t(' York"}') + [-9],
]

# Interleaved channels: analysis, tool call, final content
TOKENS_INTERLEAVED = [
    [-3] + _t('Thinking about the weather. ') + [-1] + _t('{"location": "Par'),
    _t('is, France"}') + [-9] + [-2] + _t('Fetching the weather now.'),
]

# Two calls, same function name, indices increment
TOKENS_TWO_CALLS_SAME_FUNC = [
    [-1] + _t('{"location": "Tokyo"}') + [-9],
    [-6] + _t('{"location": "Ky'),
    _t('oto"}') + [-9],
]


@pytest.mark.parametrize(('token_chunks', 'expects'), [
    (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),
])
def test_parser_stream_basic(token_chunks: List[List[int]], expects: List[TestExpects]):
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, token_chunks)

    assert len(tool_calls) == len(expects)
    for parsed_call, expected_call in zip(tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args['location'] == expected_call.location
    assert content.strip() == ''
    assert (reasoning_content or '').strip() == ''


def test_parser_stream_multiple_calls_indices():
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_TWO_CALLS_DIFFERENT_FUNCS)

    assert len(tool_calls) == 2
    # tool_calls sorted by index ensures stable order
    tc0, tc1 = tool_calls
    assert tc0.index == 0 and tc1.index == 1
    assert tc0.function.name == 'get_weather'
    assert json.loads(tc0.function.arguments)['location'] == 'Berlin'
    assert tc1.function.name == 'get_time'
    assert json.loads(tc1.function.arguments)['city'] == 'New York'
    assert (content or '').strip() == ''
    assert (reasoning_content or '').strip() == ''


def test_parser_stream_interleaved_channels():
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_INTERLEAVED)

    assert json.loads(tool_calls[0].function.arguments)['location'] == 'Paris, France'
    assert reasoning_content == 'Thinking about the weather. '
    assert content == 'Fetching the weather now.'


@pytest.mark.parametrize(('token_chunks', 'expects'), [
    (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),
                                  TestExpects('get_weather', 'Kyoto')]),
])
def test_parser_stream_two_calls_same_func(token_chunks: List[List[int]], expects: List[TestExpects]):
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
    _, _, tool_calls = _stream_parse(request, token_chunks)

    assert len(tool_calls) == len(expects)
    for parsed_call, expected_call in zip(tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args['location'] == expected_call.location


def test_open_tool_call_no_args():
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, [[-1]])

    assert len(tool_calls) == 1
    assert tool_calls[0].function.name == 'get_weather'
    assert (tool_calls[0].function.arguments or '') == ''
    assert (content or '') == ''
    assert (reasoning_content or '') == ''


@pytest.mark.parametrize(('token_chunks', 'expects'), [
    (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),
    (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),
                                  TestExpects('get_weather', 'Kyoto')]),
])
def test_parser_nonstream(token_chunks: List[List[int]], expects: List[TestExpects]):
    from lmdeploy.serve.openai.protocol import ChatCompletionRequest

    _install_openai_harmony_stub()
    resp = _chat_completion_v1(ChatCompletionRequest(model='gpt-oss', messages=[], stream=False), token_chunks)

    assert len(resp.choices) == 1
    first_message = resp.choices[0].message
    assert first_message.content is None
    assert (first_message.reasoning_content or '') == ''
    assert len(first_message.tool_calls) == len(expects)
    for parsed_call, expected_call in zip(first_message.tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args['location'] == expected_call.location


================================================
FILE: tests/test_lmdeploy/test_lite/test_quantization/test_utils/test_cal_qparams.py
================================================
# yapf: disable
import torch

from lmdeploy.lite.utils import (cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax,
                                 cal_qparams_per_group_absmax, cal_qparams_per_group_minmax,
                                 cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_minmax)


# yapf: enable
def test_cal_qparams():
    """Test function for quantization parameter calculation."""

    # Create a dummy tensor
    w = torch.randn(64, 64)

    # Test per-channel absmax method
    qparams = cal_qparams_per_channel_absmax(w, 8)
    assert qparams.scales.shape == (64, 1)
    assert qparams.zero_points is None

    # Test per-channel minmax method
    qparams = cal_qparams_per_channel_minmax(w, 8)
    assert qparams.scales.shape == (64, 1)
    assert qparams.zero_points.shape == (64, 1)

    # Test per-group absmax method
    qparams = cal_qparams_per_group_absmax(w, 8, 16)
    assert qparams.scales.shape == (64, 4, 1)
    assert qparams.zero_points is None

    # Test per-group minmax method
    qparams = cal_qparams_per_group_minmax(w, 8, 16)
    assert qparams.scales.shape == (64, 4, 1)
    assert qparams.zero_points.shape == (64, 4, 1)

    # Test per-tensor absmax method
    qparams = cal_qparams_per_tensor_absmax(w, 8)
    assert qparams.scales.shape == ()
    assert qparams.zero_points is None

    # Test per-tensor minmax method
    qparams = cal_qparams_per_tensor_minmax(w, 8)
    assert qparams.scales.shape == ()
    assert qparams.zero_points.shape == ()


================================================
FILE: tests/test_lmdeploy/test_messages.py
================================================
from typing import List

import pytest

from lmdeploy import GenerationConfig, Tokenizer
from lmdeploy.utils import get_hf_gen_cfg


def test_engine_generation_config():
    tokenizer = Tokenizer('internlm/internlm-chat-7b')
    config = GenerationConfig(n=3, stop_words=[''])
    stop_token_ids = tokenizer.encode('', add_bos=False)
    config.convert_stop_bad_words_to_ids(tokenizer)
    assert stop_token_ids == config.stop_token_ids
    assert isinstance(config.stop_token_ids, List) and \
        isinstance(config.stop_token_ids[0], int)


@pytest.mark.parametrize('model_path', [
    'deepseek-ai/DeepSeek-V3',
    'Qwen/Qwen2.5-32B-Instruct',
    'internlm/internlm3-8b-instruct',
])
def test_update_from_hf_gen_cfg(model_path):
    tokenizer = Tokenizer(model_path)
    model_cfg = get_hf_gen_cfg(model_path)

    generation_config = GenerationConfig()
    generation_config.update_from_hf_gen_cfg(model_cfg, tokenizer.eos_token_id)
    assert generation_config.stop_token_ids is not None


================================================
FILE: tests/test_lmdeploy/test_model.py
================================================
import pytest

from lmdeploy.model import MODELS

HF_MODELS_WITH_CHAT_TEMPLATES = [
    'Qwen/Qwen1.5-7B-Chat',
    'Qwen/Qwen2.5-7B-Instruct',
    'Qwen/Qwen3-8B',
    'Qwen/QwQ-32B',
    'Qwen/QwQ-32B-Preview',
    'Qwen/QwQ-32B-AWQ',
    'Qwen/Qwen2.5-VL-7B-Instruct',
    'Qwen/Qwen2-VL-7B-Instruct',
    'internlm/internlm2-chat-7b',
    'internlm/internlm2_5-7b-chat',
    'internlm/internlm3-8b-instruct',
    # 'internlm/Intern-S1',
    # 'internlm/Intern-S1-mini',
    'OpenGVLab/InternVL-Chat-V1-2',
    'OpenGVLab/InternVL-Chat-V1-5',
    'OpenGVLab/Mini-InternVL-Chat-2B-V1-5',
    'OpenGVLab/InternVL2-2B',
    'OpenGVLab/InternVL2-4B',
    'OpenGVLab/InternVL2-8B',
    'OpenGVLab/InternVL2_5-2B',
    'OpenGVLab/InternVL2_5-4B',
    'OpenGVLab/InternVL2_5-8B',
    'OpenGVLab/InternVL3-2B',
    'OpenGVLab/InternVL3-8B',
    'OpenGVLab/InternVL3-9B',
    'OpenGVLab/InternVL3_5-1B',
    'OpenGVLab/InternVL3_5-4B',
    'OpenGVLab/InternVL3_5-8B',
    'OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview',
    'deepseek-ai/DeepSeek-V2-Lite',
    'deepseek-ai/DeepSeek-V3',
    'deepseek-ai/DeepSeek-R1',
    'deepseek-ai/DeepSeek-R1-Zero',
    'deepseek-ai/DeepSeek-V3.1',
    'deepseek-ai/deepseek-coder-1.3b-instruct',
    'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
    'zai-org/chatglm3-6b',
    'zai-org/glm-4-9b-chat',
    'zai-org/codegeex4-all-9b',
    'zai-org/cogvlm2-llama3-chat-19B',
    'microsoft/Phi-3-mini-128k-instruct',
    'microsoft/Phi-3-vision-128k-instruct',
    'microsoft/Phi-3.5-mini-instruct',
    'microsoft/Phi-3.5-vision-instruct',
    'microsoft/Phi-3.5-MoE-instruct',
    '01-ai/Yi-1.5-34B-Chat',
    # Accessing the following models is supposed to be authenticated
    # 'openbmb/MiniCPM-V-2_6',
    # 'google/gemma-3-4b-it',
]


@pytest.mark.parametrize('model_path', HF_MODELS_WITH_CHAT_TEMPLATES)
def test_HFChatTemplate_get_prompt_sequence_start_True(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    prompt = 'How to apply chat template using transformers?'
    messages = [{'role': 'user', 'content': prompt}]

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    assert model.get_prompt(prompt, sequence_start=True) == expected


@pytest.mark.parametrize('model_path', HF_MODELS_WITH_CHAT_TEMPLATES)
def test_HFChatTemplate_message2prompt_sequence_start_True(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    prompt = 'How to apply chat template using transformers?'
    messages = [{'role': 'user', 'content': prompt}]

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    assert model.messages2prompt(prompt, sequence_start=True) == expected
    assert model.messages2prompt(messages, sequence_start=True) == expected


def test_base_model():
    model = MODELS.get('internlm')(capability='completion')
    assert model.capability == 'completion'
    assert model.get_prompt('hi') == 'hi'
    assert model.messages2prompt('test') == 'test'


def test_vicuna():
    prompt = 'hello, can u introduce yourself'
    model = MODELS.get('vicuna')(capability='completion')
    assert model.get_prompt(prompt, sequence_start=True) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt

    model = MODELS.get('vicuna')(capability='chat', system='Provide answers in Python')
    assert model.get_prompt(prompt, sequence_start=True) != prompt
    assert model.get_prompt(prompt, sequence_start=False) != prompt
    assert model.system == 'Provide answers in Python'

    model = MODELS.get('vicuna')(capability='voice')
    _prompt = None
    with pytest.raises(AssertionError):
        _prompt = model.get_prompt(prompt, sequence_start=True)
        assert _prompt is None


def test_prefix_response():
    model = MODELS.get('hf')(model_path='Qwen/Qwen3-8B')
    messages = [dict(role='assistant', content='prefix test')]
    prompt = model.messages2prompt(messages)
    assert prompt[-len('prefix test'):] == 'prefix test'


def test_internlm_chat():
    prompt = 'hello, can u introduce yourself'
    model = MODELS.get('internlm')(capability='completion')
    assert model.get_prompt(prompt, sequence_start=True) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt
    assert model.stop_words is not None
    assert model.system == '<|System|>:'

    model = MODELS.get('internlm')(capability='chat', system='Provide answers in Python')
    assert model.get_prompt(prompt, sequence_start=True) != prompt
    assert model.get_prompt(prompt, sequence_start=False) != prompt
    assert model.system == 'Provide answers in Python'

    model = MODELS.get('internlm')(capability='voice')
    _prompt = None
    with pytest.raises(AssertionError):
        _prompt = model.get_prompt(prompt, sequence_start=True)
        assert _prompt is None


def test_baichuan():
    prompt = 'hello, can u introduce yourself'
    model = MODELS.get('baichuan2')(capability='completion')
    assert model.get_prompt(prompt, sequence_start=True) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt
    assert model.stop_words is None

    model = MODELS.get('baichuan2')(capability='chat')
    _prompt = model.get_prompt(prompt, sequence_start=True)
    assert _prompt == '' + prompt + ''


def test_llama2():
    prompt = 'hello, can u introduce yourself'
    model = MODELS.get('llama2')(capability='completion')
    assert model.get_prompt(prompt, sequence_start=True) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt
    assert model.stop_words is None
    assert model.meta_instruction is not None

    model = MODELS.get('llama2')(capability='chat', meta_instruction='Provide answers in Python')
    assert model.get_prompt(prompt, sequence_start=True) != prompt
    assert model.get_prompt(prompt, sequence_start=False) != prompt
    assert model.meta_instruction == 'Provide answers in Python'

    model = MODELS.get('llama2')(capability='voice')
    _prompt = None
    with pytest.raises(AssertionError):
        _prompt = model.get_prompt(prompt, sequence_start=True)
        assert _prompt is None


def test_codellama_completion():
    model = MODELS.get('codellama')(capability='completion')
    prompt = """\
import socket

def ping_exponential_backoff(host: str):"""
    assert model.get_prompt(prompt) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt
    assert model.stop_words is None


def test_codellama_infilling():
    model = MODELS.get('codellama')(capability='infilling')
    prompt = '''def remove_non_ascii(s: str) -> str:
    """ 
    return result
'''
    _prompt = model.get_prompt(prompt)
    assert _prompt.find('') == -1
    assert model.stop_words == ['']

    model = MODELS.get('codellama')(capability='infilling', suffix_first=True)
    _prompt = model.get_prompt(prompt)
    assert _prompt.find('') == -1


def test_codellama_chat():
    model = MODELS.get('codellama')(capability='chat', system='Provide answers in Python')
    prompt = 'Write a function that computes the set of sums of all contiguous sublists of a given list.'  # noqa: E501
    _prompt = model.get_prompt(prompt, sequence_start=True)
    assert _prompt.find('Provide answers in Python') != -1

    _prompt = model.get_prompt(prompt, sequence_start=False)
    assert _prompt.find('Provide answers in Python') == -1
    assert model.stop_words is None


def test_codellama_python_specialist():
    model = MODELS.get('codellama')(capability='python')
    prompt = """
    def remove_non_ascii(s: str) -> str:
"""
    assert model.get_prompt(prompt, sequence_start=True) == prompt
    assert model.get_prompt(prompt, sequence_start=False) == prompt
    assert model.stop_words is None


def test_codellama_others():
    model = None
    with pytest.raises(AssertionError):
        model = MODELS.get('codellama')(capability='java')
    assert model is None


@pytest.mark.parametrize(
    'model_path_or_name',
    ['deepseek-ai/deepseek-vl2-tiny', 'deepseek-ai/deepseek-vl2-small', 'deepseek-ai/deepseek-vl2'])
def test_deepseek_vl2(model_path_or_name):
    chat_template = MODELS.get('deepseek-vl2')()
    messages = [{
        'role': 'user',
        'content': 'This is image_1: \n'
        'This is image_2: \n'
        'This is image_3: \n Can you tell me what are in the images?',
        'images': [
            'images/multi_image_1.jpeg',
            'images/multi_image_2.jpeg',
            'images/multi_image_3.jpeg',
        ],
    }, {
        'role': 'assistant',
        'content': ''
    }]

    ref = '<|User|>: This is image_1: \nThis is image_2: \nThis is image_3: ' + \
          '\n Can you tell me what are in the images?\n\n<|Assistant|>:'
    lm_res = chat_template.messages2prompt(messages)
    assert ref == lm_res


@pytest.mark.parametrize('model_path', ['Qwen/Qwen3-30B-A3B', 'Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen3.5-35B-A3B'])
@pytest.mark.parametrize('enable_thinking', [True, False, None])
def test_qwen3(model_path, enable_thinking):
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    chat_template = MODELS.get('hf')(model_path)

    messages = [{
        'role': 'system',
        'content': 'you are a helpful assistant'
    }, {
        'role': 'user',
        'content': 'who are you'
    }, {
        'role': 'assistant',
        'content': 'I am an AI'
    }, {
        'role': 'user',
        'content': 'AGI is?'
    }]
    if enable_thinking is None:
        ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        ref = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True,
                                            enable_thinking=enable_thinking)
    lm_res = chat_template.messages2prompt(messages, enable_thinking=enable_thinking)
    assert ref == lm_res


# TODO(lvhan): bring this case back when internlm/Intern-S1 fix tokenizer
# @pytest.mark.parametrize('model_path', ['internlm/Intern-S1'])
# @pytest.mark.parametrize('enable_thinking', [None, True, False])
# @pytest.mark.parametrize('has_user_sys', [True, False])
# def test_interns1(model_path, enable_thinking, has_user_sys):
#     from transformers import AutoTokenizer
#     try:
#         tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
#     except OSError:
#         pytest.skip(reason=f'{model_path} not exists')

#     chat_template = MODELS.get('hf')(model_path)

#     messages = [{
#         'role': 'system',
#         'content': 'you are a helpful assistant'
#     }, {
#         'role': 'user',
#         'content': 'who are you'
#     }, {
#         'role': 'assistant',
#         'content': 'I am an AI'
#     }, {
#         'role': 'user',
#         'content': 'AGI is?'
#     }]
#     if not has_user_sys:
#         messages = messages[1:]

#     if enable_thinking is None:
#         ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#     else:
#         ref = tokenizer.apply_chat_template(messages,
#                                             tokenize=False,
#                                             add_generation_prompt=True,
#                                             enable_thinking=enable_thinking)
#     lm_res = chat_template.messages2prompt(messages, enable_thinking=enable_thinking)
#     assert ref == lm_res


@pytest.mark.parametrize('model_path', ['Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen3-8B'])
def test_HFChatTemplate_get_prompt_sequence_start_False_Qwen(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    assert model.stop_words == ['<|im_end|>']

    prompt = 'How to apply chat template using transformers?'
    assert model.get_prompt(prompt,
                            sequence_start=False) == f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'


@pytest.mark.parametrize('model_path', ['Qwen/Qwen3.5-35B-A3B'])
def test_HFChatTemplate_get_prompt_sequence_start_False_Qwen3_5(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    assert model.stop_words == ['<|im_end|>']

    prompt = 'How to apply chat template using transformers?'
    assert model.get_prompt(
        prompt, sequence_start=False) == f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n\n'


@pytest.mark.parametrize('model_path', ['deepseek-ai/DeepSeek-V3'])
def test_HFChatTemplate_DeepSeek_V3(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    assert model.stop_words == ['<|end▁of▁sentence|>']

    prompt = 'How to apply chat template using transformers?'
    assert model.get_prompt(prompt, sequence_start=False) == f'<|User|>{prompt}<|Assistant|>'


@pytest.mark.parametrize('model_path', ['deepseek-ai/DeepSeek-R1'])
def test_HFChatTemplate_DeepSeek_thinking(model_path):
    model = MODELS.get('hf')(model_path=model_path)
    assert model.stop_words == ['<|end▁of▁sentence|>']

    prompt = 'How to apply chat template using transformers?'
    assert model.get_prompt(prompt, sequence_start=False) == f'<|User|>{prompt}<|Assistant|>\n'


@pytest.mark.parametrize('model_path', ['Qwen/Qwen3-VL-8B-Instruct', 'Qwen/Qwen3.5-35B-A3B'])
def test_HFChatTemplate_Qwen3_VL_with_vision_id(model_path):
    model = MODELS.get('hf')(model_path=model_path)

    # testcase from https://github.com/QwenLM/Qwen3-VL
    messages = [
        {
            'role': 'user',
            'content': [{
                'type': 'image'
            }, {
                'type': 'text',
                'text': 'Hello, how are you?'
            }],
        },
        {
            'role': 'assistant',
            'content': "I'm doing well, thank you for asking. How can I assist you today?",
        },
        {
            'role':
            'user',
            'content': [
                {
                    'type': 'text',
                    'text': 'Can you describe these images and video?'
                },
                {
                    'type': 'image'
                },
                {
                    'type': 'image'
                },
                {
                    'type': 'video'
                },
                {
                    'type': 'text',
                    'text': 'These are from my vacation.'
                },
            ],
        },
        {
            'role':
            'assistant',
            'content':
            """I'd be happy to describe the images and video for you.
                Could you please provide more context about your vacation?""",
        },
        {
            'role': 'user',
            'content': 'It was a trip to the mountains. Can you see the details in the images and video?',
        },
    ]

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)
    chat_template_kwargs = dict(add_vision_id=True)
    lm_res = model.messages2prompt(messages, **chat_template_kwargs)
    assert expected == lm_res


@pytest.mark.parametrize('model_path', ['google/gemma-2-9b-it', 'google/gemma-3-12b-it'])
def test_gemma_chat_template(model_path):
    messages = [{'role': 'user', 'content': 'who are you'}]

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    model = MODELS.get('hf')(model_path=model_path)
    lm_res = model.messages2prompt(messages)
    assert expected == lm_res

    messages += [{'role': 'assistant', 'content': 'I am an AI'}, {'role': 'user', 'content': 'AGI is?'}]
    lm_res = model.messages2prompt(messages, sequence_start=False)
    assert lm_res == """user
who are you
model
I am an AI
user
AGI is?
model
"""


================================================
FILE: tests/test_lmdeploy/test_pipeline.py
================================================
import gc

import pytest
import torch

from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline
from lmdeploy.messages import Response

MODEL_ID = 'Qwen/Qwen3-8B'


@pytest.mark.parametrize('backend', ['pytorch', 'turbomind'], scope='class')
class TestBackendInference:
    """Test class grouping all tests for each backend."""

    @pytest.fixture(scope='class', autouse=True)
    def backend_config(self, backend):
        """Parametrized backend configuration for all tests."""

        if backend == 'pytorch':
            return PytorchEngineConfig(session_len=4096, max_batch_size=4, tp=1)
        elif backend == 'turbomind':
            return TurbomindEngineConfig(session_len=4096, max_batch_size=4, tp=1)
        else:
            raise ValueError(f'Unknown backend type: {backend}')

    @pytest.fixture(scope='class', autouse=True)
    def pipe(self, backend_config):
        """Shared pipeline instance across all tests in class."""
        pipe = pipeline(MODEL_ID, backend_config=backend_config)
        yield pipe
        pipe.close()
        del pipe
        gc.collect()
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()

    def test_infer_single_string(self, pipe):
        """Test infer with single string prompt."""
        prompt = 'Hello, how are you?'
        response = pipe.infer(prompt)

        assert isinstance(response, Response)
        assert hasattr(response, 'text')
        assert hasattr(response, 'generate_token_len')
        assert hasattr(response, 'input_token_len')
        assert len(response.text) > 0

    def test_infer_batch_strings(self, pipe):
        """Test infer with batch of string prompts."""
        prompts = ['What is AI?', 'Explain quantum computing', 'Tell me a joke']
        responses = pipe.infer(prompts)

        assert isinstance(responses, list)
        assert len(responses) == len(prompts)
        for resp in responses:
            assert isinstance(resp, Response)
            assert len(resp.text) > 0

    def test_infer_openai_format(self, pipe):
        """Test infer with OpenAI-style message format."""
        prompts = [[{
            'role': 'user',
            'content': 'What is machine learning?'
        }], [{
            'role': 'user',
            'content': 'Define deep learning'
        }]]
        responses = pipe.infer(prompts)

        assert len(responses) == 2
        for resp in responses:
            assert isinstance(resp, Response)

    def test_infer_with_generation_config(self, pipe):
        """Test infer with custom GenerationConfig."""
        gen_config = GenerationConfig(max_new_tokens=50, temperature=0.5, top_p=0.9, top_k=40, do_sample=True)
        prompt = 'Write a haiku about nature'
        response = pipe.infer(prompt, gen_config=gen_config)

        assert isinstance(response, Response)
        assert response.generate_token_len <= 50

    def test_call_method(self, pipe):
        """Test __call__ method as shortcut for infer."""
        prompt = 'What is Python?'
        response = pipe(prompt)

        assert isinstance(response, Response)
        assert len(response.text) > 0

    def test_stream_infer_single(self, pipe):
        """Test stream_infer with single prompt."""
        prompt = 'Count from 1 to 5'
        generator = pipe.stream_infer(prompt)

        chunks = []
        for chunk in generator:
            chunks.append(chunk)
            assert isinstance(chunk, Response)

        assert len(chunks) > 0
        full_text = ''.join([c.text for c in chunks])
        assert len(full_text) > 0

    def test_stream_infer_batch(self, pipe):
        """Test stream_infer with batch prompts."""
        prompts = ['First prompt', 'Second prompt']
        generator = pipe.stream_infer(prompts)

        responses = {}
        for chunk in generator:
            chunks = responses.setdefault(chunk.index, [])
            chunks.append(chunk)
            assert isinstance(chunk, Response)

        assert len(responses) == len(prompts)
        for chunks in responses.values():
            full_text = ''.join([c.text for c in chunks])
            assert len(full_text) > 0

    def test_stream_infer_with_session(self, pipe):
        """Test stream_infer with session for multi-turn context."""
        session = pipe.session()
        prompt1 = 'Hello! My name is Alice.'
        step = 0

        # First turn
        generator = pipe.stream_infer(prompts=prompt1,
                                      sessions=session,
                                      gen_config=GenerationConfig(max_new_tokens=30),
                                      sequence_start=True,
                                      sequence_end=False,
                                      enable_thinking=False)
        resp = None
        for out in generator:
            resp = resp.extend(out) if resp else out

        step += resp.generate_token_len + resp.input_token_len

        response1 = resp.text

        assert response1

        # Second turn should remember context
        prompt2 = 'What is my name?'
        session.step = step
        generator = pipe.stream_infer(prompts=prompt2,
                                      sessions=session,
                                      gen_config=GenerationConfig(max_new_tokens=30),
                                      sequence_start=False,
                                      sequence_end=False,
                                      enable_thinking=False)

        resp = None
        for out in generator:
            resp = resp.extend(out) if resp else out

        step += out.generate_token_len + out.input_token_len

        response2 = resp.text

        assert 'alice' in response2.lower()

    def test_chat_streaming(self, pipe):
        """Test chat method with streaming output."""
        prompt = 'Tell me a short story'
        session = pipe.session()

        generator = pipe.chat(prompt=prompt,
                              session=session,
                              stream_response=True,
                              gen_config=GenerationConfig(max_new_tokens=50))

        chunks = []
        for chunk in generator:
            chunks.append(chunk)
            assert isinstance(chunk, Response)

        assert len(chunks) > 0
        assert session.response is not None
        assert session.step > 0

    def test_chat_non_streaming(self, pipe):
        """Test chat method with non-streaming output."""
        prompt = 'What is 2+2?'
        session = pipe.chat(prompt=prompt,
                            stream_response=False,
                            gen_config=GenerationConfig(max_new_tokens=20),
                            enable_thinking=False)

        assert session is not None
        assert hasattr(session, 'response')
        assert hasattr(session, 'history')
        assert len(session.history) == 1
        assert '4' in session.response.text or 'four' in session.response.text.lower()

    def test_chat_multi_turn(self, pipe):
        """Test chat method with multi-turn conversation."""
        # First turn
        session = pipe.chat(prompt='My favorite color is blue.',
                            stream_response=False,
                            gen_config=GenerationConfig(max_new_tokens=30),
                            enable_thinking=False)

        # Second turn should remember context
        session = pipe.chat(prompt='What is my favorite color?',
                            session=session,
                            stream_response=False,
                            gen_config=GenerationConfig(max_new_tokens=30),
                            enable_thinking=False)

        assert 'blue' in session.response.text.lower()
        assert len(session.history) == 2

    def test_session_creation(self, pipe):
        """Test session method to create new sessions."""
        session1 = pipe.session()
        session2 = pipe.session()

        assert session1 is not None
        assert session2 is not None
        assert session1 != session2

    def test_get_ppl_single(self, pipe):
        """Test get_ppl with single input."""
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

        text = 'This is a test sentence.'
        input_ids = tokenizer.encode(text, return_tensors='pt')[0].tolist()

        ppl = pipe.get_ppl(input_ids)

        assert isinstance(ppl, list)
        assert len(ppl) == 1
        assert isinstance(ppl[0], float)
        assert ppl[0] > 0

    def test_get_ppl_batch(self, pipe):
        """Test get_ppl with batch inputs."""
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

        texts = ['First text.', 'Second text.']
        input_ids_list = [tokenizer.encode(text, return_tensors='pt')[0].tolist() for text in texts]

        ppl = pipe.get_ppl(input_ids_list)

        assert isinstance(ppl, list)
        assert len(ppl) == len(texts)
        for score in ppl:
            assert isinstance(score, float)
            assert score > 0

    def test_stream_infer_stream_response_parameter(self, pipe):
        """Test stream_infer stream_response parameter."""
        prompt = 'Test'
        gen = pipe.stream_infer(prompt, stream_response=True)
        assert hasattr(gen, '__iter__')

        results = list(gen)
        assert len(results) > 0

    @pytest.mark.parametrize('max_new_tokens', [10, 50, 100])
    def test_infer_different_max_tokens(self, pipe, max_new_tokens):
        """Parametrized test for different max_new_tokens values."""
        gen_config = GenerationConfig(max_new_tokens=max_new_tokens)
        prompt = 'Continue: Once upon a time'
        response = pipe.infer(prompt, gen_config=gen_config)

        assert response.generate_token_len <= max_new_tokens + 5

    def test_batch_infer_different_gen_configs(self, pipe):
        """Test batch infer with different GenerationConfig per prompt."""
        prompts = ['Short answer: What is AI?', 'Long answer: Explain ML']
        gen_configs = [GenerationConfig(max_new_tokens=20), GenerationConfig(max_new_tokens=50)]

        responses = pipe.infer(prompts, gen_config=gen_configs)

        assert len(responses) == 2
        assert responses[0].generate_token_len <= responses[1].generate_token_len + 10

    def test_infer_zero_tokens(self, pipe):
        """Test infer with max_new_tokens=0 to end generation immediately
        without producing tokens."""
        gen_config = GenerationConfig(max_new_tokens=0)
        prompt = 'This prompt should not generate any response'
        response = pipe.infer(prompt, gen_config=gen_config, enable_thinking=False)
        assert isinstance(response, Response)
        assert response.generate_token_len == 0


================================================
FILE: tests/test_lmdeploy/test_qwen3_parser.py
================================================
import collections
import json
import time
from typing import Generator, List, Tuple, Union

import pytest
import shortuuid

from lmdeploy.serve.openai.api_server import VariableInterface
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
                                            ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo)
from lmdeploy.serve.openai.reasoning_parser.qwen_qwq_reasoning_parser import QwenQwQReasoningParser
from lmdeploy.serve.openai.tool_parser.qwen3_parser import Qwen3ToolParser

TestExpects = collections.namedtuple('TestExpects', 'func_name location')


class DummyTokenizer:

    def decode(self, token_ids: List[int]) -> str:
        return ' '.join(map(str, token_ids))

    def encode(self, text: str) -> List[int]:
        return [ord(c) for c in text]


DELTA_TEXT_SEQUENCE = [
    '',
    '\n',
    '好的',
    ',',
    '用户',
    '问',
    '的是',
    '北京',
    '的',
    '天气',
    '怎么样',
    '。',
    '我',
    '需要',
    '调',
    '用',
    'get',
    '_weather',
    '这个',
    '工具',
    '来',
    '获取',
    '信息',
    '。',
    '首先',
    ',',
    '确认',
    '用户',
    '提供的',
    '地点',
    '是',
    '北京',
    ',',
    '参数',
    '正确',
    '。',
    '然后',
    '检查',
    '工具',
    '的',
    '参数',
    '要求',
    ',',
    '只需要',
    'location',
    ',',
    '类型',
    '是',
    '字符串',
    '。',
    '于是',
    '构造',
    '参数',
    '对象',
    ',',
    '调',
    '用',
    '函数',
    ',',
    '返回',
    '结果',
    '。',
    '确保',
    '没有',
    '遗漏',
    '必要',
    '参数',
    ',',
    '比如',
    'location',
    '是',
    '必须',
    '的',
    ',',
    '这里',
    '已经',
    '提供',
    ',',
    '所以',
    '没问题',
    '。',
    '最后',
    '将',
    '结果',
    '以',
    '自然',
    '语言',
    '回复',
    '用户',
    '。\n',
    '',
    '\n\n',
    '',
    '\n',
    '{"',
    'name',
    '":',
    ' "',
    'get',
    '_weather',
    '",',
    ' "',
    'arguments',
    '":',
    ' {"',
    'location',
    '":',
    ' "',
    '北京',
    '"}}\n',
    '',
]

DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [
    '\n\n',
    '',
    '\n',
    '{"',
    'name',
    '":',
    ' "',
    'get',
    '_weather',
    '",',
    ' "',
    'arguments',
    '":',
    ' {"',
    'location',
    '":',
    ' "',
    '上海',
    '"}}\n',
    '',
]

EXPECTED_CONTENT = ''
EXPECTED_REASONING_CONTENT = ''.join((
    '好的,用户问的是北京的天气怎么样。我需要调用get_weather这个工具来获取信息。',
    '首先,确认用户提供的地点是北京,参数正确。然后检查工具的参数要求,',
    '只需要location,类型是字符串。于是构造参数对象,调用函数,返回结果。',
    '确保没有遗漏必要参数,比如location是必须的,这里已经提供,所以没问题。',
    '最后将结果以自然语言回复用户。',
))


def _chat_completion_v1(
        request: ChatCompletionRequest,
        text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]:
    request_id = f'chat-{shortuuid.random()}'
    created_time = int(time.time())
    model_name = request.model
    if request.stream:

        def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]:
            previous_text = ''
            current_text = ''
            finish_reason = 'stop'
            has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None
            for text in text_sequence:
                logprobs, usage = None, None
                delta_message = DeltaMessage(role='assistant', content=text)
                if has_parser:
                    current_text = current_text + text
                if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
                        previous_text=previous_text,
                        current_text=current_text,
                        delta_text=delta_message.content,
                        previous_token_ids=[],
                        current_token_ids=[],
                        delta_token_ids=[],
                        request=request)
                    if tool_delta is not None:
                        delta_message.tool_calls = tool_delta.tool_calls
                        delta_message.content = tool_delta.content or ''
                if VariableInterface.reasoning_parser is not None:
                    reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(
                        previous_text=previous_text,
                        current_text=current_text,
                        delta_text=delta_message.content,
                        previous_token_ids=[],
                        current_token_ids=[],
                        delta_token_ids=[])
                    if reasoning_delta is not None:
                        delta_message.reasoning_content = reasoning_delta.reasoning_content
                        delta_message.content = reasoning_delta.content or ''
                if has_parser:
                    previous_text = current_text

                choice_data = ChatCompletionResponseStreamChoice(index=0,
                                                                 delta=delta_message,
                                                                 finish_reason=finish_reason,
                                                                 logprobs=logprobs)
                response = ChatCompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[choice_data],
                    usage=usage,
                )
                yield response

        return completion_stream_generator()

    # copied and simplified from api_server.py:chat_completions_v1
    text = ''.join(text_sequence)
    tool_calls = None
    reasoning_content = None
    finish_reason = 'stop'
    if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
        tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
        text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
        if isinstance(tool_calls, List) and len(tool_calls):
            if finish_reason == 'stop':
                finish_reason = 'tool_calls'

    if VariableInterface.reasoning_parser is not None:
        reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)

    choices = []
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),
        finish_reason=finish_reason,
    )
    choices.append(choice_data)

    return ChatCompletionResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        choices=choices,
        usage=UsageInfo(),
    )


def _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:
    # Call parser.extract_tool_calls_streaming with delta_text specified in `DELTA_TEXT_SEQUENCE`.
    # `current_text` and `previous_text` init values and update logic
    # can be found in lmdeploy/serve/openai/api_server.py:455-523.
    content = ''
    reasoning_content = ''
    tool_calls = {}

    for stream_resp in _chat_completion_v1(request, text_sequence):
        delta_message: DeltaMessage = stream_resp.choices[0].delta
        if delta_message.content:
            content += delta_message.content
        if delta_message.reasoning_content:
            reasoning_content += delta_message.reasoning_content
        if delta_message.tool_calls:
            for c in delta_message.tool_calls:
                existing_call = tool_calls.get(c.id, None)
                if not existing_call:
                    tool_calls[c.id] = c
                    continue
                # merge with existing
                if c.function.name:
                    existing_call.function.name = c.function.name
                if c.function.arguments:
                    existing_call.function.arguments = existing_call.function.arguments or ''
                    existing_call.function.arguments += c.function.arguments
    return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index))


@pytest.mark.parametrize(('text_sequence', 'expects'), [
    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]),
    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'),
                                          TestExpects('get_weather', '上海')]),
])
def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)
    request = ChatCompletionRequest(model='qwen', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, text_sequence)
    assert len(tool_calls) == len(expects)
    for parsed_call, expected_call in zip(tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args['location'] == expected_call.location
        assert content.strip() == EXPECTED_CONTENT
        assert reasoning_content.strip() == EXPECTED_REASONING_CONTENT


@pytest.mark.parametrize(('text_sequence', 'expects'), [
    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]),
    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'),
                                          TestExpects('get_weather', '上海')]),
])
def test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]):
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)
    resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False),
                                                       text_sequence)

    assert len(resp.choices) == 1
    first_message = resp.choices[0].message
    assert first_message.content is None
    assert first_message.reasoning_content == EXPECTED_REASONING_CONTENT
    assert len(first_message.tool_calls) == len(expects)
    for parsed_call, expected_call in zip(first_message.tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args['location'] == expected_call.location


def test_no_think_nonstream():
    text_sequence = [
        '你好',
        '呀',
        '!',
        '✨',
        '',
        ' 很',
        '高兴',
        '见到',
        '你',
        '!',
    ]
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)
    resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False),
                                                       text_sequence)

    assert len(resp.choices) == 1
    first_message = resp.choices[0].message
    assert first_message.content == '你好呀!✨ 很高兴见到你!'
    assert first_message.reasoning_content is None


================================================
FILE: tests/test_lmdeploy/test_qwen3coder_parser.py
================================================
import collections
import json
import time
from typing import Generator, List, Tuple, Union

import pytest
import shortuuid

from lmdeploy.serve.openai.api_server import VariableInterface
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
                                            ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo)
from lmdeploy.serve.openai.tool_parser.qwen3coder_parser import Qwen3CoderToolParser

TestExpects = collections.namedtuple('TestExpects', 'func_name kwargs')


class DummyTokenizer:

    def decode(self, token_ids: List[int]) -> str:
        return ' '.join(map(str, token_ids))

    def encode(self, text: str) -> List[int]:
        return [ord(c) for c in text]


DELTA_TEXT_SEQUENCE = [
    '好的,我现在帮你调用工具。\n',
    '',
    '\n',
    '\n',
    '',
    '北京\n',
    'celsius\n',
    '\n',
    '',
]

DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [
    '\n\n',
    '',
    '\n\n',
    '上海\n',
    '\n',
    '',
]

EXPECTED_CONTENT = '好的,我现在帮你调用工具。'


def _chat_completion_v1(
        request: ChatCompletionRequest,
        text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]:
    request_id = f'chat-{shortuuid.random()}'
    created_time = int(time.time())
    model_name = request.model
    if request.stream:

        def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]:
            previous_text = ''
            current_text = ''
            finish_reason = 'stop'
            has_parser = (VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None)
            for text in text_sequence:
                logprobs, usage = None, None
                delta_message = DeltaMessage(role='assistant', content=text)
                if has_parser:
                    current_text = current_text + text
                has_tool = VariableInterface.tool_parser is not None
                if request.tool_choice != 'none' and has_tool:
                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
                        previous_text=previous_text,
                        current_text=current_text,
                        delta_text=delta_message.content,
                        previous_token_ids=[],
                        current_token_ids=[],
                        delta_token_ids=[],
                        request=request)
                    if tool_delta is not None:
                        delta_message.tool_calls = tool_delta.tool_calls
                        delta_message.content = tool_delta.content or ''
                if VariableInterface.reasoning_parser is not None:
                    parser = VariableInterface.reasoning_parser
                    reasoning_delta = parser.extract_reasoning_content_streaming(previous_text=previous_text,
                                                                                 current_text=current_text,
                                                                                 delta_text=delta_message.content,
                                                                                 previous_token_ids=[],
                                                                                 current_token_ids=[],
                                                                                 delta_token_ids=[])
                    if reasoning_delta is not None:
                        delta_message.reasoning_content = (reasoning_delta.reasoning_content)
                        delta_message.content = reasoning_delta.content or ''
                if has_parser:
                    previous_text = current_text

                choice_data = ChatCompletionResponseStreamChoice(index=0,
                                                                 delta=delta_message,
                                                                 finish_reason=finish_reason,
                                                                 logprobs=logprobs)
                response = ChatCompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[choice_data],
                    usage=usage,
                )
                yield response

        return completion_stream_generator()

    text = ''.join(text_sequence)
    tool_calls = None
    reasoning_content = None
    finish_reason = 'stop'
    has_tool = VariableInterface.tool_parser is not None
    if request.tool_choice != 'none' and has_tool:
        tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
        text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
        if isinstance(tool_calls, List) and len(tool_calls):
            if finish_reason == 'stop':
                finish_reason = 'tool_calls'

    if VariableInterface.reasoning_parser is not None:
        parser = VariableInterface.reasoning_parser
        reasoning_content, text = parser.extract_reasoning_content(text, request)

    choices = []
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),
        finish_reason=finish_reason,
    )
    choices.append(choice_data)

    return ChatCompletionResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        choices=choices,
        usage=UsageInfo(),
    )


def _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:
    content = ''
    reasoning_content = ''
    tool_calls = {}

    for stream_resp in _chat_completion_v1(request, text_sequence):
        delta_message: DeltaMessage = stream_resp.choices[0].delta
        if delta_message.content:
            content += delta_message.content
        if delta_message.reasoning_content:
            reasoning_content += delta_message.reasoning_content
        if delta_message.tool_calls:
            for c in delta_message.tool_calls:
                existing_call = tool_calls.get(c.id, None)
                if not existing_call:
                    tool_calls[c.id] = c
                    continue
                # merge with existing
                if c.function.name:
                    existing_call.function.name = c.function.name
                if c.function.arguments:
                    existing_call.function.arguments = (existing_call.function.arguments or '')
                    existing_call.function.arguments += c.function.arguments
    return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index))


@pytest.mark.parametrize(('text_sequence', 'expects'), [
    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {
        'location': '北京',
        'unit': 'celsius'
    })]),
    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [
        TestExpects('get_weather', {
            'location': '北京',
            'unit': 'celsius'
        }),
        TestExpects('get_weather', {'location': '上海'})
    ]),
])
def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = None
    request = ChatCompletionRequest(model='qwen3coder', messages=[], stream=True)
    content, reasoning_content, tool_calls = _stream_parse(request, text_sequence)
    assert len(tool_calls) == len(expects)
    for parsed_call, expected_call in zip(tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args == expected_call.kwargs
        assert content.strip() == EXPECTED_CONTENT


@pytest.mark.parametrize(('text_sequence', 'expects'), [
    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {
        'location': '北京',
        'unit': 'celsius'
    })]),
    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [
        TestExpects('get_weather', {
            'location': '北京',
            'unit': 'celsius'
        }),
        TestExpects('get_weather', {'location': '上海'})
    ]),
])
def test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]):
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = None
    resp: ChatCompletionResponse = _chat_completion_v1(
        ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)

    assert len(resp.choices) == 1
    first_message = resp.choices[0].message
    assert first_message.content.strip() == EXPECTED_CONTENT
    assert first_message.reasoning_content is None
    assert len(first_message.tool_calls) == len(expects)
    for parsed_call, expected_call in zip(first_message.tool_calls, expects):
        assert parsed_call.function.name == expected_call.func_name
        args = json.loads(parsed_call.function.arguments)
        assert args == expected_call.kwargs


def test_no_think_nonstream():
    text_sequence = [
        '你好',
        '呀',
        '!',
        '✨',
        '',
        ' 很',
        '高兴',
        '见到',
        '你',
        '!',
    ]
    tokenizer = DummyTokenizer()
    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
    VariableInterface.reasoning_parser = None
    resp: ChatCompletionResponse = _chat_completion_v1(
        ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)

    assert len(resp.choices) == 1
    first_message = resp.choices[0].message
    assert first_message.content == '你好呀!✨ 很高兴见到你!'
    assert first_message.reasoning_content is None


================================================
FILE: tests/test_lmdeploy/test_tokenizer.py
================================================
import random

import pytest

from lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer, Tokenizer


@pytest.mark.parametrize('model_path', [
    'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', 'baichuan-inc/Baichuan2-7B-Chat', 'upstage/SOLAR-0-70b-16bit',
    'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf', 'THUDM/chatglm2-6b', '01-ai/Yi-6B-200k',
    '01-ai/Yi-34B-Chat', '01-ai/Yi-6B-Chat', 'WizardLM/WizardLM-70B-V1.0', 'codellama/CodeLlama-34b-Instruct-hf'
])
@pytest.mark.parametrize('input', [' hi, this is a test 😆😆! 為什麼我還在用繁體字 😆😆       ' * 5])
@pytest.mark.parametrize('interval', [1, 3])
@pytest.mark.parametrize('add_special_tokens', [True, False])
@pytest.mark.parametrize('skip_special_tokens', [True, False])
def test_tokenizer(model_path, input, interval, add_special_tokens, skip_special_tokens):
    tokenizer = Tokenizer(model_path).model
    encoded = tokenizer.encode(input, False, add_special_tokens=add_special_tokens)
    output = ''
    input = tokenizer.decode(encoded, skip_special_tokens=skip_special_tokens)
    state = DetokenizeState()
    for i in range(0, len(encoded), interval):
        offset = i + interval
        if offset < len(encoded):
            # lmdeploy may decode nothing when concurrency is high
            if random.randint(1, 10) < 4:
                offset -= interval
        decoded, state = tokenizer.detokenize_incrementally(encoded[:offset], state, skip_special_tokens)
        output += decoded
    assert input == output, 'input string should equal to output after enc-dec'


@pytest.mark.parametrize('model_path', [
    'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', 'baichuan-inc/Baichuan2-7B-Chat', 'codellama/CodeLlama-7b-hf',
    'upstage/SOLAR-0-70b-16bit'
])
@pytest.mark.parametrize('stop_words', ['.', ' ', '?', ''])
def test_tokenizer_with_stop_words(model_path, stop_words):
    tokenizer = HuggingFaceTokenizer(model_path)
    indexes = tokenizer.indexes_containing_token(stop_words)
    assert indexes is not None


def test_qwen_vl_decode_special():
    from lmdeploy.tokenizer import Tokenizer
    tok = Tokenizer('Qwen/Qwen-VL-Chat')
    try:
        tok.decode([151857])
        assert (0)
    except Exception as e:
        assert str(e) == 'Unclosed image token'


def test_glm4_special_token():
    from lmdeploy.tokenizer import ChatGLM4Tokenizer, Tokenizer
    model_path = 'THUDM/glm-4-9b-chat'
    tokenizer = Tokenizer(model_path)
    assert isinstance(tokenizer.model, ChatGLM4Tokenizer)
    special_tokens = [
        '<|endoftext|>', '[MASK]', '[gMASK]', '[sMASK]', '', '', '<|system|>', '<|user|>', '<|assistant|>',
        '<|observation|>', '<|begin_of_image|>', '<|end_of_image|>', '<|begin_of_video|>', '<|end_of_video|>'
    ]
    speicial_token_ids = [i for i in range(151329, 151343)]

    for token, token_id in zip(special_tokens, speicial_token_ids):
        _token_id = tokenizer.encode(token, add_bos=False)
        assert len(_token_id) == 1 and _token_id[0] == token_id


@pytest.mark.parametrize('model_path',
                         ['Qwen/Qwen2-7B-Instruct', 'deepseek-ai/deepseek-vl-1.3b-chat', 'OpenGVLab/InternVL2-1B'])
def test_check_transformers_version(model_path):
    tokenizer = HuggingFaceTokenizer(model_path)
    assert tokenizer is not None


================================================
FILE: tests/test_lmdeploy/test_turbomind/test_converter.py
================================================
# yapf: disable
from lmdeploy import TurbomindEngineConfig
from lmdeploy.turbomind import update_parallel_config
from lmdeploy.turbomind.deploy.converter import (get_input_model_registered_name,
                                                 get_output_model_registered_name_and_config)
from lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS

# yapf: enable


def test_torch_dtype_fallback():
    """torch_dtype is deprecated in transformers v5+; dtype should be
    preferred.

    This test ensures get_output_model_registered_name_and_config still works
    for models whose config exposes either `dtype` or `torch_dtype`.
    """
    _, config = get_output_model_registered_name_and_config(
        'internlm/internlm2-chat-7b',
        model_format='hf',
        dtype='auto',
        group_size=0,
    )
    assert config.weight_type in ('float16', 'bfloat16')


def test_ffn_reader_kind_none():
    """FFN readers must handle kind=None (returns filter list, not tensors).

    This is the probe call from Ffn.apply() to discover parameter keys before loading actual tensor data. A missing
    guard causes KeyError with 'None' in the key string (regression test for InternLM2Reader._ffn bug).
    """
    import re

    from lmdeploy.turbomind.deploy.source_model.internlm2 import InternLM2Reader
    from lmdeploy.turbomind.deploy.source_model.llama import LlamaReader

    # Create minimal readers with fake params that match ffn patterns
    fake_params = {
        'model.layers.0.mlp.gate_proj.weight': None,
        'model.layers.0.mlp.down_proj.weight': None,
        'model.layers.0.mlp.up_proj.weight': None,
        'model.layers.0.feed_forward.w1.weight': None,
        'model.layers.0.feed_forward.w2.weight': None,
        'model.layers.0.feed_forward.w3.weight': None,
    }

    # LlamaReader with kind=None should return filtered key list
    reader = LlamaReader.__new__(LlamaReader)
    reader.params = dict(fake_params)
    reader.ffn_pattern = r'mlp'
    result = reader._ffn(0, None)
    assert isinstance(result, list)
    assert len(result) > 0
    assert all(isinstance(k, str) for k in result)
    assert all(re.search(r'mlp', k) for k in result)

    # InternLM2Reader with kind=None should also return filtered key list
    reader2 = InternLM2Reader.__new__(InternLM2Reader)
    reader2.params = dict(fake_params)
    reader2.fp8_quant = None
    reader2.ffn_pattern = r'feed_forward'
    result2 = reader2._ffn(0, None)
    assert isinstance(result2, list)
    assert len(result2) > 0
    assert all(isinstance(k, str) for k in result2)
    assert all(re.search(r'feed_forward', k) for k in result2)


def test_registered_models():
    for model, model_format, group_size, weight_type, register_name in [
        ('internlm/internlm2-7b', 'hf', 0, 'bfloat16', 'tm'), ('baichuan-inc/Baichuan-7B', 'hf', 0, 'float16', 'tm'),
        ('baichuan-inc/Baichuan2-7B-Chat', 'hf', 0, 'bfloat16', 'tm'),
        ('baichuan-inc/Baichuan-13B-Chat', 'hf', 0, 'bfloat16', 'tm'),
        ('baichuan-inc/Baichuan2-13B-Chat', 'hf', 0, 'bfloat16', 'tm'),
        ('internlm/internlm-chat-7b', 'hf', 0, 'float16', 'tm'),
        ('internlm/internlm2-chat-7b', 'hf', 0, 'bfloat16', 'tm'),
        ('internlm/internlm-xcomposer2-4khd-7b', 'hf', 0, 'bfloat16', 'tm'),
        ('internlm/internlm-xcomposer2-vl-7b', 'hf', 0, 'bfloat16', 'tm'),
        ('internlm/internlm-xcomposer2-7b', 'hf', 0, 'bfloat16', 'tm'),
        ('lmsys/vicuna-7b-v1.5', 'hf', 0, 'float16', 'tm'), ('01-ai/Yi-1.5-9B', 'hf', 0, 'bfloat16', 'tm'),
        ('deepseek-ai/deepseek-coder-6.7b-instruct', 'hf', 0, 'bfloat16', 'tm'),
        ('deepseek-ai/deepseek-llm-7b-chat', 'hf', 0, 'bfloat16', 'tm'),
        ('Qwen/Qwen-7B-Chat', 'hf', 0, 'bfloat16', 'tm'), ('Qwen/Qwen1.5-7B-Chat', 'hf', 0, 'bfloat16', 'tm'),
        ('Qwen/Qwen2-7B-Instruct', 'hf', 0, 'bfloat16', 'tm'), ('Qwen/Qwen-VL-Chat', 'hf', 0, 'bfloat16', 'tm'),
        ('liuhaotian/llava-v1.6-34b', 'hf', 0, 'bfloat16', 'tm'),
        ('liuhaotian/llava-v1.6-mistral-7b', 'hf', 0, 'bfloat16', 'tm'),
        ('liuhaotian/llava-v1.6-vicuna-13b', 'hf', 0, 'bfloat16', 'tm'),
        ('OpenGVLab/InternVL-Chat-V1-5', 'hf', 0, 'bfloat16', 'tm'),
        ('deepseek-ai/deepseek-vl-7b-chat', 'hf', 0, 'float16', 'tm'),
        ('Qwen/Qwen1.5-4B-Chat-AWQ', 'awq', 128, 'int4', 'tm'),
        ('solidrust/Meta-Llama-3-8B-Instruct-hf-AWQ', 'awq', 128, 'int4', 'tm'),
        ('internlm/internlm2-chat-20b-4bits', 'awq', 128, 'int4', 'tm'),
        ('internlm/internlm-xcomposer2-vl-7b-4bit', 'awq', 128, 'int4', 'tm')
    ]:
        input_name = get_input_model_registered_name(model, model_format=model_format)
        assert input_name in list(INPUT_MODELS.module_dict.keys())

        output_name, config = get_output_model_registered_name_and_config(model,
                                                                          model_format=model_format,
                                                                          dtype='auto',
                                                                          group_size=0)
        assert output_name == register_name
        assert config.model_config.group_size == group_size
        assert config.session_len > 0
        assert config.model_config.model_arch is not None


def test_update_from_engine_config():
    import copy
    _, _config = get_output_model_registered_name_and_config('internlm/internlm2-chat-7b',
                                                             model_format='hf',
                                                             dtype='auto',
                                                             group_size=0)
    config = copy.deepcopy(_config)
    config.update_from_engine_config(None)
    assert (config == _config)

    config = copy.deepcopy(_config)
    engine_config = TurbomindEngineConfig()
    update_parallel_config(engine_config)
    config.update_from_engine_config(engine_config)
    assert config.model_config.attn_tp_size == 1
    assert config.session_len == 32768

    config = copy.deepcopy(_config)
    engine_config = TurbomindEngineConfig(model_format='hf',
                                          tp=2,
                                          device_num=2,
                                          session_len=4000,
                                          max_batch_size=100,
                                          cache_max_entry_count=0.5,
                                          quant_policy=8,
                                          rope_scaling_factor=3.0,
                                          use_logn_attn=True,
                                          max_prefill_iters=64,
                                          num_tokens_per_iter=256)
    update_parallel_config(engine_config)
    config.update_from_engine_config(engine_config)

    assert (config.model_config.attn_tp_size == engine_config.attn_tp_size)
    assert (config.session_len == engine_config.session_len)
    assert (config.attention_config.rope_param.type == 'dynamic')
    assert (config.attention_config.rope_param.factor == engine_config.rope_scaling_factor)
    assert (config.attention_config.use_logn_attn == engine_config.use_logn_attn)


def test_dtype():
    testsets = [('auto', 'bfloat16'), ('float16', 'float16'), ('bfloat16', 'bfloat16')]
    for specified_dtype, expected_dtype in testsets:
        _, _config = get_output_model_registered_name_and_config('internlm/internlm2-chat-7b',
                                                                 model_format='hf',
                                                                 dtype=specified_dtype,
                                                                 group_size=0)
        assert _config.weight_type == expected_dtype
    for specified_dtype in ['auto', 'float16', 'bfloat16']:
        _, _config = get_output_model_registered_name_and_config('internlm/internlm2_5-20b-chat-4bit-awq',
                                                                 model_format='awq',
                                                                 dtype=specified_dtype,
                                                                 group_size=128)
        assert _config.weight_type == 'int4'


================================================
FILE: tests/test_lmdeploy/test_utils.py
================================================
from transformers import AutoConfig

from lmdeploy.utils import _get_and_verify_max_len


def test_get_and_verify_max_len():
    # with PretrainedConfig
    config = AutoConfig.from_pretrained('OpenGVLab/InternVL-Chat-V1-5-AWQ', trust_remote_code=True)
    assert (_get_and_verify_max_len(config, None) == 32768)
    assert (_get_and_verify_max_len(config, 1024) == 1024)
    assert (_get_and_verify_max_len(config, 102400) == 102400)

    # with PretrainedConfig
    config = AutoConfig.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True)
    assert (_get_and_verify_max_len(config, None) == 32768)
    assert (_get_and_verify_max_len(config, 1024) == 1024)
    assert (_get_and_verify_max_len(config, 102400) == 102400)


================================================
FILE: tests/test_lmdeploy/test_vl/test_hf_chat_template.py
================================================
import os

import pytest

from lmdeploy.model import MODELS
from lmdeploy.vl.model.builder import load_vl_model


def get_model_and_chat_template(model_path):
    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':
        from openmind_hub import snapshot_download
    else:
        from huggingface_hub import snapshot_download
    model_path = snapshot_download(model_path, allow_patterns=['*.json', '*.py', '*.txt', '*.model', '*.jinja'])
    model = load_vl_model(model_path=model_path, with_llm=False, backend='pytorch')
    chat_template = MODELS.module_dict['hf'](model_path=model_path)
    return model, chat_template


@pytest.fixture(scope='module')
def mock_messages():
    return [
        dict(role='user',
             content=[
                 dict(type='text', text='Describe the following images in detail'),
                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
                 dict(type='text', text='How many cats are there in total?')
             ]),
    ]


@pytest.fixture(scope='module')
def mock_pure_img_messages():
    return [
        dict(role='user',
             content=[
                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
             ]),
    ]


@pytest.fixture(scope='module')
def mock_pure_text_messages():
    return [
        dict(role='user',
             content=[
                 dict(type='text', text='Describe the following images in detail'),
                 dict(type='text', text='How many cats are there in total?'),
             ]),
    ]


class TestInternVLHFChatTemplate:

    @pytest.fixture(scope='module')
    def models(self):
        model_list = [
            'OpenGVLab/InternVL3_5-1B-HF',
            'OpenGVLab/InternVL3_5-2B-HF',
            'OpenGVLab/InternVL3_5-4B-HF',
            'OpenGVLab/InternVL3_5-8B-HF',
            'OpenGVLab/InternVL3_5-14B-HF',
            'OpenGVLab/InternVL3_5-38B-HF',
            'OpenGVLab/InternVL3_5-30B-A3B-HF',
            'OpenGVLab/InternVL3_5-241B-A28B-HF',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    def test_proc_messages(self, models, mock_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            # InternVL-HF and InternS1 models pad  and  internally
            reference = reference.replace('', '')
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_proc_pure_img_messages(self, models, mock_pure_img_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_pure_img_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            # InternVL-HF and InternS1 models pad  and  internally
            reference = reference.replace('', '')
            prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_proc_pure_text_messages(self, models, mock_pure_text_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_pure_text_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True)
            assert prompt == reference


class TestQwenVLChatTemplate:

    @pytest.fixture(scope='module')
    def models(self):
        model_list = [
            'Qwen/Qwen2-VL-2B-Instruct',
            'Qwen/Qwen2-VL-7B-Instruct',
            'Qwen/Qwen2-VL-72B-Instruct',
            'Qwen/Qwen2.5-VL-3B-Instruct',
            'Qwen/Qwen2.5-VL-7B-Instruct',
            'Qwen/Qwen2.5-VL-32B-Instruct',
            'Qwen/Qwen2.5-VL-72B-Instruct',
            'Qwen/Qwen3-VL-2B-Instruct',
            'Qwen/Qwen3-VL-2B-Thinking',
            'Qwen/Qwen3-VL-4B-Instruct',
            'Qwen/Qwen3-VL-4B-Thinking',
            'Qwen/Qwen3-VL-8B-Instruct',
            'Qwen/Qwen3-VL-8B-Thinking',
            'Qwen/Qwen3-VL-32B-Instruct',
            'Qwen/Qwen3-VL-32B-Thinking',
            'Qwen/Qwen3-VL-30B-A3B-Instruct',
            'Qwen/Qwen3-VL-30B-A3B-Thinking',
            'Qwen/Qwen3-VL-235B-A22B-Instruct',
            'Qwen/Qwen3-VL-235B-A22B-Thinking',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    def test_proc_messages(self, models, mock_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_pure_img_messages(self, models, mock_pure_img_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_pure_img_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_pure_text_messages(self, models, mock_pure_text_messages):
        for model, chat_template in models:
            model.build_preprocessor()
            reference = model.processor.apply_chat_template(mock_pure_text_messages,
                                                            add_generation_prompt=True,
                                                            tokenize=False,
                                                            return_dict=True)
            prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True)
            assert prompt == reference


================================================
FILE: tests/test_lmdeploy/test_vl/test_nonhf_chat_template.py
================================================
import os

import pytest

from lmdeploy.model import MODELS
from lmdeploy.vl.model.builder import load_vl_model


def get_model_and_chat_template(model_path):
    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':
        from openmind_hub import snapshot_download
    else:
        from huggingface_hub import snapshot_download
    model_path = snapshot_download(model_path, allow_patterns=['*.json', '*.py', '*.txt', '*.model', '*.jinja'])
    model = load_vl_model(model_path=model_path, with_llm=False, backend='pytorch')
    chat_template = MODELS.module_dict['hf'](model_path=model_path)
    return model, chat_template


class TestInternVLChatTemplate:

    @pytest.fixture(scope='module')
    def internvl3_5(self):
        model_list = [
            'OpenGVLab/InternVL3_5-241B-A28B',
            'OpenGVLab/InternVL3_5-30B-A3B',
            'OpenGVLab/InternVL3_5-38B',
            'OpenGVLab/InternVL3_5-14B',
            'OpenGVLab/InternVL3_5-8B',
            'OpenGVLab/InternVL3_5-4B',
            'OpenGVLab/InternVL3_5-2B',
            'OpenGVLab/InternVL3_5-1B',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    @pytest.fixture(scope='module')
    def internvl3(self):
        model_list = [
            'OpenGVLab/InternVL3-78B',
            'OpenGVLab/InternVL3-38B',
            'OpenGVLab/InternVL3-14B',
            'OpenGVLab/InternVL3-8B',
            # "OpenGVLab/InternVL3-9B",  # 
            'OpenGVLab/InternVL3-2B',
            'OpenGVLab/InternVL3-1B',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    @pytest.fixture(scope='module')
    def internvl2_5(self):
        model_list = [
            'OpenGVLab/InternVL2_5-78B',
            'OpenGVLab/InternVL2_5-38B',
            # "OpenGVLab/InternVL2_5-26B",  # 
            # "OpenGVLab/InternVL2_5-8B",  # 
            'OpenGVLab/InternVL2_5-4B',
            # "OpenGVLab/InternVL2_5-2B",  # 
            'OpenGVLab/InternVL2_5-1B',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    @pytest.fixture(scope='module')
    def internvl2(self):
        model_list = [
            'OpenGVLab/InternVL2-Llama3-76B',
            'OpenGVLab/InternVL2-40B',
            'OpenGVLab/InternVL2-26B',
            'OpenGVLab/InternVL2-8B',
            # "OpenGVLab/InternVL2-4B",  # <|user|> not <|im_start|>
            'OpenGVLab/InternVL2-2B',
            'OpenGVLab/InternVL2-1B',
        ]
        models = [get_model_and_chat_template(model_path) for model_path in model_list]
        return models

    @pytest.fixture(scope='module')
    def mock_messages(self):
        return [
            dict(role='user',
                 content=[
                     dict(type='text', text='Describe the following images in detail'),
                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
                     dict(type='text', text='How many cats are there in total?')
                 ]),
        ]

    @pytest.fixture(scope='module')
    def mock_IMAGE_TOKEN_messages(self):
        return [
            dict(role='system', content='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'),
            dict(role='user',
                 content=[
                     dict(type='text', text='\nDescribe the following images in detail'),
                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg'))
                 ]),
        ]

    def test_internvl3_5(self, internvl3_5, mock_messages):
        reference = """<|im_start|>user
Describe the following images in detail

How many cats are there in total?<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl3_5:
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)

            assert prompt == reference

    def test_internvl3_5_backward_compatibility(self, internvl3_5, mock_IMAGE_TOKEN_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user

Describe the following images in detail<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl3_5:
            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_internvl3(self, internvl3, mock_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user
Describe the following images in detail

How many cats are there in total?<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl3:
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_internvl3_backward_compatibility(self, internvl3, mock_IMAGE_TOKEN_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user

Describe the following images in detail<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl3:
            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_internvl2_5(self, internvl2_5, mock_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user
Describe the following images in detail

How many cats are there in total?<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl2_5:
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_internvl2_5_backward_compatibility(self, internvl2_5, mock_IMAGE_TOKEN_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user

Describe the following images in detail<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl2_5:
            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)
            assert prompt == reference

    def test_internvl2(self, internvl2, mock_messages):
        reference = """<|im_start|>user
Describe the following images in detail

How many cats are there in total?<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl2:
            # Let sequence_start=False to avoid the begin-of-prompt token, such as <|begin_of_text|>, 
            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=False)
            assert prompt == reference

    def test_internvl2_backward_compatibility(self, internvl2, mock_IMAGE_TOKEN_messages):
        reference = """<|im_start|>system
你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>
<|im_start|>user

Describe the following images in detail<|im_end|>
<|im_start|>assistant
"""
        for model, chat_template in internvl2:
            # Let sequence_start=False to avoid the begin-of-prompt token, such as <|begin_of_text|>, 
            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=False)
            assert prompt == reference


================================================
FILE: tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py
================================================
import copy

import pytest

from lmdeploy.vl import load_image
from lmdeploy.vl.model.qwen3 import Qwen3VLModel

QWEN3VL_MODELS = [
    'Qwen/Qwen3-VL-4B-Instruct',
]

IMAGE_URL = ('https://raw.githubusercontent.com/open-mmlab/'
             'mmdeploy/main/tests/data/tiger.jpeg')


@pytest.fixture(scope='module', params=QWEN3VL_MODELS)
def qwen3vl_model(request):
    """Initialize Qwen3VLModel with a real model path."""
    model = Qwen3VLModel(model_path=request.param)
    model.build_preprocessor()
    return model


@pytest.fixture
def sample_messages():
    """Create sample messages for preprocessing using image_url."""
    pil_image = load_image(IMAGE_URL)
    return [{
        'role':
        'user',
        'content': [
            {
                'type': 'text',
                'text': 'Can you describe this image?'
            },
            {
                'type': 'image',
                'data': pil_image
            },
        ]
    }]


def test_qwen3vl_preprocess_with_custom_pixels(qwen3vl_model, sample_messages):
    """Test that mm_processor_kwargs with min/max pixels takes effect."""

    # compression ratio for qwen3vl is 32 = patch_size * spatial_merge_size = 16 * 2
    # qwen3vl_model.processor.image_processor.size['shortest_edge'] = 66536
    # 65536 = 64 * 32 * 32, indicates 64 image token budget
    # qwen3vl_model.processor.image_processor.size['longest_edge'] = 16777216
    # 16777216 = 16384 * 32 * 32, indicates 16384 image token budget

    # Default processing without custom arguments
    default_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages))
    default_content = default_processed_messages[-1]['content']
    default_shape = default_content[0]['pixel_values'].shape  # [280, 1536]

    # Processing with smaller pixel range
    mm_processor_kwargs = {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}
    custom_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages),
                                                         mm_processor_kwargs=mm_processor_kwargs)
    custom_content = custom_processed_messages[-1]['content']
    custom_shape = custom_content[0]['pixel_values'].shape  # [60, 1536]

    assert default_shape != custom_shape, \
        'Default and custom processing should result in different shapes.'
    assert default_shape[0] > custom_shape[0], \
        'Custom processing with smaller pixel range should result in smaller image size.'

    # Processing with larger pixel range
    mm_processor_kwargs = {'min_pixels': 100 * 32 * 32, 'max_pixels': 20000 * 32 * 32}
    custom_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages),
                                                         mm_processor_kwargs=mm_processor_kwargs)
    custom_content = custom_processed_messages[-1]['content']
    custom_shape = custom_content[0]['pixel_values'].shape  # [468, 1536]

    assert default_shape != custom_shape, \
        'Default and custom processing should result in different shapes.'
    assert default_shape[0] < custom_shape[0], \
        'Custom processing with larger pixel range should result in larger image size.'


================================================
FILE: tests/test_lmdeploy/test_vl/test_vl_encode.py
================================================
import math

import numpy as np

from lmdeploy.vl import (encode_image_base64, encode_time_series_base64, encode_video_base64, load_image,
                         load_time_series, load_video)


def test_image_encode_decode():
    url = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'

    img1 = load_image(url)
    # use PNG for lossless pixel-perfect comparison
    b64 = encode_image_base64(url, format='PNG')
    img2 = load_image(f'data:image/png;base64,{b64}')

    assert img1.size == img2.size
    assert img1.mode == img2.mode
    assert img1.tobytes() == img2.tobytes()


def test_video_encode_decode():
    # url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4'
    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'

    # num_frames=4 to keep test fast
    vid1, meta1 = load_video(url, num_frames=4)
    b64 = encode_video_base64(url, num_frames=4, format='JPEG')
    vid2, meta2 = load_video(f'data:video/jpeg;base64,{b64}')

    gt_meta = {
        'total_num_frames': 498,
        'fps': 29.97002997002997,
        'duration': 16.616600000000002,
        'video_backend': 'opencv',
        'frames_indices': [0, 165, 331, 497]
    }

    assert vid1.shape == vid2.shape
    assert np.mean(np.abs(vid1.astype(float) - vid2.astype(float))) < 2.0  # JPEG is lossy
    assert meta1['total_num_frames'] == gt_meta['total_num_frames']
    assert meta1['frames_indices'] == gt_meta['frames_indices']


def test_time_series_encode_decode():
    # url = "https://huggingface.co/internlm/Intern-S1-Pro/raw/main/0092638_seism.npy"
    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/0092638_seism.npy'

    ts1 = load_time_series(url)
    b64 = encode_time_series_base64(url)
    ts2 = load_time_series(f'data:time_series/npy;base64,{b64}')

    assert ts1.shape == ts2.shape
    assert np.allclose(ts1, ts2)


def test_image_modes():
    import numpy as np
    from PIL import Image

    grayscale_img = Image.fromarray(np.zeros((100, 100), dtype=np.uint8)).convert('L')
    b64 = encode_image_base64(grayscale_img)  # should convert L -> RGB internally

    img_out = load_image(f'data:image/png;base64,{b64}')
    assert img_out.mode == 'RGB'


def test_truncated_image():
    url = 'https://github.com/irexyc/lmdeploy/releases/download/v0.0.1/tr.jpeg'
    im = load_image(url)
    assert im.width == 1638
    assert im.height == 2048


def test_single_frame_video():
    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'
    vid, meta = load_video(url, num_frames=1)
    assert vid.shape[0] == 1

    b64 = encode_video_base64(vid)
    assert isinstance(b64, str)
    assert ',' not in b64  # should only be one JPEG block, no commas


def test_video_sampling_params():
    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'

    # 1. test num_frames constraint
    num_frames = 5
    vid, meta = load_video(url, num_frames=num_frames)
    assert vid.shape[0] == num_frames
    assert len(meta['frames_indices']) == num_frames

    # 2. test fps constraint (original fps is ~29.97, duration ~16.6s)
    fps = 1
    vid, meta = load_video(url, fps=fps)
    expected_frames = max(1, int(math.floor(meta['duration'] * fps)))
    assert vid.shape[0] == expected_frames

    # 3. test both constraints (should take the minimum)
    # 10 fps x 16.6s ~= 166 frames > 10 frames, so will be limited by num_frames
    num_frames = 10
    fps = 10
    vid, meta = load_video(url, num_frames=num_frames, fps=fps)
    assert vid.shape[0] == num_frames

    # 1 fps x 16.6s ~= 16 frames < 100 frames, so will be limited by fps
    num_frames = 100
    fps = 1
    vid, meta = load_video(url, num_frames=num_frames, fps=fps)
    expected_frames = max(1, int(math.floor(meta['duration'] * fps)))
    assert vid.shape[0] == expected_frames


def test_invalid_inputs():
    # non-existent local path
    import pytest
    with pytest.raises(Exception):
        load_image('/non_existent/path/image.jpg')
    with pytest.raises(Exception):
        load_video('/non_existent/path/video.mp4')
    with pytest.raises(Exception):
        load_time_series('/non_existent/path/data.npy')